Added AUTH command and PLAIN sasl mechanism

This commit is contained in:
Gnarwhal 2024-10-02 07:56:56 +00:00
parent 3523f1cf4e
commit 86dc81451a
Signed by: Gnarwhal
GPG key ID: 0989A73D8C421174
7 changed files with 75 additions and 24 deletions

2
.gitignore vendored
View file

@ -21,3 +21,5 @@
# ...even if they are in subdirectories # ...even if they are in subdirectories
!*/ !*/
!hash.sh

View file

@ -24,12 +24,17 @@ import (
) )
type Config struct { type Config struct {
LogLevel string General GeneralConfig
Host string
Plain PlainConfig Plain PlainConfig
TLS TLSConfig TLS TLSConfig
} }
type GeneralConfig struct {
LogLevel string
Host string
PasswordHash string
}
type PlainConfig struct { type PlainConfig struct {
Enabled bool Enabled bool
Port string Port string

View file

@ -62,8 +62,8 @@ func main() {
log.Fatal().Msgf("%v", err) log.Fatal().Msgf("%v", err)
} }
if log_level == "" { if log_level == "" {
if config.LogLevel != "" { if config.General.LogLevel != "" {
log_level = config.LogLevel log_level = config.General.LogLevel
} else { } else {
log_level = "info" log_level = "info"
} }
@ -107,7 +107,8 @@ func main() {
} }
err = smtp.Run( err = smtp.Run(
config.Host, config.General.Host,
config.General.PasswordHash,
plain_config, plain_config,
tls_config, tls_config,
) )

3
hash.sh Executable file
View file

@ -0,0 +1,3 @@
#!/bin/sh
printf "$@" | base64 | tr -d '\n' | sha256sum

View file

@ -67,9 +67,10 @@ func Helo(smtp_session *SMTPSession, message string) (bool, error) {
func Ehlo(smtp_session *SMTPSession, message string) (bool, error) { func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
err := smtp_session.Write( err := smtp_session.Write(
fmt.Sprintf( fmt.Sprintf(
"250-%v is shy" + "\r\n" + "250-%v is shy" + "\r\n" +
"250-8BITMIME" + "\r\n" + "250-AUTH PLAIN" + "\r\n" +
"250 SMTPUTF8" + "\r\n", "250-8BITMIME" + "\r\n" +
"250 SMTPUTF8" + "\r\n",
smtp_session.GetHost(), smtp_session.GetHost(),
), ),
) )
@ -80,6 +81,31 @@ func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
return false, nil return false, nil
} }
func Auth(smtp_session *SMTPSession, message string) (bool, error) {
parts := strings.Split(message, " ")[1:]
if parts[0] == "PLAIN" {
var password string
if len(parts) > 1 {
password = parts[1][:len(parts[1]) - 1]
} else {
err := smtp_session.Write("334 \n")
if err != nil {
return true, nil
}
password, err = smtp_session.ReadUntil("\r\n")
if err != nil {
return true, nil
}
}
if smtp_session.ValidatePassword(password) {
smtp_session.Write("235 Authentication successful\n")
} else {
smtp_session.Write("535 Authentication credentials invalid\n")
}
}
return false, nil
}
func MailFrom(smtp_session *SMTPSession, message string) (bool, error) { func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
match := ReversePath.FindStringSubmatch(message) match := ReversePath.FindStringSubmatch(message)
if match == nil { if match == nil {
@ -131,7 +157,7 @@ func Data(smtp_session *SMTPSession, message string) (bool, error) {
if err != nil { if err != nil {
return true, err return true, err
} }
data, err := smtp_session.Read("\r\n.\r\n") data, err := smtp_session.ReadUntil("\r\n.\r\n")
if err != nil { if err != nil {
smtp_session.Write( smtp_session.Write(
"550 Action not taken\n", "550 Action not taken\n",

View file

@ -36,14 +36,14 @@ type TLSConfig struct {
TlsConfig tls.Config TlsConfig tls.Config
} }
func handle(connection net.Conn, host string) { func handle(connection net.Conn, host string, password_hash string) {
log.Info().Msgf( log.Info().Msgf(
"New connection %v. Starting session.", "New connection %v. Starting session.",
connection.RemoteAddr(), connection.RemoteAddr(),
) )
defer connection.Close() defer connection.Close()
session := MakeSMTPSession(connection, host) session := MakeSMTPSession(connection, host, password_hash)
err := session.Run() err := session.Run()
if err != nil { if err != nil {
log.Error().Msgf( log.Error().Msgf(
@ -59,7 +59,7 @@ func handle(connection net.Conn, host string) {
} }
} }
func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error { func Run(host string, password_hash string, plain_config *PlainConfig, tls_config *TLSConfig) error {
var wait_group sync.WaitGroup var wait_group sync.WaitGroup
if plain_config != nil { if plain_config != nil {
listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, plain_config.Port)) listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, plain_config.Port))
@ -68,7 +68,7 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
} }
log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host) log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host)
wait_group.Add(1) wait_group.Add(1)
go Listen(wait_group, host, listener) go Listen(wait_group, host, password_hash, listener)
} }
if tls_config != nil { if tls_config != nil {
listener, err := tls.Listen("tcp", fmt.Sprintf("%v:%v", host, tls_config.Port), &tls_config.TlsConfig) listener, err := tls.Listen("tcp", fmt.Sprintf("%v:%v", host, tls_config.Port), &tls_config.TlsConfig)
@ -77,18 +77,18 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
} }
log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host) log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host)
wait_group.Add(1) wait_group.Add(1)
go Listen(wait_group, host, listener) go Listen(wait_group, host, password_hash, listener)
} }
wait_group.Wait() wait_group.Wait()
return nil return nil
} }
func Listen(wait_group sync.WaitGroup, host string, listener net.Listener) { func Listen(wait_group sync.WaitGroup, host string, password_hash string, listener net.Listener) {
defer wait_group.Done() defer wait_group.Done()
for { for {
connection, _ := listener.Accept() connection, _ := listener.Accept()
go handle(connection, host) go handle(connection, host, password_hash)
} }
} }

View file

@ -24,23 +24,26 @@ import (
"fmt" "fmt"
"strings" "strings"
"crypto/tls" "crypto/tls"
"crypto/sha256"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type SMTPSession struct { type SMTPSession struct {
connection net.Conn connection net.Conn
host string host string
buffer [4096]byte password_hash string
buffer [4096]byte
ReversePathBuffer *string ReversePathBuffer *string
ForwardPathBuffer []string ForwardPathBuffer []string
} }
func MakeSMTPSession(connection net.Conn, host string) SMTPSession { func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession {
return SMTPSession{ return SMTPSession{
connection: connection, connection: connection,
host: host, host: host,
password_hash: password_hash,
} }
} }
@ -52,6 +55,7 @@ func (self *SMTPSession) Run() error {
COMMANDS := []Command{ COMMANDS := []Command{
Command{ "HELO", Helo }, Command{ "HELO", Helo },
Command{ "EHLO", Ehlo }, Command{ "EHLO", Ehlo },
Command{ "AUTH", Auth },
Command{ "MAIL FROM:", MailFrom }, Command{ "MAIL FROM:", MailFrom },
Command{ "RCPT TO:", RcptTo }, Command{ "RCPT TO:", RcptTo },
Command{ "DATA", Data }, Command{ "DATA", Data },
@ -61,7 +65,7 @@ func (self *SMTPSession) Run() error {
} }
quit := false quit := false
for !quit { for !quit {
message, err := self.Read("\n") message, err := self.ReadUntil("\n")
if err != nil { if err != nil {
return err return err
} }
@ -102,7 +106,7 @@ func (self *SMTPSession) TraceSession(direction string, message string) {
} }
func (self *SMTPSession) Read(delimiter string) (string, error) { func (self *SMTPSession) ReadUntil(delimiter string) (string, error) {
var message string var message string
for !strings.Contains(message, delimiter) { for !strings.Contains(message, delimiter) {
num_read, err := self.connection.Read(self.buffer[:]) num_read, err := self.connection.Read(self.buffer[:])
@ -116,6 +120,12 @@ func (self *SMTPSession) Read(delimiter string) (string, error) {
return message[:len(message) - len(delimiter)], nil return message[:len(message) - len(delimiter)], nil
} }
func (self *SMTPSession) ReadCount(num_bytes int) ([]byte, error) {
bytes := make([]byte, num_bytes)
_, err := self.connection.Read(bytes[:])
return bytes, err
}
func (self *SMTPSession) Write(message string) error { func (self *SMTPSession) Write(message string) error {
self.TraceSession("<- ", message) self.TraceSession("<- ", message)
_, err := self.connection.Write([]byte(message)) _, err := self.connection.Write([]byte(message))
@ -126,6 +136,10 @@ func (self *SMTPSession) GetHost() string {
return self.host return self.host
} }
func (self *SMTPSession) ValidatePassword(password string) bool {
return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash
}
func (self *SMTPSession) MailFrom(from string) { func (self *SMTPSession) MailFrom(from string) {
self.Reset() self.Reset()
self.ReversePathBuffer = &from self.ReversePathBuffer = &from