diodemail/smtp/session.go
2024-10-03 15:12:20 +00:00

209 lines
4.8 KiB
Go

/* diodemail - send-only smtp server
* Copyright (c) 2024 Gnarwhal
*
* This file is part of diodemail.
*
* diodemail is free software: you can redistribute it and/or modify it under the terms of
* the GNU General Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*
* diodemail is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along with
* diodemail. If not, see <https://www.gnu.org/licenses/>.
*/
package smtp
import (
"net"
"net/smtp"
"fmt"
"strings"
"crypto/tls"
"crypto/sha256"
"github.com/rs/zerolog/log"
)
type SMTPSession struct {
connection net.Conn
host string
password_hash string
buffer [4096]byte
ReversePathBuffer *string
ForwardPathBuffer []string
}
func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession {
return SMTPSession{
connection: connection,
host: host,
password_hash: password_hash,
}
}
func (self *SMTPSession) Run() error {
err := Greet(self)
if err != nil {
return err
}
COMMANDS := []Command{
Command{ "HELO", Helo },
Command{ "EHLO", Ehlo },
Command{ "AUTH", Auth },
Command{ "MAIL FROM:", MailFrom },
Command{ "RCPT TO:", RcptTo },
Command{ "DATA", Data },
Command{ "QUIT", Quit },
Command{ "NOOP", Noop },
Command{ "RSET", Rset },
}
quit := false
for !quit {
message, err := self.ReadUntil("\n")
if err != nil {
return err
}
command_found := false
for _, command := range COMMANDS {
if command.Check(message) {
command_found = true
quit, err = command.Exec(self, message[len(command.GetName()):])
if err != nil {
if quit {
return err
} else {
log.Error().Msgf("%v: %v", self.connection.RemoteAddr(), err)
}
}
break
}
}
if !command_found {
self.Write("500 Unrecognized Command\n")
}
}
return nil
}
func (self *SMTPSession) TraceSession(direction string, message string) {
lines := strings.Split(message, "\n")
for index, line := range lines {
if index != len(lines) - 1 || line != "" {
if index == 0 {
log.Trace().Msgf("%v %v %v", self.connection.RemoteAddr(), direction, line)
} else {
log.Trace().Msgf("%v %v", self.connection.RemoteAddr(), line)
}
}
}
}
func (self *SMTPSession) ReadUntil(delimiter string) (string, error) {
var message string
for !strings.Contains(message, delimiter) {
num_read, err := self.connection.Read(self.buffer[:])
if err != nil {
return "", err
}
message += string(self.buffer[:num_read])
}
self.TraceSession(" ->", message)
return message[:len(message) - len(delimiter)], nil
}
func (self *SMTPSession) ReadCount(num_bytes int) ([]byte, error) {
bytes := make([]byte, num_bytes)
_, err := self.connection.Read(bytes[:])
return bytes, err
}
func (self *SMTPSession) Write(message string) error {
self.TraceSession("<- ", message)
_, err := self.connection.Write([]byte(message))
return err
}
func (self *SMTPSession) GetHost() string {
return self.host
}
func (self *SMTPSession) ValidatePassword(password string) bool {
return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash
}
func (self *SMTPSession) MailFrom(from string) {
self.Reset()
self.ReversePathBuffer = &from
}
func (self *SMTPSession) AddRecipient(recipient string) {
self.ForwardPathBuffer = append(self.ForwardPathBuffer, recipient)
}
func (self *SMTPSession) SendMail(data string) error {
mx_hostname := Domain.FindStringSubmatch(self.ForwardPathBuffer[0])[1]
mx, err := net.LookupMX(mx_hostname)
if err != nil {
return err
}
var smtp_hostname string
if len(mx) == 0 {
smtp_hostname = mx_hostname
} else {
smtp_hostname = mx[0].Host
}
tls_connection, err := tls.Dial("tcp", fmt.Sprintf("%v:465", smtp_hostname), nil)
if err != nil {
return err
}
smtp_client, err := smtp.NewClient(tls_connection, "")
if err != nil {
return err
}
defer smtp_client.Close()
err = smtp_client.Hello(self.host)
if err != nil {
return err
}
err = smtp_client.Mail(*self.ReversePathBuffer)
if err != nil {
return err
}
err = smtp_client.Rcpt(self.ForwardPathBuffer[0])
if err != nil {
return err
}
writer, err := smtp_client.Data()
if err != nil {
return err
}
_, err = writer.Write([]byte(data))
if err != nil {
return err
}
err = writer.Close()
if err != nil {
return err
}
err = smtp_client.Quit()
if err != nil {
return err
}
return nil
}
func (self *SMTPSession) Reset() {
self.ReversePathBuffer = nil
self.ForwardPathBuffer = []string{}
}