From 86dc81451a2fa8f15d35847c62fe24925c2bdc1d Mon Sep 17 00:00:00 2001 From: Gnarwhal Date: Wed, 2 Oct 2024 07:56:56 +0000 Subject: [PATCH] Added AUTH command and PLAIN sasl mechanism --- .gitignore | 2 ++ cmd/server/config.go | 9 +++++++-- cmd/server/main.go | 7 ++++--- hash.sh | 3 +++ smtp/commands.go | 34 ++++++++++++++++++++++++++++++---- smtp/server.go | 14 +++++++------- smtp/session.go | 30 ++++++++++++++++++++++-------- 7 files changed, 75 insertions(+), 24 deletions(-) create mode 100755 hash.sh diff --git a/.gitignore b/.gitignore index a309a01..133f18b 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,5 @@ # ...even if they are in subdirectories !*/ + +!hash.sh diff --git a/cmd/server/config.go b/cmd/server/config.go index ecde6f6..aac1aea 100644 --- a/cmd/server/config.go +++ b/cmd/server/config.go @@ -24,12 +24,17 @@ import ( ) type Config struct { - LogLevel string - Host string + General GeneralConfig Plain PlainConfig TLS TLSConfig } +type GeneralConfig struct { + LogLevel string + Host string + PasswordHash string +} + type PlainConfig struct { Enabled bool Port string diff --git a/cmd/server/main.go b/cmd/server/main.go index 97364ea..e2de162 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -62,8 +62,8 @@ func main() { log.Fatal().Msgf("%v", err) } if log_level == "" { - if config.LogLevel != "" { - log_level = config.LogLevel + if config.General.LogLevel != "" { + log_level = config.General.LogLevel } else { log_level = "info" } @@ -107,7 +107,8 @@ func main() { } err = smtp.Run( - config.Host, + config.General.Host, + config.General.PasswordHash, plain_config, tls_config, ) diff --git a/hash.sh b/hash.sh new file mode 100755 index 0000000..8bd8db3 --- /dev/null +++ b/hash.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +printf "$@" | base64 | tr -d '\n' | sha256sum diff --git a/smtp/commands.go b/smtp/commands.go index 3933770..f183a6b 100644 --- a/smtp/commands.go +++ b/smtp/commands.go @@ -67,9 +67,10 @@ func Helo(smtp_session *SMTPSession, message string) (bool, error) { func Ehlo(smtp_session *SMTPSession, message string) (bool, error) { err := smtp_session.Write( fmt.Sprintf( - "250-%v is shy" + "\r\n" + - "250-8BITMIME" + "\r\n" + - "250 SMTPUTF8" + "\r\n", + "250-%v is shy" + "\r\n" + + "250-AUTH PLAIN" + "\r\n" + + "250-8BITMIME" + "\r\n" + + "250 SMTPUTF8" + "\r\n", smtp_session.GetHost(), ), ) @@ -80,6 +81,31 @@ func Ehlo(smtp_session *SMTPSession, message string) (bool, error) { return false, nil } +func Auth(smtp_session *SMTPSession, message string) (bool, error) { + parts := strings.Split(message, " ")[1:] + if parts[0] == "PLAIN" { + var password string + if len(parts) > 1 { + password = parts[1][:len(parts[1]) - 1] + } else { + err := smtp_session.Write("334 \n") + if err != nil { + return true, nil + } + password, err = smtp_session.ReadUntil("\r\n") + if err != nil { + return true, nil + } + } + if smtp_session.ValidatePassword(password) { + smtp_session.Write("235 Authentication successful\n") + } else { + smtp_session.Write("535 Authentication credentials invalid\n") + } + } + return false, nil +} + func MailFrom(smtp_session *SMTPSession, message string) (bool, error) { match := ReversePath.FindStringSubmatch(message) if match == nil { @@ -131,7 +157,7 @@ func Data(smtp_session *SMTPSession, message string) (bool, error) { if err != nil { return true, err } - data, err := smtp_session.Read("\r\n.\r\n") + data, err := smtp_session.ReadUntil("\r\n.\r\n") if err != nil { smtp_session.Write( "550 Action not taken\n", diff --git a/smtp/server.go b/smtp/server.go index 22822f5..e408033 100644 --- a/smtp/server.go +++ b/smtp/server.go @@ -36,14 +36,14 @@ type TLSConfig struct { TlsConfig tls.Config } -func handle(connection net.Conn, host string) { +func handle(connection net.Conn, host string, password_hash string) { log.Info().Msgf( "New connection %v. Starting session.", connection.RemoteAddr(), ) defer connection.Close() - session := MakeSMTPSession(connection, host) + session := MakeSMTPSession(connection, host, password_hash) err := session.Run() if err != nil { log.Error().Msgf( @@ -59,7 +59,7 @@ func handle(connection net.Conn, host string) { } } -func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error { +func Run(host string, password_hash string, plain_config *PlainConfig, tls_config *TLSConfig) error { var wait_group sync.WaitGroup if plain_config != nil { listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, plain_config.Port)) @@ -68,7 +68,7 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error { } log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host) wait_group.Add(1) - go Listen(wait_group, host, listener) + go Listen(wait_group, host, password_hash, listener) } if tls_config != nil { listener, err := tls.Listen("tcp", fmt.Sprintf("%v:%v", host, tls_config.Port), &tls_config.TlsConfig) @@ -77,18 +77,18 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error { } log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host) wait_group.Add(1) - go Listen(wait_group, host, listener) + go Listen(wait_group, host, password_hash, listener) } wait_group.Wait() return nil } -func Listen(wait_group sync.WaitGroup, host string, listener net.Listener) { +func Listen(wait_group sync.WaitGroup, host string, password_hash string, listener net.Listener) { defer wait_group.Done() for { connection, _ := listener.Accept() - go handle(connection, host) + go handle(connection, host, password_hash) } } diff --git a/smtp/session.go b/smtp/session.go index 2922688..b1c4313 100644 --- a/smtp/session.go +++ b/smtp/session.go @@ -24,23 +24,26 @@ import ( "fmt" "strings" "crypto/tls" + "crypto/sha256" "github.com/rs/zerolog/log" ) type SMTPSession struct { - connection net.Conn - host string - buffer [4096]byte + connection net.Conn + host string + password_hash string + buffer [4096]byte ReversePathBuffer *string ForwardPathBuffer []string } -func MakeSMTPSession(connection net.Conn, host string) SMTPSession { +func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession { return SMTPSession{ - connection: connection, - host: host, + connection: connection, + host: host, + password_hash: password_hash, } } @@ -52,6 +55,7 @@ func (self *SMTPSession) Run() error { COMMANDS := []Command{ Command{ "HELO", Helo }, Command{ "EHLO", Ehlo }, + Command{ "AUTH", Auth }, Command{ "MAIL FROM:", MailFrom }, Command{ "RCPT TO:", RcptTo }, Command{ "DATA", Data }, @@ -61,7 +65,7 @@ func (self *SMTPSession) Run() error { } quit := false for !quit { - message, err := self.Read("\n") + message, err := self.ReadUntil("\n") if err != nil { return err } @@ -102,7 +106,7 @@ func (self *SMTPSession) TraceSession(direction string, message string) { } -func (self *SMTPSession) Read(delimiter string) (string, error) { +func (self *SMTPSession) ReadUntil(delimiter string) (string, error) { var message string for !strings.Contains(message, delimiter) { num_read, err := self.connection.Read(self.buffer[:]) @@ -116,6 +120,12 @@ func (self *SMTPSession) Read(delimiter string) (string, error) { return message[:len(message) - len(delimiter)], nil } +func (self *SMTPSession) ReadCount(num_bytes int) ([]byte, error) { + bytes := make([]byte, num_bytes) + _, err := self.connection.Read(bytes[:]) + return bytes, err +} + func (self *SMTPSession) Write(message string) error { self.TraceSession("<- ", message) _, err := self.connection.Write([]byte(message)) @@ -126,6 +136,10 @@ func (self *SMTPSession) GetHost() string { return self.host } +func (self *SMTPSession) ValidatePassword(password string) bool { + return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash +} + func (self *SMTPSession) MailFrom(from string) { self.Reset() self.ReversePathBuffer = &from