diff --git a/smtp/commands.go b/smtp/commands.go new file mode 100644 index 0000000..45c40a4 --- /dev/null +++ b/smtp/commands.go @@ -0,0 +1,168 @@ +/* diodemail - send-only smtp server + * Copyright (c) 2024 Gnarwhal + * + * This file is part of SSHare. + * + * SSHare is free software: you can redistribute it and/or modify it under the terms of + * the GNU General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * SSHare is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * SSHare. If not, see . + */ + +package smtp + +import ( + "fmt" + "strings" +) + +type Command struct { + name string + Exec func(SMTPSession, string)(bool, error) +} + +func (self Command) Name() string { + return self.name +} + +func (self Command) Check(message string) bool { + return strings.HasPrefix(message, self.name) +} + +func Greet(smtp_session SMTPSession) error { + err := smtp_session.Write( + "220 localhost ESMTPSession diodemail -- Service ready" + "\n", + ) + if err != nil { + return err + } + + return nil +} + +func Helo(smtp_session SMTPSession, message string) (bool, error) { + err := smtp_session.Write( + fmt.Sprintf( + "250 %v is shy" + "\r\n", + smtp_session.GetAddress(), + ), + ) + if err != nil { + return false, err + } + + return false, nil +} + +func Ehlo(smtp_session SMTPSession, message string) (bool, error) { + err := smtp_session.Write( + fmt.Sprintf( + "250 %v is shy" + "\r\n", + smtp_session.GetAddress(), + ), + ) + if err != nil { + return false, err + } + + return false, nil +} + +func MailFrom(smtp_session SMTPSession, message string) (bool, error) { + match := ReversePath.FindStringSubmatch(message) + if match == nil { + smtp_session.Write( + "501 Could not parse reverse-path", + ) + } else if len(match) > 1 { + smtp_session.MailFrom(match[1]) + err := smtp_session.Write( + "250 OK\n", + ) + if err != nil { + return false, err + } + } + + return false, nil +} + +func RcptTo(smtp_session SMTPSession, message string) (bool, error) { + match := ForwardPath.FindStringSubmatch(message) + if match == nil { + smtp_session.Write( + "501 Could not parse forward-path", + ) + } else { + smtp_session.AddRecipient(match[1]) + err := smtp_session.Write( + "250 OK\n", + ) + if err != nil { + return false, err + } + } + + return false, nil +} + +func Data(smtp_session SMTPSession, message string) (bool, error) { + err := smtp_session.Write( + "354 Start Input\n", + ) + if err != nil { + return false, err + } + data, err := smtp_session.Read("\r\n.\r\n") + smtp_session.SetMailBuffer(data) + err = smtp_session.Write( + "250 OK\n", + ) + smtp_session.Reset() + if err != nil { + return false, err + } + + return false, nil +} + +func Quit(smtp_session SMTPSession, message string) (bool, error) { + err := smtp_session.Write( + "221 Goodbye :)\n", + ) + if err != nil { + return true, err + } + + return true, nil +} + +func Noop(smtp_session SMTPSession, message string) (bool, error) { + err := smtp_session.Write( + "250 OK\n", + ) + if err != nil { + return false, err + } + + return false, nil +} + +func Rset(smtp_session SMTPSession, message string) (bool, error) { + err := smtp_session.Write( + "250 Reset\n", + ) + smtp_session.Reset() + if err != nil { + return false, err + } + + return false, nil +} diff --git a/smtp/connection.go b/smtp/connection.go index 7fbee1d..a1dfdee 100644 --- a/smtp/connection.go +++ b/smtp/connection.go @@ -20,13 +20,13 @@ package smtp import ( "net" - // "fmt" + "fmt" "strings" "github.com/rs/zerolog/log" ) -type Connection struct { +type SMTPSession struct { connection net.Conn ReversePathBuffer *string @@ -34,8 +34,8 @@ type Connection struct { MailBuffer *string } -func NewConnection(connection net.Conn) Connection { - return Connection{ +func MakeSMTPSession(connection net.Conn) SMTPSession { + return SMTPSession{ connection, nil, []string{}, @@ -43,8 +43,62 @@ func NewConnection(connection net.Conn) Connection { } } +func (self SMTPSession) Run() error { + err := Greet(self) + if err != nil { + return err + } + COMMANDS := []Command{ + Command{ "HELO", Helo }, + Command{ "EHLO", Ehlo }, + Command{ "MAIL FROM:", MailFrom }, + Command{ "RCPT TO:", RcptTo }, + Command{ "DATA", Data }, + Command{ "QUIT", Quit }, + Command{ "NOOP", Noop }, + Command{ "RSET", Rset }, + } + quit := false + for !quit { + message, err := self.Read("\n") + if err != nil { + return err + } + + command_found := false + for _, command := range COMMANDS { + if command.Check(message) { + command_found = true + quit, err = command.Exec(self, message[len(command.Name()):]) + if err != nil { + return err + } + break + } + } + if !command_found { + self.Write("500 Unrecognized Command\n") + } + } + return nil +} + +func (self SMTPSession) TraceSession(direction string, message string) { + lines := strings.Split(message, "\n") + for index, line := range lines { + if index != len(lines) - 1 || line != "" { + if index == 0 { + log.Trace().Msgf("%v %v %v", self.connection.RemoteAddr(), direction, line) + } else { + log.Trace().Msgf("%v %v", self.connection.RemoteAddr(), line) + } + } + } + +} + var buffer = [1024]byte{} -func (self Connection) Read(delimiter string) (string, error) { +func (self SMTPSession) Read(delimiter string) (string, error) { var message string for !strings.Contains(message, delimiter) { num_read, err := self.connection.Read(buffer[:]) @@ -53,50 +107,36 @@ func (self Connection) Read(delimiter string) (string, error) { } message += string(buffer[:num_read]) } - - lines := strings.Split(message, "\n") - for index, line := range lines { - if index != len(lines) - 1 || line != "" { - if index == 0 { - log.Trace().Msgf("%v -> %v", self.RemoteAddr(), line) - } else { - log.Trace().Msgf("%v %v", self.RemoteAddr(), line) - } - } - } + self.TraceSession(" ->", message) return message[:len(message) - len(delimiter)], nil } -func (self Connection) Write(message string) error { - lines := strings.Split(message, "\n") - for index, line := range lines { - if index != len(lines) - 1 || line != "" { - if index == 0 { - log.Trace().Msgf("%v <- %v", self.RemoteAddr(), line) - } else { - log.Trace().Msgf("%v %v", self.RemoteAddr(), line) - } - } - } +func (self SMTPSession) Write(message string) error { + self.TraceSession("<- ", message) _, err := self.connection.Write([]byte(message)) return err } -func (self Connection) AddRecipient(recipient string) { +func (self SMTPSession) GetAddress() string { + return fmt.Sprint(self.connection.LocalAddr()) +} + +func (self SMTPSession) MailFrom(from string) { + self.Reset() + self.ReversePathBuffer = &from +} + +func (self SMTPSession) AddRecipient(recipient string) { self.ForwardPathBuffer = append(self.ForwardPathBuffer, recipient) } -func (self Connection) Reset() { +func (self SMTPSession) SetMailBuffer(data string) { + self.MailBuffer = &data +} + +func (self SMTPSession) Reset() { self.ReversePathBuffer = nil self.ForwardPathBuffer = []string{} self.MailBuffer = nil } - -func (self Connection) Close() { - self.connection.Close() -} - -func (self Connection) RemoteAddr() net.Addr { - return self.connection.RemoteAddr() -} diff --git a/smtp/handlers.go b/smtp/handlers.go deleted file mode 100644 index 8ba5560..0000000 --- a/smtp/handlers.go +++ /dev/null @@ -1,214 +0,0 @@ -/* diodemail - send-only smtp server - * Copyright (c) 2024 Gnarwhal - * - * This file is part of SSHare. - * - * SSHare is free software: you can redistribute it and/or modify it under the terms of - * the GNU General Public License as published by the Free Software Foundation, - * either version 3 of the License, or (at your option) any later version. - * - * SSHare is distributed in the hope that it will be useful, but WITHOUT ANY - * WARRANTY; without even the implied warranty of MERCHANTABILITY - * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. - * - * You should have received a copy of the GNU General Public License along with - * SSHare. If not, see . - */ - -package smtp - -import ( - "fmt" - "strings" -) - -type Command struct { - name string - Exec func(Connection, string)(bool, error) -} - -func (self Command) Name() string { - return self.name -} - -func (self Command) Check(message string) bool { - return strings.HasPrefix(message, self.name) -} - -func (self Connection) Chain() error { - err := Greet(self) - if err != nil { - return err - } - COMMANDS := []Command{ - Command{ "HELO", Helo }, - Command{ "EHLO", Ehlo }, - Command{ "MAIL FROM:", MailFrom }, - Command{ "RCPT TO:", RcptTo }, - Command{ "DATA", Data }, - Command{ "QUIT", Quit }, - Command{ "NOOP", Noop }, - Command{ "RSET", Rset }, - } - quit := false - for !quit { - message, err := self.Read("\n") - if err != nil { - return err - } - - command_found := false - for _, command := range COMMANDS { - if command.Check(message) { - command_found = true - quit, err = command.Exec(self, message[len(command.Name()):]) - if err != nil { - return err - } - break - } - } - if !command_found { - self.Write("500 Unrecognized Command\n") - } - } - return nil -} - -/* --- GREETING RESPONSE --- */ - -func Greet(connection Connection) error { - err := connection.Write( - "220 localhost ESMTP diodemail -- Service ready" + "\n", - ) - if err != nil { - return err - } - - return nil -} - -/* --- HELO/EHLO RESPONSE --- */ - -func Helo(connection Connection, message string) (bool, error) { - err := connection.Write( - fmt.Sprintf( - "250 %v is shy" + "\r\n", - connection.connection.LocalAddr(), - ), - ) - if err != nil { - return false, err - } - - return false, nil -} - -func Ehlo(connection Connection, message string) (bool, error) { - err := connection.Write( - fmt.Sprintf( - "250 %v is shy" + "\r\n", - connection.connection.LocalAddr(), - ), - ) - if err != nil { - return false, err - } - - return false, nil -} - -/* --- MAIL FROM RESPONSE --- */ - -func MailFrom(connection Connection, message string) (bool, error) { - match := ReversePath.FindStringSubmatch(message) - if match == nil { - connection.Write( - "501 Could not parse reverse-path", - ) - } else if len(match) > 1 { - connection.ReversePathBuffer = &match[1] - err := connection.Write( - "250 OK\n", - ) - if err != nil { - return false, err - } - } - - return false, nil -} - -func RcptTo(connection Connection, message string) (bool, error) { - match := ForwardPath.FindStringSubmatch(message) - if match == nil { - connection.Write( - "501 Could not parse forward-path", - ) - } else { - connection.AddRecipient(match[1]) - err := connection.Write( - "250 OK\n", - ) - if err != nil { - return false, err - } - } - - return false, nil -} - -func Data(connection Connection, message string) (bool, error) { - err := connection.Write( - "354 Start Input\n", - ) - if err != nil { - return false, err - } - data, err := connection.Read("\r\n.\r\n") - connection.MailBuffer = &data - err = connection.Write( - "250 OK\n", - ) - connection.Reset() - if err != nil { - return false, err - } - - return false, nil -} - -func Quit(connection Connection, message string) (bool, error) { - err := connection.Write( - "221 Goodbye :)\n", - ) - if err != nil { - return true, err - } - - return true, nil -} - -func Noop(connection Connection, message string) (bool, error) { - err := connection.Write( - "250 OK\n", - ) - if err != nil { - return false, err - } - - return false, nil -} - -func Rset(connection Connection, message string) (bool, error) { - err := connection.Write( - "250 Reset\n", - ) - connection.Reset() - if err != nil { - return false, err - } - - return false, nil -} diff --git a/smtp/server.go b/smtp/server.go index 5ace672..16a2727 100644 --- a/smtp/server.go +++ b/smtp/server.go @@ -30,14 +30,26 @@ type PlainListener struct { listener net.Listener } -func handle(connection Connection) { - log.Info().Msgf("New connection %v", connection.RemoteAddr()) +func handle(connection net.Conn) { + log.Info().Msgf( + "New connection %v. Starting session.", + connection.RemoteAddr(), + ) defer connection.Close() - err := connection.Chain() + + session := MakeSMTPSession(connection) + err := session.Run() if err != nil { - log.Error().Msgf("Failed to serve %v: %v", connection.RemoteAddr(), err) + log.Error().Msgf( + "Session %v exited with error: %v", + connection.RemoteAddr(), + err, + ) } else { - log.Info().Msgf("Successfully served %v", connection.RemoteAddr()) + log.Info().Msgf( + "Session %v exited successfully", + connection.RemoteAddr(), + ) } } @@ -60,7 +72,7 @@ func Run(host string, implicit_tls bool) error { return err } - go handle(NewConnection(connection)) + go handle(connection) } }