Slight reworking of config file. Connect to MTAs via port 25 and use certificate for StartTLS.

This commit is contained in:
Gnarwhal 2024-10-04 20:41:21 +00:00
parent e8979dc08e
commit 1750872280
Signed by: Gnarwhal
GPG key ID: 0989A73D8C421174
4 changed files with 67 additions and 77 deletions

View file

@ -24,39 +24,26 @@ import (
) )
type Config struct { type Config struct {
General GeneralConfig
Plain PlainConfig
TLS TLSConfig
}
type GeneralConfig struct {
LogLevel string LogLevel string
Host string Host string
PasswordHash string PasswordHash string
} Plain string
TLS string
type PlainConfig struct {
Enabled bool
Port string
}
type TLSConfig struct {
Enabled bool
Port string
CertPath string CertPath string
PrivateKeyPath string PrivateKeyPath string
} }
type GeneralConfig struct {
}
func LoadConfig(path string) (*Config, error) { func LoadConfig(path string) (*Config, error) {
contents, err := os.ReadFile(path) contents, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var config Config var config Config
config.Plain.Enabled = false config.Plain = "disabled"
config.Plain.Port = "25" config.TLS = "disabled"
config.TLS.Enabled = false
config.TLS.Port = "465"
err = json.Unmarshal(contents, &config) err = json.Unmarshal(contents, &config)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -62,8 +62,8 @@ func main() {
log.Fatal().Msgf("%v", err) log.Fatal().Msgf("%v", err)
} }
if log_level == "" { if log_level == "" {
if config.General.LogLevel != "" { if config.LogLevel != "" {
log_level = config.General.LogLevel log_level = config.LogLevel
} else { } else {
log_level = "info" log_level = "info"
} }
@ -83,34 +83,28 @@ func main() {
log.Info().Msgf("Starting diodemail v%v", Version) log.Info().Msgf("Starting diodemail v%v", Version)
log.Info().Msgf("Loaded config from: %v", config_path) log.Info().Msgf("Loaded config from: %v", config_path)
var plain_config *smtp.PlainConfig if config.CertPath == "" || config.PrivateKeyPath == "" {
if config.Plain.Enabled { log.Fatal().Msgf(
plain_config = &smtp.PlainConfig { "Must provide CertPath (got '%v') and PrivateKeyPath (got '%v')",
config.Plain.Port, config.CertPath,
} config.PrivateKeyPath,
)
} }
var tls_config *smtp.TLSConfig
if config.TLS.Enabled {
certificate, err := tls.LoadX509KeyPair( certificate, err := tls.LoadX509KeyPair(
config.TLS.CertPath, config.CertPath,
config.TLS.PrivateKeyPath, config.PrivateKeyPath,
) )
if err != nil { if err != nil {
log.Error().Msgf("Failed to load TLS config: %v", err) log.Fatal().Msgf("Failed to load TLS config: %v", err)
} else {
tls_config = &smtp.TLSConfig {
config.TLS.Port,
tls.Config{Certificates: []tls.Certificate{certificate}},
}
}
} }
err = smtp.Run( err = smtp.Run(
config.General.Host, config.Host,
config.General.PasswordHash, config.PasswordHash,
plain_config, config.Plain,
tls_config, config.TLS,
tls.Config{Certificates: []tls.Certificate{certificate}},
) )
if err != nil { if err != nil {
log.Fatal().Msgf("%v", err) log.Fatal().Msgf("%v", err)

View file

@ -27,23 +27,14 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type PlainConfig struct { func handle(connection net.Conn, host string, password_hash string, tls_config tls.Config) {
Port string
}
type TLSConfig struct {
Port string
TlsConfig tls.Config
}
func handle(connection net.Conn, host string, password_hash string) {
log.Info().Msgf( log.Info().Msgf(
"New connection %v. Starting session.", "New connection %v. Starting session.",
connection.RemoteAddr(), connection.RemoteAddr(),
) )
defer connection.Close() defer connection.Close()
session := MakeSMTPSession(connection, host, password_hash) session := MakeSMTPSession(connection, host, password_hash, tls_config)
err := session.Run() err := session.Run()
if err != nil { if err != nil {
log.Error().Msgf( log.Error().Msgf(
@ -59,36 +50,51 @@ func handle(connection net.Conn, host string, password_hash string) {
} }
} }
func Run(host string, password_hash string, plain_config *PlainConfig, tls_config *TLSConfig) error { func Run(
host string,
password_hash string,
plain_port string,
tls_port string,
tls_config tls.Config,
) error {
var wait_group sync.WaitGroup var wait_group sync.WaitGroup
if plain_config != nil { if plain_port != "disabled" {
listener, err := net.Listen("tcp", fmt.Sprintf(":%v", plain_config.Port)) listener, err := net.Listen("tcp", fmt.Sprintf(":%v", plain_port))
if err != nil { if err != nil {
return err return err
} }
log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host) log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host)
wait_group.Add(1) wait_group.Add(1)
go Listen(wait_group, host, password_hash, listener) go Listen(wait_group, host, password_hash, tls_config, listener)
} }
if tls_config != nil { if tls_port != "disabled" {
listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_config.Port), &tls_config.TlsConfig) listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config)
if err != nil { if err != nil {
return err return err
} }
log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host) log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host)
wait_group.Add(1) wait_group.Add(1)
go Listen(wait_group, host, password_hash, listener) go Listen(wait_group, host, password_hash, tls_config, listener)
} }
wait_group.Wait() wait_group.Wait()
return nil return nil
} }
func Listen(wait_group sync.WaitGroup, host string, password_hash string, listener net.Listener) { func Listen(
wait_group sync.WaitGroup,
host string,
password_hash string,
tls_config tls.Config,
listener net.Listener,
) {
defer wait_group.Done() defer wait_group.Done()
for { for {
connection, _ := listener.Accept() connection, err := listener.Accept()
if err != nil {
go handle(connection, host, password_hash) log.Error().Msgf("Failed to accept client: %v", err)
} else {
go handle(connection, host, password_hash, tls_config)
}
} }
} }

View file

@ -33,13 +33,14 @@ type SMTPSession struct {
connection net.Conn connection net.Conn
host string host string
password_hash string password_hash string
tls_config tls.Config
buffer [4096]byte buffer [4096]byte
ReversePathBuffer *string ReversePathBuffer *string
ForwardPathBuffer []string ForwardPathBuffer []string
} }
func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession { func MakeSMTPSession(connection net.Conn, host string, password_hash string, tls_config tls.Config) SMTPSession {
return SMTPSession{ return SMTPSession{
connection: connection, connection: connection,
host: host, host: host,
@ -161,11 +162,7 @@ func (self *SMTPSession) SendMail(data string) error {
} else { } else {
smtp_hostname = mx[0].Host smtp_hostname = mx[0].Host
} }
tls_connection, err := tls.Dial("tcp", fmt.Sprintf("%v:465", smtp_hostname), nil) smtp_client, err := smtp.Dial(fmt.Sprintf("%v:25", smtp_hostname))
if err != nil {
return err
}
smtp_client, err := smtp.NewClient(tls_connection, "")
if err != nil { if err != nil {
return err return err
} }
@ -175,6 +172,12 @@ func (self *SMTPSession) SendMail(data string) error {
if err != nil { if err != nil {
return err return err
} }
tls_config := self.tls_config
tls_config.ServerName = smtp_hostname
err = smtp_client.StartTLS(&tls_config)
if err != nil {
return err
}
err = smtp_client.Mail(*self.ReversePathBuffer) err = smtp_client.Mail(*self.ReversePathBuffer)
if err != nil { if err != nil {
return err return err