diff --git a/smtp/commands.go b/smtp/commands.go
new file mode 100644
index 0000000..45c40a4
--- /dev/null
+++ b/smtp/commands.go
@@ -0,0 +1,168 @@
+/* diodemail - send-only smtp server
+ * Copyright (c) 2024 Gnarwhal
+ *
+ * This file is part of SSHare.
+ *
+ * SSHare is free software: you can redistribute it and/or modify it under the terms of
+ * the GNU General Public License as published by the Free Software Foundation,
+ * either version 3 of the License, or (at your option) any later version.
+ *
+ * SSHare is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * SSHare. If not, see .
+ */
+
+package smtp
+
+import (
+ "fmt"
+ "strings"
+)
+
+type Command struct {
+ name string
+ Exec func(SMTPSession, string)(bool, error)
+}
+
+func (self Command) Name() string {
+ return self.name
+}
+
+func (self Command) Check(message string) bool {
+ return strings.HasPrefix(message, self.name)
+}
+
+func Greet(smtp_session SMTPSession) error {
+ err := smtp_session.Write(
+ "220 localhost ESMTPSession diodemail -- Service ready" + "\n",
+ )
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func Helo(smtp_session SMTPSession, message string) (bool, error) {
+ err := smtp_session.Write(
+ fmt.Sprintf(
+ "250 %v is shy" + "\r\n",
+ smtp_session.GetAddress(),
+ ),
+ )
+ if err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
+
+func Ehlo(smtp_session SMTPSession, message string) (bool, error) {
+ err := smtp_session.Write(
+ fmt.Sprintf(
+ "250 %v is shy" + "\r\n",
+ smtp_session.GetAddress(),
+ ),
+ )
+ if err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
+
+func MailFrom(smtp_session SMTPSession, message string) (bool, error) {
+ match := ReversePath.FindStringSubmatch(message)
+ if match == nil {
+ smtp_session.Write(
+ "501 Could not parse reverse-path",
+ )
+ } else if len(match) > 1 {
+ smtp_session.MailFrom(match[1])
+ err := smtp_session.Write(
+ "250 OK\n",
+ )
+ if err != nil {
+ return false, err
+ }
+ }
+
+ return false, nil
+}
+
+func RcptTo(smtp_session SMTPSession, message string) (bool, error) {
+ match := ForwardPath.FindStringSubmatch(message)
+ if match == nil {
+ smtp_session.Write(
+ "501 Could not parse forward-path",
+ )
+ } else {
+ smtp_session.AddRecipient(match[1])
+ err := smtp_session.Write(
+ "250 OK\n",
+ )
+ if err != nil {
+ return false, err
+ }
+ }
+
+ return false, nil
+}
+
+func Data(smtp_session SMTPSession, message string) (bool, error) {
+ err := smtp_session.Write(
+ "354 Start Input\n",
+ )
+ if err != nil {
+ return false, err
+ }
+ data, err := smtp_session.Read("\r\n.\r\n")
+ smtp_session.SetMailBuffer(data)
+ err = smtp_session.Write(
+ "250 OK\n",
+ )
+ smtp_session.Reset()
+ if err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
+
+func Quit(smtp_session SMTPSession, message string) (bool, error) {
+ err := smtp_session.Write(
+ "221 Goodbye :)\n",
+ )
+ if err != nil {
+ return true, err
+ }
+
+ return true, nil
+}
+
+func Noop(smtp_session SMTPSession, message string) (bool, error) {
+ err := smtp_session.Write(
+ "250 OK\n",
+ )
+ if err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
+
+func Rset(smtp_session SMTPSession, message string) (bool, error) {
+ err := smtp_session.Write(
+ "250 Reset\n",
+ )
+ smtp_session.Reset()
+ if err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
diff --git a/smtp/connection.go b/smtp/connection.go
index 7fbee1d..a1dfdee 100644
--- a/smtp/connection.go
+++ b/smtp/connection.go
@@ -20,13 +20,13 @@ package smtp
import (
"net"
- // "fmt"
+ "fmt"
"strings"
"github.com/rs/zerolog/log"
)
-type Connection struct {
+type SMTPSession struct {
connection net.Conn
ReversePathBuffer *string
@@ -34,8 +34,8 @@ type Connection struct {
MailBuffer *string
}
-func NewConnection(connection net.Conn) Connection {
- return Connection{
+func MakeSMTPSession(connection net.Conn) SMTPSession {
+ return SMTPSession{
connection,
nil,
[]string{},
@@ -43,8 +43,62 @@ func NewConnection(connection net.Conn) Connection {
}
}
+func (self SMTPSession) Run() error {
+ err := Greet(self)
+ if err != nil {
+ return err
+ }
+ COMMANDS := []Command{
+ Command{ "HELO", Helo },
+ Command{ "EHLO", Ehlo },
+ Command{ "MAIL FROM:", MailFrom },
+ Command{ "RCPT TO:", RcptTo },
+ Command{ "DATA", Data },
+ Command{ "QUIT", Quit },
+ Command{ "NOOP", Noop },
+ Command{ "RSET", Rset },
+ }
+ quit := false
+ for !quit {
+ message, err := self.Read("\n")
+ if err != nil {
+ return err
+ }
+
+ command_found := false
+ for _, command := range COMMANDS {
+ if command.Check(message) {
+ command_found = true
+ quit, err = command.Exec(self, message[len(command.Name()):])
+ if err != nil {
+ return err
+ }
+ break
+ }
+ }
+ if !command_found {
+ self.Write("500 Unrecognized Command\n")
+ }
+ }
+ return nil
+}
+
+func (self SMTPSession) TraceSession(direction string, message string) {
+ lines := strings.Split(message, "\n")
+ for index, line := range lines {
+ if index != len(lines) - 1 || line != "" {
+ if index == 0 {
+ log.Trace().Msgf("%v %v %v", self.connection.RemoteAddr(), direction, line)
+ } else {
+ log.Trace().Msgf("%v %v", self.connection.RemoteAddr(), line)
+ }
+ }
+ }
+
+}
+
var buffer = [1024]byte{}
-func (self Connection) Read(delimiter string) (string, error) {
+func (self SMTPSession) Read(delimiter string) (string, error) {
var message string
for !strings.Contains(message, delimiter) {
num_read, err := self.connection.Read(buffer[:])
@@ -53,50 +107,36 @@ func (self Connection) Read(delimiter string) (string, error) {
}
message += string(buffer[:num_read])
}
-
- lines := strings.Split(message, "\n")
- for index, line := range lines {
- if index != len(lines) - 1 || line != "" {
- if index == 0 {
- log.Trace().Msgf("%v -> %v", self.RemoteAddr(), line)
- } else {
- log.Trace().Msgf("%v %v", self.RemoteAddr(), line)
- }
- }
- }
+ self.TraceSession(" ->", message)
return message[:len(message) - len(delimiter)], nil
}
-func (self Connection) Write(message string) error {
- lines := strings.Split(message, "\n")
- for index, line := range lines {
- if index != len(lines) - 1 || line != "" {
- if index == 0 {
- log.Trace().Msgf("%v <- %v", self.RemoteAddr(), line)
- } else {
- log.Trace().Msgf("%v %v", self.RemoteAddr(), line)
- }
- }
- }
+func (self SMTPSession) Write(message string) error {
+ self.TraceSession("<- ", message)
_, err := self.connection.Write([]byte(message))
return err
}
-func (self Connection) AddRecipient(recipient string) {
+func (self SMTPSession) GetAddress() string {
+ return fmt.Sprint(self.connection.LocalAddr())
+}
+
+func (self SMTPSession) MailFrom(from string) {
+ self.Reset()
+ self.ReversePathBuffer = &from
+}
+
+func (self SMTPSession) AddRecipient(recipient string) {
self.ForwardPathBuffer = append(self.ForwardPathBuffer, recipient)
}
-func (self Connection) Reset() {
+func (self SMTPSession) SetMailBuffer(data string) {
+ self.MailBuffer = &data
+}
+
+func (self SMTPSession) Reset() {
self.ReversePathBuffer = nil
self.ForwardPathBuffer = []string{}
self.MailBuffer = nil
}
-
-func (self Connection) Close() {
- self.connection.Close()
-}
-
-func (self Connection) RemoteAddr() net.Addr {
- return self.connection.RemoteAddr()
-}
diff --git a/smtp/handlers.go b/smtp/handlers.go
deleted file mode 100644
index 8ba5560..0000000
--- a/smtp/handlers.go
+++ /dev/null
@@ -1,214 +0,0 @@
-/* diodemail - send-only smtp server
- * Copyright (c) 2024 Gnarwhal
- *
- * This file is part of SSHare.
- *
- * SSHare is free software: you can redistribute it and/or modify it under the terms of
- * the GNU General Public License as published by the Free Software Foundation,
- * either version 3 of the License, or (at your option) any later version.
- *
- * SSHare is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY
- * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- * more details.
- *
- * You should have received a copy of the GNU General Public License along with
- * SSHare. If not, see .
- */
-
-package smtp
-
-import (
- "fmt"
- "strings"
-)
-
-type Command struct {
- name string
- Exec func(Connection, string)(bool, error)
-}
-
-func (self Command) Name() string {
- return self.name
-}
-
-func (self Command) Check(message string) bool {
- return strings.HasPrefix(message, self.name)
-}
-
-func (self Connection) Chain() error {
- err := Greet(self)
- if err != nil {
- return err
- }
- COMMANDS := []Command{
- Command{ "HELO", Helo },
- Command{ "EHLO", Ehlo },
- Command{ "MAIL FROM:", MailFrom },
- Command{ "RCPT TO:", RcptTo },
- Command{ "DATA", Data },
- Command{ "QUIT", Quit },
- Command{ "NOOP", Noop },
- Command{ "RSET", Rset },
- }
- quit := false
- for !quit {
- message, err := self.Read("\n")
- if err != nil {
- return err
- }
-
- command_found := false
- for _, command := range COMMANDS {
- if command.Check(message) {
- command_found = true
- quit, err = command.Exec(self, message[len(command.Name()):])
- if err != nil {
- return err
- }
- break
- }
- }
- if !command_found {
- self.Write("500 Unrecognized Command\n")
- }
- }
- return nil
-}
-
-/* --- GREETING RESPONSE --- */
-
-func Greet(connection Connection) error {
- err := connection.Write(
- "220 localhost ESMTP diodemail -- Service ready" + "\n",
- )
- if err != nil {
- return err
- }
-
- return nil
-}
-
-/* --- HELO/EHLO RESPONSE --- */
-
-func Helo(connection Connection, message string) (bool, error) {
- err := connection.Write(
- fmt.Sprintf(
- "250 %v is shy" + "\r\n",
- connection.connection.LocalAddr(),
- ),
- )
- if err != nil {
- return false, err
- }
-
- return false, nil
-}
-
-func Ehlo(connection Connection, message string) (bool, error) {
- err := connection.Write(
- fmt.Sprintf(
- "250 %v is shy" + "\r\n",
- connection.connection.LocalAddr(),
- ),
- )
- if err != nil {
- return false, err
- }
-
- return false, nil
-}
-
-/* --- MAIL FROM RESPONSE --- */
-
-func MailFrom(connection Connection, message string) (bool, error) {
- match := ReversePath.FindStringSubmatch(message)
- if match == nil {
- connection.Write(
- "501 Could not parse reverse-path",
- )
- } else if len(match) > 1 {
- connection.ReversePathBuffer = &match[1]
- err := connection.Write(
- "250 OK\n",
- )
- if err != nil {
- return false, err
- }
- }
-
- return false, nil
-}
-
-func RcptTo(connection Connection, message string) (bool, error) {
- match := ForwardPath.FindStringSubmatch(message)
- if match == nil {
- connection.Write(
- "501 Could not parse forward-path",
- )
- } else {
- connection.AddRecipient(match[1])
- err := connection.Write(
- "250 OK\n",
- )
- if err != nil {
- return false, err
- }
- }
-
- return false, nil
-}
-
-func Data(connection Connection, message string) (bool, error) {
- err := connection.Write(
- "354 Start Input\n",
- )
- if err != nil {
- return false, err
- }
- data, err := connection.Read("\r\n.\r\n")
- connection.MailBuffer = &data
- err = connection.Write(
- "250 OK\n",
- )
- connection.Reset()
- if err != nil {
- return false, err
- }
-
- return false, nil
-}
-
-func Quit(connection Connection, message string) (bool, error) {
- err := connection.Write(
- "221 Goodbye :)\n",
- )
- if err != nil {
- return true, err
- }
-
- return true, nil
-}
-
-func Noop(connection Connection, message string) (bool, error) {
- err := connection.Write(
- "250 OK\n",
- )
- if err != nil {
- return false, err
- }
-
- return false, nil
-}
-
-func Rset(connection Connection, message string) (bool, error) {
- err := connection.Write(
- "250 Reset\n",
- )
- connection.Reset()
- if err != nil {
- return false, err
- }
-
- return false, nil
-}
diff --git a/smtp/server.go b/smtp/server.go
index 5ace672..16a2727 100644
--- a/smtp/server.go
+++ b/smtp/server.go
@@ -30,14 +30,26 @@ type PlainListener struct {
listener net.Listener
}
-func handle(connection Connection) {
- log.Info().Msgf("New connection %v", connection.RemoteAddr())
+func handle(connection net.Conn) {
+ log.Info().Msgf(
+ "New connection %v. Starting session.",
+ connection.RemoteAddr(),
+ )
defer connection.Close()
- err := connection.Chain()
+
+ session := MakeSMTPSession(connection)
+ err := session.Run()
if err != nil {
- log.Error().Msgf("Failed to serve %v: %v", connection.RemoteAddr(), err)
+ log.Error().Msgf(
+ "Session %v exited with error: %v",
+ connection.RemoteAddr(),
+ err,
+ )
} else {
- log.Info().Msgf("Successfully served %v", connection.RemoteAddr())
+ log.Info().Msgf(
+ "Session %v exited successfully",
+ connection.RemoteAddr(),
+ )
}
}
@@ -60,7 +72,7 @@ func Run(host string, implicit_tls bool) error {
return err
}
- go handle(NewConnection(connection))
+ go handle(connection)
}
}