From 5ce65e6be9b23ba3b51e12cc4356e47447a9337d Mon Sep 17 00:00:00 2001 From: Gnarwhal Date: Fri, 4 Oct 2024 21:44:10 +0000 Subject: [PATCH] Actual AUTH and also enforce EHLO/HELO first --- cmd/server/config.go | 27 ++++++++++++----- cmd/server/main.go | 23 ++++++++------ smtp/commands.go | 72 +++++++++++++++++++++++++++++++++++++++++++- smtp/server.go | 22 +++++++++----- smtp/session.go | 15 ++++++++- 5 files changed, 133 insertions(+), 26 deletions(-) diff --git a/cmd/server/config.go b/cmd/server/config.go index 392c9e9..624948e 100644 --- a/cmd/server/config.go +++ b/cmd/server/config.go @@ -26,14 +26,24 @@ import ( type Config struct { LogLevel string Host string - PasswordHash string - Plain string - TLS string - CertPath string - PrivateKeyPath string + Ports PortConfig + Certificate CertConfig + Auth AuthConfig } -type GeneralConfig struct { +type PortConfig struct { + Plain string + TLS string +} + +type CertConfig struct { + CertFile string + KeyFile string +} + +type AuthConfig struct { + Enabled bool + PasswordHash string } func LoadConfig(path string) (*Config, error) { @@ -42,8 +52,9 @@ func LoadConfig(path string) (*Config, error) { return nil, err } var config Config - config.Plain = "disabled" - config.TLS = "disabled" + config.Ports.Plain = "disabled" + config.Ports.TLS = "disabled" + config.Auth.Enabled = true err = json.Unmarshal(contents, &config) if err != nil { return nil, err diff --git a/cmd/server/main.go b/cmd/server/main.go index 63915ec..67560be 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -83,28 +83,33 @@ func main() { log.Info().Msgf("Starting diodemail v%v", Version) log.Info().Msgf("Loaded config from: %v", config_path) - if config.CertPath == "" || config.PrivateKeyPath == "" { + if config.Certificate.CertFile == "" || config.Certificate.KeyFile == "" { log.Fatal().Msgf( - "Must provide CertPath (got '%v') and PrivateKeyPath (got '%v')", - config.CertPath, - config.PrivateKeyPath, + "Must provide CertFile (got '%v') and KeyFile (got '%v')", + config.Certificate.CertFile, + config.Certificate.KeyFile, ) } certificate, err := tls.LoadX509KeyPair( - config.CertPath, - config.PrivateKeyPath, + config.Certificate.CertFile, + config.Certificate.KeyFile, ) if err != nil { log.Fatal().Msgf("Failed to load TLS config: %v", err) } + if config.Auth.Enabled && config.Auth.PasswordHash == "" { + log.Fatal().Msgf("Authentication is enabled but no password hash was supplied") + } + err = smtp.Run( config.Host, - config.PasswordHash, - config.Plain, - config.TLS, + config.Ports.Plain, + config.Ports.TLS, tls.Config{Certificates: []tls.Certificate{certificate}}, + config.Auth.Enabled, + config.Auth.PasswordHash, ) if err != nil { log.Fatal().Msgf("%v", err) diff --git a/smtp/commands.go b/smtp/commands.go index 0ab403d..8c25842 100644 --- a/smtp/commands.go +++ b/smtp/commands.go @@ -61,27 +61,43 @@ func Helo(smtp_session *SMTPSession, message string) (bool, error) { return true, err } + smtp_session.HasHelloed = true return false, nil } func Ehlo(smtp_session *SMTPSession, message string) (bool, error) { + enable_auth := "" + if smtp_session.RequiresAuthentication() { + enable_auth = "250-AUTH PLAIN\r\n" + } err := smtp_session.Write( fmt.Sprintf( "250-%v is shy" + "\r\n" + - "250-AUTH PLAIN" + "\r\n" + + "%v" + "250-8BITMIME" + "\r\n" + "250 SMTPUTF8" + "\r\n", smtp_session.GetHost(), + enable_auth, ), ) if err != nil { return true, err } + smtp_session.HasHelloed = true return false, nil } func Auth(smtp_session *SMTPSession, message string) (bool, error) { + if !smtp_session.HasHelloed { + err := smtp_session.Write( + "503 Must HELO/EHLO first\n", + ) + if err != nil { + return true, err + } + return false, nil + } parts := strings.Split(message, " ")[1:] if parts[0] == "PLAIN" { var password string @@ -107,6 +123,24 @@ func Auth(smtp_session *SMTPSession, message string) (bool, error) { } func MailFrom(smtp_session *SMTPSession, message string) (bool, error) { + if !smtp_session.HasHelloed { + err := smtp_session.Write( + "503 Must HELO/EHLO first\n", + ) + if err != nil { + return true, err + } + return false, nil + } + if !smtp_session.AuthenticationSatisfied() { + err := smtp_session.Write( + "530 Authentication required\n", + ) + if err != nil { + return true, err + } + return false, nil + } match := ReversePath.FindStringSubmatch(message) if match == nil { err := smtp_session.Write( @@ -129,6 +163,24 @@ func MailFrom(smtp_session *SMTPSession, message string) (bool, error) { } func RcptTo(smtp_session *SMTPSession, message string) (bool, error) { + if !smtp_session.HasHelloed { + err := smtp_session.Write( + "503 Must HELO/EHLO first\n", + ) + if err != nil { + return true, err + } + return false, nil + } + if !smtp_session.AuthenticationSatisfied() { + err := smtp_session.Write( + "530 Authentication required\n", + ) + if err != nil { + return true, err + } + return false, nil + } match := ForwardPath.FindStringSubmatch(message) if match == nil { err := smtp_session.Write( @@ -151,6 +203,24 @@ func RcptTo(smtp_session *SMTPSession, message string) (bool, error) { } func Data(smtp_session *SMTPSession, message string) (bool, error) { + if !smtp_session.HasHelloed { + err := smtp_session.Write( + "503 Must HELO/EHLO first\n", + ) + if err != nil { + return true, err + } + return false, nil + } + if !smtp_session.AuthenticationSatisfied() { + err := smtp_session.Write( + "530 Authentication required\n", + ) + if err != nil { + return true, err + } + return false, nil + } err := smtp_session.Write( "354 Start Input\n", ) diff --git a/smtp/server.go b/smtp/server.go index 40fa91e..d5498dd 100644 --- a/smtp/server.go +++ b/smtp/server.go @@ -27,14 +27,20 @@ import ( "github.com/rs/zerolog/log" ) -func handle(connection net.Conn, host string, password_hash string, tls_config tls.Config) { +func handle( + connection net.Conn, + host string, + tls_config tls.Config, + auth bool, + password_hash string, +) { log.Info().Msgf( "New connection %v. Starting session.", connection.RemoteAddr(), ) defer connection.Close() - session := MakeSMTPSession(connection, host, password_hash, tls_config) + session := MakeSMTPSession(connection, host, tls_config, auth, password_hash) err := session.Run() if err != nil { log.Error().Msgf( @@ -52,10 +58,11 @@ func handle(connection net.Conn, host string, password_hash string, tls_config t func Run( host string, - password_hash string, plain_port string, tls_port string, tls_config tls.Config, + auth bool, + password_hash string, ) error { var wait_group sync.WaitGroup if plain_port != "disabled" { @@ -65,7 +72,7 @@ func Run( } log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host) wait_group.Add(1) - go Listen(wait_group, host, password_hash, tls_config, listener) + go Listen(wait_group, host, tls_config, auth, password_hash, listener) } if tls_port != "disabled" { listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config) @@ -74,7 +81,7 @@ func Run( } log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host) wait_group.Add(1) - go Listen(wait_group, host, password_hash, tls_config, listener) + go Listen(wait_group, host, tls_config, auth, password_hash, listener) } wait_group.Wait() @@ -84,8 +91,9 @@ func Run( func Listen( wait_group sync.WaitGroup, host string, - password_hash string, tls_config tls.Config, + auth bool, + password_hash string, listener net.Listener, ) { defer wait_group.Done() @@ -94,7 +102,7 @@ func Listen( if err != nil { log.Error().Msgf("Failed to accept client: %v", err) } else { - go handle(connection, host, password_hash, tls_config) + go handle(connection, host, tls_config, auth, password_hash) } } } diff --git a/smtp/session.go b/smtp/session.go index ef648b5..eebf53a 100644 --- a/smtp/session.go +++ b/smtp/session.go @@ -35,16 +35,20 @@ type SMTPSession struct { password_hash string tls_config tls.Config buffer [4096]byte + HasHelloed bool + requires_auth bool + is_authed bool ReversePathBuffer *string ForwardPathBuffer []string } -func MakeSMTPSession(connection net.Conn, host string, password_hash string, tls_config tls.Config) SMTPSession { +func MakeSMTPSession(connection net.Conn, host string, tls_config tls.Config, auth bool, password_hash string) SMTPSession { return SMTPSession{ connection: connection, host: host, password_hash: password_hash, + requires_auth: auth, } } @@ -137,6 +141,15 @@ func (self *SMTPSession) GetHost() string { return self.host } +func (self *SMTPSession) RequiresAuthentication() bool { + return self.requires_auth +} + +func (self *SMTPSession) AuthenticationSatisfied() bool { + fmt.Println(self.requires_auth, self.is_authed) + return !self.requires_auth || self.is_authed +} + func (self *SMTPSession) ValidatePassword(password string) bool { return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash }