diff --git a/cmd/server/config.go b/cmd/server/config.go new file mode 100644 index 0000000..3fcd9eb --- /dev/null +++ b/cmd/server/config.go @@ -0,0 +1,59 @@ +/* diodemail - send-only smtp server + * Copyright (c) 2024 Gnarwhal + * + * This file is part of SSHare. + * + * SSHare is free software: you can redistribute it and/or modify it under the terms of + * the GNU General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * SSHare is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * SSHare. If not, see . + */ + +package main + +import ( + "os" + "encoding/json" +) + +type Config struct { + Host string + Plain PlainConfig + TLS TLSConfig +} + +type PlainConfig struct { + Enabled bool + Port string +} + +type TLSConfig struct { + Enabled bool + Port string + CertPath string + PrivateKeyPath string +} + +func LoadConfig(path string) (*Config, error) { + contents, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var config Config + config.Plain.Enabled = false + config.Plain.Port = "25" + config.TLS.Enabled = false + config.TLS.Port = "465" + err = json.Unmarshal(contents, &config) + if err != nil { + return nil, err + } + return &config, nil +} diff --git a/cmd/server/main.go b/cmd/server/main.go index 035e477..1530c09 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -20,15 +20,48 @@ package main import ( "log" + "flag" + "crypto/tls" "forge.monodon.me/Gnarwhal/diodemail/smtp" ) func main() { - err := smtp.Run( - "localhost", - "4650", - false, + var cert_path string + flag.StringVar(&cert_path, "config", "/etc/diodemail/config.json", "Path to config file") + flag.Parse() + config, err := LoadConfig(cert_path) + if err != nil { + log.Fatal(err) + } + + var plain_config *smtp.PlainConfig + if config.Plain.Enabled { + plain_config = &smtp.PlainConfig { + config.Plain.Port, + } + } + + var tls_config *smtp.TLSConfig + if config.TLS.Enabled { + certificate, err := tls.LoadX509KeyPair( + config.TLS.CertPath, + config.TLS.PrivateKeyPath, + ) + if err != nil { + log.Println(err) + } else { + tls_config = &smtp.TLSConfig { + config.TLS.Port, + tls.Config{Certificates: []tls.Certificate{certificate}}, + } + } + } + + err = smtp.Run( + config.Host, + plain_config, + tls_config, ) if err != nil { log.Fatal(err) diff --git a/smtp/server.go b/smtp/server.go index 1d07ad5..a1e997a 100644 --- a/smtp/server.go +++ b/smtp/server.go @@ -21,14 +21,21 @@ package smtp import ( "fmt" "net" + "crypto/tls" "os" + "sync" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) -type PlainListener struct { - listener net.Listener +type PlainConfig struct { + Port string +} + +type TLSConfig struct { + Port string + TlsConfig tls.Config } func handle(connection net.Conn, host string) { @@ -54,30 +61,43 @@ func handle(connection net.Conn, host string) { } } -func Run(host string, port string, implicit_tls bool) error { +func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error { log.Logger = zerolog. New(zerolog.ConsoleWriter{Out: os.Stderr}). With(). Timestamp(). Logger(). Level(zerolog.TraceLevel) - - listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, port)) - if err != nil { - return err - } - log.Info().Msgf("Server started on port %v for host %v", port, host) - for { - connection, err := listener.Accept() + var wait_group sync.WaitGroup + if plain_config != nil { + listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, plain_config.Port)) if err != nil { return err } + log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host) + wait_group.Add(1) + go Listen(wait_group, host, listener) + } + if tls_config != nil { + listener, err := tls.Listen("tcp", fmt.Sprintf("%v:%v", host, tls_config.Port), &tls_config.TlsConfig) + if err != nil { + return err + } + log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host) + wait_group.Add(1) + go Listen(wait_group, host, listener) + } + wait_group.Wait() + + return nil +} + +func Listen(wait_group sync.WaitGroup, host string, listener net.Listener) { + defer wait_group.Done() + for { + connection, _ := listener.Accept() go handle(connection, host) } } - -func (self PlainListener) Close() error { - return self.listener.Close() -} diff --git a/smtp/session.go b/smtp/session.go index 448ebbe..b4882b8 100644 --- a/smtp/session.go +++ b/smtp/session.go @@ -149,7 +149,7 @@ func (self *SMTPSession) SendMail(data string) error { } else { smtp_hostname = mx[0].Host } - tls_connection, err := tls.Dial("tcp", fmt.Sprintf("%v:4560", smtp_hostname), nil) + tls_connection, err := tls.Dial("tcp", fmt.Sprintf("%v:465", smtp_hostname), nil) if err != nil { return err }