Added AUTH command and PLAIN sasl mechanism

This commit is contained in:
Gnarwhal 2024-10-02 07:56:56 +00:00
parent 3523f1cf4e
commit 86dc81451a
Signed by: Gnarwhal
GPG key ID: 0989A73D8C421174
7 changed files with 75 additions and 24 deletions

2
.gitignore vendored
View file

@ -21,3 +21,5 @@
# ...even if they are in subdirectories
!*/
!hash.sh

View file

@ -24,12 +24,17 @@ import (
)
type Config struct {
LogLevel string
Host string
General GeneralConfig
Plain PlainConfig
TLS TLSConfig
}
type GeneralConfig struct {
LogLevel string
Host string
PasswordHash string
}
type PlainConfig struct {
Enabled bool
Port string

View file

@ -62,8 +62,8 @@ func main() {
log.Fatal().Msgf("%v", err)
}
if log_level == "" {
if config.LogLevel != "" {
log_level = config.LogLevel
if config.General.LogLevel != "" {
log_level = config.General.LogLevel
} else {
log_level = "info"
}
@ -107,7 +107,8 @@ func main() {
}
err = smtp.Run(
config.Host,
config.General.Host,
config.General.PasswordHash,
plain_config,
tls_config,
)

3
hash.sh Executable file
View file

@ -0,0 +1,3 @@
#!/bin/sh
printf "$@" | base64 | tr -d '\n' | sha256sum

View file

@ -68,6 +68,7 @@ func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
err := smtp_session.Write(
fmt.Sprintf(
"250-%v is shy" + "\r\n" +
"250-AUTH PLAIN" + "\r\n" +
"250-8BITMIME" + "\r\n" +
"250 SMTPUTF8" + "\r\n",
smtp_session.GetHost(),
@ -80,6 +81,31 @@ func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
return false, nil
}
func Auth(smtp_session *SMTPSession, message string) (bool, error) {
parts := strings.Split(message, " ")[1:]
if parts[0] == "PLAIN" {
var password string
if len(parts) > 1 {
password = parts[1][:len(parts[1]) - 1]
} else {
err := smtp_session.Write("334 \n")
if err != nil {
return true, nil
}
password, err = smtp_session.ReadUntil("\r\n")
if err != nil {
return true, nil
}
}
if smtp_session.ValidatePassword(password) {
smtp_session.Write("235 Authentication successful\n")
} else {
smtp_session.Write("535 Authentication credentials invalid\n")
}
}
return false, nil
}
func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
match := ReversePath.FindStringSubmatch(message)
if match == nil {
@ -131,7 +157,7 @@ func Data(smtp_session *SMTPSession, message string) (bool, error) {
if err != nil {
return true, err
}
data, err := smtp_session.Read("\r\n.\r\n")
data, err := smtp_session.ReadUntil("\r\n.\r\n")
if err != nil {
smtp_session.Write(
"550 Action not taken\n",

View file

@ -36,14 +36,14 @@ type TLSConfig struct {
TlsConfig tls.Config
}
func handle(connection net.Conn, host string) {
func handle(connection net.Conn, host string, password_hash string) {
log.Info().Msgf(
"New connection %v. Starting session.",
connection.RemoteAddr(),
)
defer connection.Close()
session := MakeSMTPSession(connection, host)
session := MakeSMTPSession(connection, host, password_hash)
err := session.Run()
if err != nil {
log.Error().Msgf(
@ -59,7 +59,7 @@ func handle(connection net.Conn, host string) {
}
}
func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
func Run(host string, password_hash string, plain_config *PlainConfig, tls_config *TLSConfig) error {
var wait_group sync.WaitGroup
if plain_config != nil {
listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, plain_config.Port))
@ -68,7 +68,7 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
}
log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host)
wait_group.Add(1)
go Listen(wait_group, host, listener)
go Listen(wait_group, host, password_hash, listener)
}
if tls_config != nil {
listener, err := tls.Listen("tcp", fmt.Sprintf("%v:%v", host, tls_config.Port), &tls_config.TlsConfig)
@ -77,18 +77,18 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
}
log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host)
wait_group.Add(1)
go Listen(wait_group, host, listener)
go Listen(wait_group, host, password_hash, listener)
}
wait_group.Wait()
return nil
}
func Listen(wait_group sync.WaitGroup, host string, listener net.Listener) {
func Listen(wait_group sync.WaitGroup, host string, password_hash string, listener net.Listener) {
defer wait_group.Done()
for {
connection, _ := listener.Accept()
go handle(connection, host)
go handle(connection, host, password_hash)
}
}

View file

@ -24,6 +24,7 @@ import (
"fmt"
"strings"
"crypto/tls"
"crypto/sha256"
"github.com/rs/zerolog/log"
)
@ -31,16 +32,18 @@ import (
type SMTPSession struct {
connection net.Conn
host string
password_hash string
buffer [4096]byte
ReversePathBuffer *string
ForwardPathBuffer []string
}
func MakeSMTPSession(connection net.Conn, host string) SMTPSession {
func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession {
return SMTPSession{
connection: connection,
host: host,
password_hash: password_hash,
}
}
@ -52,6 +55,7 @@ func (self *SMTPSession) Run() error {
COMMANDS := []Command{
Command{ "HELO", Helo },
Command{ "EHLO", Ehlo },
Command{ "AUTH", Auth },
Command{ "MAIL FROM:", MailFrom },
Command{ "RCPT TO:", RcptTo },
Command{ "DATA", Data },
@ -61,7 +65,7 @@ func (self *SMTPSession) Run() error {
}
quit := false
for !quit {
message, err := self.Read("\n")
message, err := self.ReadUntil("\n")
if err != nil {
return err
}
@ -102,7 +106,7 @@ func (self *SMTPSession) TraceSession(direction string, message string) {
}
func (self *SMTPSession) Read(delimiter string) (string, error) {
func (self *SMTPSession) ReadUntil(delimiter string) (string, error) {
var message string
for !strings.Contains(message, delimiter) {
num_read, err := self.connection.Read(self.buffer[:])
@ -116,6 +120,12 @@ func (self *SMTPSession) Read(delimiter string) (string, error) {
return message[:len(message) - len(delimiter)], nil
}
func (self *SMTPSession) ReadCount(num_bytes int) ([]byte, error) {
bytes := make([]byte, num_bytes)
_, err := self.connection.Read(bytes[:])
return bytes, err
}
func (self *SMTPSession) Write(message string) error {
self.TraceSession("<- ", message)
_, err := self.connection.Write([]byte(message))
@ -126,6 +136,10 @@ func (self *SMTPSession) GetHost() string {
return self.host
}
func (self *SMTPSession) ValidatePassword(password string) bool {
return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash
}
func (self *SMTPSession) MailFrom(from string) {
self.Reset()
self.ReversePathBuffer = &from