Actual AUTH and also enforce EHLO/HELO first

This commit is contained in:
Gnarwhal 2024-10-04 21:44:10 +00:00
parent f656866e7d
commit 5ce65e6be9
Signed by: Gnarwhal
GPG key ID: 0989A73D8C421174
5 changed files with 133 additions and 26 deletions

View file

@ -26,14 +26,24 @@ import (
type Config struct { type Config struct {
LogLevel string LogLevel string
Host string Host string
PasswordHash string Ports PortConfig
Plain string Certificate CertConfig
TLS string Auth AuthConfig
CertPath string
PrivateKeyPath string
} }
type GeneralConfig struct { type PortConfig struct {
Plain string
TLS string
}
type CertConfig struct {
CertFile string
KeyFile string
}
type AuthConfig struct {
Enabled bool
PasswordHash string
} }
func LoadConfig(path string) (*Config, error) { func LoadConfig(path string) (*Config, error) {
@ -42,8 +52,9 @@ func LoadConfig(path string) (*Config, error) {
return nil, err return nil, err
} }
var config Config var config Config
config.Plain = "disabled" config.Ports.Plain = "disabled"
config.TLS = "disabled" config.Ports.TLS = "disabled"
config.Auth.Enabled = true
err = json.Unmarshal(contents, &config) err = json.Unmarshal(contents, &config)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -83,28 +83,33 @@ func main() {
log.Info().Msgf("Starting diodemail v%v", Version) log.Info().Msgf("Starting diodemail v%v", Version)
log.Info().Msgf("Loaded config from: %v", config_path) log.Info().Msgf("Loaded config from: %v", config_path)
if config.CertPath == "" || config.PrivateKeyPath == "" { if config.Certificate.CertFile == "" || config.Certificate.KeyFile == "" {
log.Fatal().Msgf( log.Fatal().Msgf(
"Must provide CertPath (got '%v') and PrivateKeyPath (got '%v')", "Must provide CertFile (got '%v') and KeyFile (got '%v')",
config.CertPath, config.Certificate.CertFile,
config.PrivateKeyPath, config.Certificate.KeyFile,
) )
} }
certificate, err := tls.LoadX509KeyPair( certificate, err := tls.LoadX509KeyPair(
config.CertPath, config.Certificate.CertFile,
config.PrivateKeyPath, config.Certificate.KeyFile,
) )
if err != nil { if err != nil {
log.Fatal().Msgf("Failed to load TLS config: %v", err) log.Fatal().Msgf("Failed to load TLS config: %v", err)
} }
if config.Auth.Enabled && config.Auth.PasswordHash == "" {
log.Fatal().Msgf("Authentication is enabled but no password hash was supplied")
}
err = smtp.Run( err = smtp.Run(
config.Host, config.Host,
config.PasswordHash, config.Ports.Plain,
config.Plain, config.Ports.TLS,
config.TLS,
tls.Config{Certificates: []tls.Certificate{certificate}}, tls.Config{Certificates: []tls.Certificate{certificate}},
config.Auth.Enabled,
config.Auth.PasswordHash,
) )
if err != nil { if err != nil {
log.Fatal().Msgf("%v", err) log.Fatal().Msgf("%v", err)

View file

@ -61,27 +61,43 @@ func Helo(smtp_session *SMTPSession, message string) (bool, error) {
return true, err return true, err
} }
smtp_session.HasHelloed = true
return false, nil return false, nil
} }
func Ehlo(smtp_session *SMTPSession, message string) (bool, error) { func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
enable_auth := ""
if smtp_session.RequiresAuthentication() {
enable_auth = "250-AUTH PLAIN\r\n"
}
err := smtp_session.Write( err := smtp_session.Write(
fmt.Sprintf( fmt.Sprintf(
"250-%v is shy" + "\r\n" + "250-%v is shy" + "\r\n" +
"250-AUTH PLAIN" + "\r\n" + "%v" +
"250-8BITMIME" + "\r\n" + "250-8BITMIME" + "\r\n" +
"250 SMTPUTF8" + "\r\n", "250 SMTPUTF8" + "\r\n",
smtp_session.GetHost(), smtp_session.GetHost(),
enable_auth,
), ),
) )
if err != nil { if err != nil {
return true, err return true, err
} }
smtp_session.HasHelloed = true
return false, nil return false, nil
} }
func Auth(smtp_session *SMTPSession, message string) (bool, error) { func Auth(smtp_session *SMTPSession, message string) (bool, error) {
if !smtp_session.HasHelloed {
err := smtp_session.Write(
"503 Must HELO/EHLO first\n",
)
if err != nil {
return true, err
}
return false, nil
}
parts := strings.Split(message, " ")[1:] parts := strings.Split(message, " ")[1:]
if parts[0] == "PLAIN" { if parts[0] == "PLAIN" {
var password string var password string
@ -107,6 +123,24 @@ func Auth(smtp_session *SMTPSession, message string) (bool, error) {
} }
func MailFrom(smtp_session *SMTPSession, message string) (bool, error) { func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
if !smtp_session.HasHelloed {
err := smtp_session.Write(
"503 Must HELO/EHLO first\n",
)
if err != nil {
return true, err
}
return false, nil
}
if !smtp_session.AuthenticationSatisfied() {
err := smtp_session.Write(
"530 Authentication required\n",
)
if err != nil {
return true, err
}
return false, nil
}
match := ReversePath.FindStringSubmatch(message) match := ReversePath.FindStringSubmatch(message)
if match == nil { if match == nil {
err := smtp_session.Write( err := smtp_session.Write(
@ -129,6 +163,24 @@ func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
} }
func RcptTo(smtp_session *SMTPSession, message string) (bool, error) { func RcptTo(smtp_session *SMTPSession, message string) (bool, error) {
if !smtp_session.HasHelloed {
err := smtp_session.Write(
"503 Must HELO/EHLO first\n",
)
if err != nil {
return true, err
}
return false, nil
}
if !smtp_session.AuthenticationSatisfied() {
err := smtp_session.Write(
"530 Authentication required\n",
)
if err != nil {
return true, err
}
return false, nil
}
match := ForwardPath.FindStringSubmatch(message) match := ForwardPath.FindStringSubmatch(message)
if match == nil { if match == nil {
err := smtp_session.Write( err := smtp_session.Write(
@ -151,6 +203,24 @@ func RcptTo(smtp_session *SMTPSession, message string) (bool, error) {
} }
func Data(smtp_session *SMTPSession, message string) (bool, error) { func Data(smtp_session *SMTPSession, message string) (bool, error) {
if !smtp_session.HasHelloed {
err := smtp_session.Write(
"503 Must HELO/EHLO first\n",
)
if err != nil {
return true, err
}
return false, nil
}
if !smtp_session.AuthenticationSatisfied() {
err := smtp_session.Write(
"530 Authentication required\n",
)
if err != nil {
return true, err
}
return false, nil
}
err := smtp_session.Write( err := smtp_session.Write(
"354 Start Input\n", "354 Start Input\n",
) )

View file

@ -27,14 +27,20 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func handle(connection net.Conn, host string, password_hash string, tls_config tls.Config) { func handle(
connection net.Conn,
host string,
tls_config tls.Config,
auth bool,
password_hash string,
) {
log.Info().Msgf( log.Info().Msgf(
"New connection %v. Starting session.", "New connection %v. Starting session.",
connection.RemoteAddr(), connection.RemoteAddr(),
) )
defer connection.Close() defer connection.Close()
session := MakeSMTPSession(connection, host, password_hash, tls_config) session := MakeSMTPSession(connection, host, tls_config, auth, password_hash)
err := session.Run() err := session.Run()
if err != nil { if err != nil {
log.Error().Msgf( log.Error().Msgf(
@ -52,10 +58,11 @@ func handle(connection net.Conn, host string, password_hash string, tls_config t
func Run( func Run(
host string, host string,
password_hash string,
plain_port string, plain_port string,
tls_port string, tls_port string,
tls_config tls.Config, tls_config tls.Config,
auth bool,
password_hash string,
) error { ) error {
var wait_group sync.WaitGroup var wait_group sync.WaitGroup
if plain_port != "disabled" { if plain_port != "disabled" {
@ -65,7 +72,7 @@ func Run(
} }
log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host) log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host)
wait_group.Add(1) wait_group.Add(1)
go Listen(wait_group, host, password_hash, tls_config, listener) go Listen(wait_group, host, tls_config, auth, password_hash, listener)
} }
if tls_port != "disabled" { if tls_port != "disabled" {
listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config) listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config)
@ -74,7 +81,7 @@ func Run(
} }
log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host) log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host)
wait_group.Add(1) wait_group.Add(1)
go Listen(wait_group, host, password_hash, tls_config, listener) go Listen(wait_group, host, tls_config, auth, password_hash, listener)
} }
wait_group.Wait() wait_group.Wait()
@ -84,8 +91,9 @@ func Run(
func Listen( func Listen(
wait_group sync.WaitGroup, wait_group sync.WaitGroup,
host string, host string,
password_hash string,
tls_config tls.Config, tls_config tls.Config,
auth bool,
password_hash string,
listener net.Listener, listener net.Listener,
) { ) {
defer wait_group.Done() defer wait_group.Done()
@ -94,7 +102,7 @@ func Listen(
if err != nil { if err != nil {
log.Error().Msgf("Failed to accept client: %v", err) log.Error().Msgf("Failed to accept client: %v", err)
} else { } else {
go handle(connection, host, password_hash, tls_config) go handle(connection, host, tls_config, auth, password_hash)
} }
} }
} }

View file

@ -35,16 +35,20 @@ type SMTPSession struct {
password_hash string password_hash string
tls_config tls.Config tls_config tls.Config
buffer [4096]byte buffer [4096]byte
HasHelloed bool
requires_auth bool
is_authed bool
ReversePathBuffer *string ReversePathBuffer *string
ForwardPathBuffer []string ForwardPathBuffer []string
} }
func MakeSMTPSession(connection net.Conn, host string, password_hash string, tls_config tls.Config) SMTPSession { func MakeSMTPSession(connection net.Conn, host string, tls_config tls.Config, auth bool, password_hash string) SMTPSession {
return SMTPSession{ return SMTPSession{
connection: connection, connection: connection,
host: host, host: host,
password_hash: password_hash, password_hash: password_hash,
requires_auth: auth,
} }
} }
@ -137,6 +141,15 @@ func (self *SMTPSession) GetHost() string {
return self.host return self.host
} }
func (self *SMTPSession) RequiresAuthentication() bool {
return self.requires_auth
}
func (self *SMTPSession) AuthenticationSatisfied() bool {
fmt.Println(self.requires_auth, self.is_authed)
return !self.requires_auth || self.is_authed
}
func (self *SMTPSession) ValidatePassword(password string) bool { func (self *SMTPSession) ValidatePassword(password string) bool {
return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash
} }