diff --git a/cmd/server/config.go b/cmd/server/config.go index 712ab6e..392c9e9 100644 --- a/cmd/server/config.go +++ b/cmd/server/config.go @@ -24,27 +24,16 @@ import ( ) type Config struct { - General GeneralConfig - Plain PlainConfig - TLS TLSConfig + LogLevel string + Host string + PasswordHash string + Plain string + TLS string + CertPath string + PrivateKeyPath string } type GeneralConfig struct { - LogLevel string - Host string - PasswordHash string -} - -type PlainConfig struct { - Enabled bool - Port string -} - -type TLSConfig struct { - Enabled bool - Port string - CertPath string - PrivateKeyPath string } func LoadConfig(path string) (*Config, error) { @@ -53,10 +42,8 @@ func LoadConfig(path string) (*Config, error) { return nil, err } var config Config - config.Plain.Enabled = false - config.Plain.Port = "25" - config.TLS.Enabled = false - config.TLS.Port = "465" + config.Plain = "disabled" + config.TLS = "disabled" err = json.Unmarshal(contents, &config) if err != nil { return nil, err diff --git a/cmd/server/main.go b/cmd/server/main.go index a9a7fdd..63915ec 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -62,8 +62,8 @@ func main() { log.Fatal().Msgf("%v", err) } if log_level == "" { - if config.General.LogLevel != "" { - log_level = config.General.LogLevel + if config.LogLevel != "" { + log_level = config.LogLevel } else { log_level = "info" } @@ -83,34 +83,28 @@ func main() { log.Info().Msgf("Starting diodemail v%v", Version) log.Info().Msgf("Loaded config from: %v", config_path) - var plain_config *smtp.PlainConfig - if config.Plain.Enabled { - plain_config = &smtp.PlainConfig { - config.Plain.Port, - } + if config.CertPath == "" || config.PrivateKeyPath == "" { + log.Fatal().Msgf( + "Must provide CertPath (got '%v') and PrivateKeyPath (got '%v')", + config.CertPath, + config.PrivateKeyPath, + ) } - var tls_config *smtp.TLSConfig - if config.TLS.Enabled { - certificate, err := tls.LoadX509KeyPair( - config.TLS.CertPath, - config.TLS.PrivateKeyPath, - ) - if err != nil { - log.Error().Msgf("Failed to load TLS config: %v", err) - } else { - tls_config = &smtp.TLSConfig { - config.TLS.Port, - tls.Config{Certificates: []tls.Certificate{certificate}}, - } - } + certificate, err := tls.LoadX509KeyPair( + config.CertPath, + config.PrivateKeyPath, + ) + if err != nil { + log.Fatal().Msgf("Failed to load TLS config: %v", err) } err = smtp.Run( - config.General.Host, - config.General.PasswordHash, - plain_config, - tls_config, + config.Host, + config.PasswordHash, + config.Plain, + config.TLS, + tls.Config{Certificates: []tls.Certificate{certificate}}, ) if err != nil { log.Fatal().Msgf("%v", err) diff --git a/smtp/server.go b/smtp/server.go index ac4f9d9..40fa91e 100644 --- a/smtp/server.go +++ b/smtp/server.go @@ -27,23 +27,14 @@ import ( "github.com/rs/zerolog/log" ) -type PlainConfig struct { - Port string -} - -type TLSConfig struct { - Port string - TlsConfig tls.Config -} - -func handle(connection net.Conn, host string, password_hash string) { +func handle(connection net.Conn, host string, password_hash string, tls_config tls.Config) { log.Info().Msgf( "New connection %v. Starting session.", connection.RemoteAddr(), ) defer connection.Close() - session := MakeSMTPSession(connection, host, password_hash) + session := MakeSMTPSession(connection, host, password_hash, tls_config) err := session.Run() if err != nil { log.Error().Msgf( @@ -59,36 +50,51 @@ func handle(connection net.Conn, host string, password_hash string) { } } -func Run(host string, password_hash string, plain_config *PlainConfig, tls_config *TLSConfig) error { +func Run( + host string, + password_hash string, + plain_port string, + tls_port string, + tls_config tls.Config, +) error { var wait_group sync.WaitGroup - if plain_config != nil { - listener, err := net.Listen("tcp", fmt.Sprintf(":%v", plain_config.Port)) + if plain_port != "disabled" { + listener, err := net.Listen("tcp", fmt.Sprintf(":%v", plain_port)) if err != nil { return err } - log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host) + log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host) wait_group.Add(1) - go Listen(wait_group, host, password_hash, listener) + go Listen(wait_group, host, password_hash, tls_config, listener) } - if tls_config != nil { - listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_config.Port), &tls_config.TlsConfig) + if tls_port != "disabled" { + listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config) if err != nil { return err } - log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host) + log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host) wait_group.Add(1) - go Listen(wait_group, host, password_hash, listener) + go Listen(wait_group, host, password_hash, tls_config, listener) } wait_group.Wait() return nil } -func Listen(wait_group sync.WaitGroup, host string, password_hash string, listener net.Listener) { +func Listen( + wait_group sync.WaitGroup, + host string, + password_hash string, + tls_config tls.Config, + listener net.Listener, +) { defer wait_group.Done() for { - connection, _ := listener.Accept() - - go handle(connection, host, password_hash) + connection, err := listener.Accept() + if err != nil { + log.Error().Msgf("Failed to accept client: %v", err) + } else { + go handle(connection, host, password_hash, tls_config) + } } } diff --git a/smtp/session.go b/smtp/session.go index 55b84f0..ef648b5 100644 --- a/smtp/session.go +++ b/smtp/session.go @@ -33,13 +33,14 @@ type SMTPSession struct { connection net.Conn host string password_hash string + tls_config tls.Config buffer [4096]byte ReversePathBuffer *string ForwardPathBuffer []string } -func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession { +func MakeSMTPSession(connection net.Conn, host string, password_hash string, tls_config tls.Config) SMTPSession { return SMTPSession{ connection: connection, host: host, @@ -161,11 +162,7 @@ func (self *SMTPSession) SendMail(data string) error { } else { smtp_hostname = mx[0].Host } - tls_connection, err := tls.Dial("tcp", fmt.Sprintf("%v:465", smtp_hostname), nil) - if err != nil { - return err - } - smtp_client, err := smtp.NewClient(tls_connection, "") + smtp_client, err := smtp.Dial(fmt.Sprintf("%v:25", smtp_hostname)) if err != nil { return err } @@ -175,6 +172,12 @@ func (self *SMTPSession) SendMail(data string) error { if err != nil { return err } + tls_config := self.tls_config + tls_config.ServerName = smtp_hostname + err = smtp_client.StartTLS(&tls_config) + if err != nil { + return err + } err = smtp_client.Mail(*self.ReversePathBuffer) if err != nil { return err