Slight reworking of config file. Connect to MTAs via port 25 and use certificate for StartTLS.

This commit is contained in:
Gnarwhal 2024-10-04 20:41:21 +00:00
parent e8979dc08e
commit 1750872280
Signed by: Gnarwhal
GPG key ID: 0989A73D8C421174
4 changed files with 67 additions and 77 deletions

View file

@ -24,39 +24,26 @@ import (
)
type Config struct {
General GeneralConfig
Plain PlainConfig
TLS TLSConfig
}
type GeneralConfig struct {
LogLevel string
Host string
PasswordHash string
}
type PlainConfig struct {
Enabled bool
Port string
}
type TLSConfig struct {
Enabled bool
Port string
Plain string
TLS string
CertPath string
PrivateKeyPath string
}
type GeneralConfig struct {
}
func LoadConfig(path string) (*Config, error) {
contents, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var config Config
config.Plain.Enabled = false
config.Plain.Port = "25"
config.TLS.Enabled = false
config.TLS.Port = "465"
config.Plain = "disabled"
config.TLS = "disabled"
err = json.Unmarshal(contents, &config)
if err != nil {
return nil, err

View file

@ -62,8 +62,8 @@ func main() {
log.Fatal().Msgf("%v", err)
}
if log_level == "" {
if config.General.LogLevel != "" {
log_level = config.General.LogLevel
if config.LogLevel != "" {
log_level = config.LogLevel
} else {
log_level = "info"
}
@ -83,34 +83,28 @@ func main() {
log.Info().Msgf("Starting diodemail v%v", Version)
log.Info().Msgf("Loaded config from: %v", config_path)
var plain_config *smtp.PlainConfig
if config.Plain.Enabled {
plain_config = &smtp.PlainConfig {
config.Plain.Port,
}
if config.CertPath == "" || config.PrivateKeyPath == "" {
log.Fatal().Msgf(
"Must provide CertPath (got '%v') and PrivateKeyPath (got '%v')",
config.CertPath,
config.PrivateKeyPath,
)
}
var tls_config *smtp.TLSConfig
if config.TLS.Enabled {
certificate, err := tls.LoadX509KeyPair(
config.TLS.CertPath,
config.TLS.PrivateKeyPath,
config.CertPath,
config.PrivateKeyPath,
)
if err != nil {
log.Error().Msgf("Failed to load TLS config: %v", err)
} else {
tls_config = &smtp.TLSConfig {
config.TLS.Port,
tls.Config{Certificates: []tls.Certificate{certificate}},
}
}
log.Fatal().Msgf("Failed to load TLS config: %v", err)
}
err = smtp.Run(
config.General.Host,
config.General.PasswordHash,
plain_config,
tls_config,
config.Host,
config.PasswordHash,
config.Plain,
config.TLS,
tls.Config{Certificates: []tls.Certificate{certificate}},
)
if err != nil {
log.Fatal().Msgf("%v", err)

View file

@ -27,23 +27,14 @@ import (
"github.com/rs/zerolog/log"
)
type PlainConfig struct {
Port string
}
type TLSConfig struct {
Port string
TlsConfig tls.Config
}
func handle(connection net.Conn, host string, password_hash string) {
func handle(connection net.Conn, host string, password_hash string, tls_config tls.Config) {
log.Info().Msgf(
"New connection %v. Starting session.",
connection.RemoteAddr(),
)
defer connection.Close()
session := MakeSMTPSession(connection, host, password_hash)
session := MakeSMTPSession(connection, host, password_hash, tls_config)
err := session.Run()
if err != nil {
log.Error().Msgf(
@ -59,36 +50,51 @@ func handle(connection net.Conn, host string, password_hash string) {
}
}
func Run(host string, password_hash string, plain_config *PlainConfig, tls_config *TLSConfig) error {
func Run(
host string,
password_hash string,
plain_port string,
tls_port string,
tls_config tls.Config,
) error {
var wait_group sync.WaitGroup
if plain_config != nil {
listener, err := net.Listen("tcp", fmt.Sprintf(":%v", plain_config.Port))
if plain_port != "disabled" {
listener, err := net.Listen("tcp", fmt.Sprintf(":%v", plain_port))
if err != nil {
return err
}
log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host)
log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host)
wait_group.Add(1)
go Listen(wait_group, host, password_hash, listener)
go Listen(wait_group, host, password_hash, tls_config, listener)
}
if tls_config != nil {
listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_config.Port), &tls_config.TlsConfig)
if tls_port != "disabled" {
listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config)
if err != nil {
return err
}
log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host)
log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host)
wait_group.Add(1)
go Listen(wait_group, host, password_hash, listener)
go Listen(wait_group, host, password_hash, tls_config, listener)
}
wait_group.Wait()
return nil
}
func Listen(wait_group sync.WaitGroup, host string, password_hash string, listener net.Listener) {
func Listen(
wait_group sync.WaitGroup,
host string,
password_hash string,
tls_config tls.Config,
listener net.Listener,
) {
defer wait_group.Done()
for {
connection, _ := listener.Accept()
go handle(connection, host, password_hash)
connection, err := listener.Accept()
if err != nil {
log.Error().Msgf("Failed to accept client: %v", err)
} else {
go handle(connection, host, password_hash, tls_config)
}
}
}

View file

@ -33,13 +33,14 @@ type SMTPSession struct {
connection net.Conn
host string
password_hash string
tls_config tls.Config
buffer [4096]byte
ReversePathBuffer *string
ForwardPathBuffer []string
}
func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession {
func MakeSMTPSession(connection net.Conn, host string, password_hash string, tls_config tls.Config) SMTPSession {
return SMTPSession{
connection: connection,
host: host,
@ -161,11 +162,7 @@ func (self *SMTPSession) SendMail(data string) error {
} else {
smtp_hostname = mx[0].Host
}
tls_connection, err := tls.Dial("tcp", fmt.Sprintf("%v:465", smtp_hostname), nil)
if err != nil {
return err
}
smtp_client, err := smtp.NewClient(tls_connection, "")
smtp_client, err := smtp.Dial(fmt.Sprintf("%v:25", smtp_hostname))
if err != nil {
return err
}
@ -175,6 +172,12 @@ func (self *SMTPSession) SendMail(data string) error {
if err != nil {
return err
}
tls_config := self.tls_config
tls_config.ServerName = smtp_hostname
err = smtp_client.StartTLS(&tls_config)
if err != nil {
return err
}
err = smtp_client.Mail(*self.ReversePathBuffer)
if err != nil {
return err