From dafe3096cac9e495075e5661704f9b8b0740956a Mon Sep 17 00:00:00 2001 From: Gnarwhal Date: Wed, 2 Oct 2024 00:28:14 +0000 Subject: [PATCH] Wow it forwards mail :D --- cmd/server/main.go | 6 +- smtp/commands.go | 79 ++++++++++++++++---------- smtp/parsers.go | 1 + smtp/server.go | 12 ++-- smtp/{connection.go => session.go} | 91 ++++++++++++++++++++++++------ 5 files changed, 136 insertions(+), 53 deletions(-) rename smtp/{connection.go => session.go} (57%) diff --git a/cmd/server/main.go b/cmd/server/main.go index c66fe75..035e477 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -25,7 +25,11 @@ import ( ) func main() { - err := smtp.Run(":4650", false) + err := smtp.Run( + "localhost", + "4650", + false, + ) if err != nil { log.Fatal(err) } diff --git a/smtp/commands.go b/smtp/commands.go index 45c40a4..eb427a8 100644 --- a/smtp/commands.go +++ b/smtp/commands.go @@ -25,10 +25,10 @@ import ( type Command struct { name string - Exec func(SMTPSession, string)(bool, error) + Exec func(*SMTPSession, string)(bool, error) } -func (self Command) Name() string { +func (self Command) GetName() string { return self.name } @@ -36,9 +36,12 @@ func (self Command) Check(message string) bool { return strings.HasPrefix(message, self.name) } -func Greet(smtp_session SMTPSession) error { +func Greet(smtp_session *SMTPSession) error { err := smtp_session.Write( - "220 localhost ESMTPSession diodemail -- Service ready" + "\n", + fmt.Sprintf( + "220 %v ESMTP diodemail -- Service ready" + "\n", + smtp_session.GetHost(), + ), ) if err != nil { return err @@ -47,93 +50,111 @@ func Greet(smtp_session SMTPSession) error { return nil } -func Helo(smtp_session SMTPSession, message string) (bool, error) { +func Helo(smtp_session *SMTPSession, message string) (bool, error) { err := smtp_session.Write( fmt.Sprintf( "250 %v is shy" + "\r\n", - smtp_session.GetAddress(), + smtp_session.GetHost(), ), ) if err != nil { - return false, err + return true, err } return false, nil } -func Ehlo(smtp_session SMTPSession, message string) (bool, error) { +func Ehlo(smtp_session *SMTPSession, message string) (bool, error) { err := smtp_session.Write( fmt.Sprintf( "250 %v is shy" + "\r\n", - smtp_session.GetAddress(), + smtp_session.GetHost(), ), ) if err != nil { - return false, err + return true, err } return false, nil } -func MailFrom(smtp_session SMTPSession, message string) (bool, error) { +func MailFrom(smtp_session *SMTPSession, message string) (bool, error) { match := ReversePath.FindStringSubmatch(message) if match == nil { - smtp_session.Write( - "501 Could not parse reverse-path", + err := smtp_session.Write( + "501 Could not parse reverse-path\n", ) + if err != nil { + return true, err + } } else if len(match) > 1 { smtp_session.MailFrom(match[1]) err := smtp_session.Write( "250 OK\n", ) if err != nil { - return false, err + return true, err } } return false, nil } -func RcptTo(smtp_session SMTPSession, message string) (bool, error) { +func RcptTo(smtp_session *SMTPSession, message string) (bool, error) { match := ForwardPath.FindStringSubmatch(message) if match == nil { - smtp_session.Write( - "501 Could not parse forward-path", + err := smtp_session.Write( + "501 Could not parse forward-path\n", ) + if err != nil { + return true, err + } } else { smtp_session.AddRecipient(match[1]) err := smtp_session.Write( "250 OK\n", ) if err != nil { - return false, err + return true, err } } return false, nil } -func Data(smtp_session SMTPSession, message string) (bool, error) { +func Data(smtp_session *SMTPSession, message string) (bool, error) { err := smtp_session.Write( "354 Start Input\n", ) if err != nil { - return false, err + return true, err } data, err := smtp_session.Read("\r\n.\r\n") - smtp_session.SetMailBuffer(data) + if err != nil { + smtp_session.Write( + "550 Action not taken\n", + ) + return false, err + } + err = smtp_session.SendMail(data) + if err != nil { + smtp_session.Write( + "550 Action not taken\n", + ) + return false, err + } + smtp_session.Reset() err = smtp_session.Write( "250 OK\n", ) - smtp_session.Reset() if err != nil { - return false, err + return true, err } return false, nil } -func Quit(smtp_session SMTPSession, message string) (bool, error) { +func Quit(smtp_session *SMTPSession, message string) (bool, error) { err := smtp_session.Write( "221 Goodbye :)\n", ) @@ -144,24 +165,24 @@ func Quit(smtp_session SMTPSession, message string) (bool, error) { return true, nil } -func Noop(smtp_session SMTPSession, message string) (bool, error) { +func Noop(smtp_session *SMTPSession, message string) (bool, error) { err := smtp_session.Write( "250 OK\n", ) if err != nil { - return false, err + return true, err } return false, nil } -func Rset(smtp_session SMTPSession, message string) (bool, error) { +func Rset(smtp_session *SMTPSession, message string) (bool, error) { + smtp_session.Reset() err := smtp_session.Write( "250 Reset\n", ) - smtp_session.Reset() if err != nil { - return false, err + return true, err } return false, nil diff --git a/smtp/parsers.go b/smtp/parsers.go index efe7279..c090ac7 100644 --- a/smtp/parsers.go +++ b/smtp/parsers.go @@ -25,6 +25,7 @@ import ( var ReversePath = regexp.MustCompile(fmt.Sprintf("(?:%v)|<>", path)) var ForwardPath = regexp.MustCompile(path) +var Domain = regexp.MustCompile("\\w+@(\\w+(?:\\.\\w+)*)") // https://datatracker.ietf.org/doc/html/rfc5321#page-41 // Is this...legal, m'lord? (no, but ¯\_(ツ)_/¯) diff --git a/smtp/server.go b/smtp/server.go index 16a2727..1d07ad5 100644 --- a/smtp/server.go +++ b/smtp/server.go @@ -19,6 +19,7 @@ package smtp import ( + "fmt" "net" "os" @@ -30,14 +31,14 @@ type PlainListener struct { listener net.Listener } -func handle(connection net.Conn) { +func handle(connection net.Conn, host string) { log.Info().Msgf( "New connection %v. Starting session.", connection.RemoteAddr(), ) defer connection.Close() - session := MakeSMTPSession(connection) + session := MakeSMTPSession(connection, host) err := session.Run() if err != nil { log.Error().Msgf( @@ -53,7 +54,7 @@ func handle(connection net.Conn) { } } -func Run(host string, implicit_tls bool) error { +func Run(host string, port string, implicit_tls bool) error { log.Logger = zerolog. New(zerolog.ConsoleWriter{Out: os.Stderr}). With(). @@ -61,18 +62,19 @@ func Run(host string, implicit_tls bool) error { Logger(). Level(zerolog.TraceLevel) - listener, err := net.Listen("tcp", host) + listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, port)) if err != nil { return err } + log.Info().Msgf("Server started on port %v for host %v", port, host) for { connection, err := listener.Accept() if err != nil { return err } - go handle(connection) + go handle(connection, host) } } diff --git a/smtp/connection.go b/smtp/session.go similarity index 57% rename from smtp/connection.go rename to smtp/session.go index a1dfdee..448ebbe 100644 --- a/smtp/connection.go +++ b/smtp/session.go @@ -20,30 +20,32 @@ package smtp import ( "net" + "net/smtp" "fmt" "strings" + "crypto/tls" "github.com/rs/zerolog/log" ) type SMTPSession struct { connection net.Conn + host string ReversePathBuffer *string ForwardPathBuffer []string - MailBuffer *string } -func MakeSMTPSession(connection net.Conn) SMTPSession { +func MakeSMTPSession(connection net.Conn, host string) SMTPSession { return SMTPSession{ connection, + host, nil, []string{}, - nil, } } -func (self SMTPSession) Run() error { +func (self *SMTPSession) Run() error { err := Greet(self) if err != nil { return err @@ -69,9 +71,13 @@ func (self SMTPSession) Run() error { for _, command := range COMMANDS { if command.Check(message) { command_found = true - quit, err = command.Exec(self, message[len(command.Name()):]) + quit, err = command.Exec(self, message[len(command.GetName()):]) if err != nil { - return err + if quit { + return err + } else { + log.Error().Msgf("%v: %v", self.connection.RemoteAddr(), err) + } } break } @@ -83,14 +89,14 @@ func (self SMTPSession) Run() error { return nil } -func (self SMTPSession) TraceSession(direction string, message string) { +func (self *SMTPSession) TraceSession(direction string, message string) { lines := strings.Split(message, "\n") for index, line := range lines { if index != len(lines) - 1 || line != "" { if index == 0 { log.Trace().Msgf("%v %v %v", self.connection.RemoteAddr(), direction, line) } else { - log.Trace().Msgf("%v %v", self.connection.RemoteAddr(), line) + log.Trace().Msgf("%v %v", self.connection.RemoteAddr(), line) } } } @@ -98,7 +104,7 @@ func (self SMTPSession) TraceSession(direction string, message string) { } var buffer = [1024]byte{} -func (self SMTPSession) Read(delimiter string) (string, error) { +func (self *SMTPSession) Read(delimiter string) (string, error) { var message string for !strings.Contains(message, delimiter) { num_read, err := self.connection.Read(buffer[:]) @@ -112,31 +118,80 @@ func (self SMTPSession) Read(delimiter string) (string, error) { return message[:len(message) - len(delimiter)], nil } -func (self SMTPSession) Write(message string) error { +func (self *SMTPSession) Write(message string) error { self.TraceSession("<- ", message) _, err := self.connection.Write([]byte(message)) return err } -func (self SMTPSession) GetAddress() string { - return fmt.Sprint(self.connection.LocalAddr()) +func (self *SMTPSession) GetHost() string { + return self.host } -func (self SMTPSession) MailFrom(from string) { +func (self *SMTPSession) MailFrom(from string) { self.Reset() self.ReversePathBuffer = &from } -func (self SMTPSession) AddRecipient(recipient string) { +func (self *SMTPSession) AddRecipient(recipient string) { self.ForwardPathBuffer = append(self.ForwardPathBuffer, recipient) } -func (self SMTPSession) SetMailBuffer(data string) { - self.MailBuffer = &data +func (self *SMTPSession) SendMail(data string) error { + mx_hostname := Domain.FindStringSubmatch(self.ForwardPathBuffer[0])[1] + mx, err := net.LookupMX(mx_hostname) + if err != nil { + return err + } + var smtp_hostname string + if len(mx) == 0 { + smtp_hostname = self.host + } else { + smtp_hostname = mx[0].Host + } + tls_connection, err := tls.Dial("tcp", fmt.Sprintf("%v:4560", smtp_hostname), nil) + if err != nil { + return err + } + smtp_client, err := smtp.NewClient(tls_connection, "") + if err != nil { + return err + } + defer smtp_client.Close() + + err = smtp_client.Hello(self.host) + if err != nil { + return err + } + err = smtp_client.Mail(*self.ReversePathBuffer) + if err != nil { + return err + } + err = smtp_client.Rcpt(self.ForwardPathBuffer[0]) + if err != nil { + return err + } + writer, err := smtp_client.Data() + if err != nil { + return err + } + _, err = writer.Write([]byte(data)) + if err != nil { + return err + } + err = writer.Close() + if err != nil { + return err + } + err = smtp_client.Quit() + if err != nil { + return err + } + + return nil } -func (self SMTPSession) Reset() { +func (self *SMTPSession) Reset() { self.ReversePathBuffer = nil self.ForwardPathBuffer = []string{} - self.MailBuffer = nil }