Actual AUTH and also enforce EHLO/HELO first
This commit is contained in:
parent
f656866e7d
commit
5ce65e6be9
5 changed files with 133 additions and 26 deletions
|
@ -26,14 +26,24 @@ import (
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LogLevel string
|
LogLevel string
|
||||||
Host string
|
Host string
|
||||||
PasswordHash string
|
Ports PortConfig
|
||||||
Plain string
|
Certificate CertConfig
|
||||||
TLS string
|
Auth AuthConfig
|
||||||
CertPath string
|
|
||||||
PrivateKeyPath string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeneralConfig struct {
|
type PortConfig struct {
|
||||||
|
Plain string
|
||||||
|
TLS string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertConfig struct {
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
PasswordHash string
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (*Config, error) {
|
||||||
|
@ -42,8 +52,9 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var config Config
|
var config Config
|
||||||
config.Plain = "disabled"
|
config.Ports.Plain = "disabled"
|
||||||
config.TLS = "disabled"
|
config.Ports.TLS = "disabled"
|
||||||
|
config.Auth.Enabled = true
|
||||||
err = json.Unmarshal(contents, &config)
|
err = json.Unmarshal(contents, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -83,28 +83,33 @@ func main() {
|
||||||
log.Info().Msgf("Starting diodemail v%v", Version)
|
log.Info().Msgf("Starting diodemail v%v", Version)
|
||||||
log.Info().Msgf("Loaded config from: %v", config_path)
|
log.Info().Msgf("Loaded config from: %v", config_path)
|
||||||
|
|
||||||
if config.CertPath == "" || config.PrivateKeyPath == "" {
|
if config.Certificate.CertFile == "" || config.Certificate.KeyFile == "" {
|
||||||
log.Fatal().Msgf(
|
log.Fatal().Msgf(
|
||||||
"Must provide CertPath (got '%v') and PrivateKeyPath (got '%v')",
|
"Must provide CertFile (got '%v') and KeyFile (got '%v')",
|
||||||
config.CertPath,
|
config.Certificate.CertFile,
|
||||||
config.PrivateKeyPath,
|
config.Certificate.KeyFile,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
certificate, err := tls.LoadX509KeyPair(
|
certificate, err := tls.LoadX509KeyPair(
|
||||||
config.CertPath,
|
config.Certificate.CertFile,
|
||||||
config.PrivateKeyPath,
|
config.Certificate.KeyFile,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Msgf("Failed to load TLS config: %v", err)
|
log.Fatal().Msgf("Failed to load TLS config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.Auth.Enabled && config.Auth.PasswordHash == "" {
|
||||||
|
log.Fatal().Msgf("Authentication is enabled but no password hash was supplied")
|
||||||
|
}
|
||||||
|
|
||||||
err = smtp.Run(
|
err = smtp.Run(
|
||||||
config.Host,
|
config.Host,
|
||||||
config.PasswordHash,
|
config.Ports.Plain,
|
||||||
config.Plain,
|
config.Ports.TLS,
|
||||||
config.TLS,
|
|
||||||
tls.Config{Certificates: []tls.Certificate{certificate}},
|
tls.Config{Certificates: []tls.Certificate{certificate}},
|
||||||
|
config.Auth.Enabled,
|
||||||
|
config.Auth.PasswordHash,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Msgf("%v", err)
|
log.Fatal().Msgf("%v", err)
|
||||||
|
|
|
@ -61,27 +61,43 @@ func Helo(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
smtp_session.HasHelloed = true
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
|
func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
|
enable_auth := ""
|
||||||
|
if smtp_session.RequiresAuthentication() {
|
||||||
|
enable_auth = "250-AUTH PLAIN\r\n"
|
||||||
|
}
|
||||||
err := smtp_session.Write(
|
err := smtp_session.Write(
|
||||||
fmt.Sprintf(
|
fmt.Sprintf(
|
||||||
"250-%v is shy" + "\r\n" +
|
"250-%v is shy" + "\r\n" +
|
||||||
"250-AUTH PLAIN" + "\r\n" +
|
"%v" +
|
||||||
"250-8BITMIME" + "\r\n" +
|
"250-8BITMIME" + "\r\n" +
|
||||||
"250 SMTPUTF8" + "\r\n",
|
"250 SMTPUTF8" + "\r\n",
|
||||||
smtp_session.GetHost(),
|
smtp_session.GetHost(),
|
||||||
|
enable_auth,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
smtp_session.HasHelloed = true
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Auth(smtp_session *SMTPSession, message string) (bool, error) {
|
func Auth(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
|
if !smtp_session.HasHelloed {
|
||||||
|
err := smtp_session.Write(
|
||||||
|
"503 Must HELO/EHLO first\n",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
parts := strings.Split(message, " ")[1:]
|
parts := strings.Split(message, " ")[1:]
|
||||||
if parts[0] == "PLAIN" {
|
if parts[0] == "PLAIN" {
|
||||||
var password string
|
var password string
|
||||||
|
@ -107,6 +123,24 @@ func Auth(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
|
func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
|
if !smtp_session.HasHelloed {
|
||||||
|
err := smtp_session.Write(
|
||||||
|
"503 Must HELO/EHLO first\n",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !smtp_session.AuthenticationSatisfied() {
|
||||||
|
err := smtp_session.Write(
|
||||||
|
"530 Authentication required\n",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
match := ReversePath.FindStringSubmatch(message)
|
match := ReversePath.FindStringSubmatch(message)
|
||||||
if match == nil {
|
if match == nil {
|
||||||
err := smtp_session.Write(
|
err := smtp_session.Write(
|
||||||
|
@ -129,6 +163,24 @@ func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func RcptTo(smtp_session *SMTPSession, message string) (bool, error) {
|
func RcptTo(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
|
if !smtp_session.HasHelloed {
|
||||||
|
err := smtp_session.Write(
|
||||||
|
"503 Must HELO/EHLO first\n",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !smtp_session.AuthenticationSatisfied() {
|
||||||
|
err := smtp_session.Write(
|
||||||
|
"530 Authentication required\n",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
match := ForwardPath.FindStringSubmatch(message)
|
match := ForwardPath.FindStringSubmatch(message)
|
||||||
if match == nil {
|
if match == nil {
|
||||||
err := smtp_session.Write(
|
err := smtp_session.Write(
|
||||||
|
@ -151,6 +203,24 @@ func RcptTo(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Data(smtp_session *SMTPSession, message string) (bool, error) {
|
func Data(smtp_session *SMTPSession, message string) (bool, error) {
|
||||||
|
if !smtp_session.HasHelloed {
|
||||||
|
err := smtp_session.Write(
|
||||||
|
"503 Must HELO/EHLO first\n",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !smtp_session.AuthenticationSatisfied() {
|
||||||
|
err := smtp_session.Write(
|
||||||
|
"530 Authentication required\n",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
err := smtp_session.Write(
|
err := smtp_session.Write(
|
||||||
"354 Start Input\n",
|
"354 Start Input\n",
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,14 +27,20 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func handle(connection net.Conn, host string, password_hash string, tls_config tls.Config) {
|
func handle(
|
||||||
|
connection net.Conn,
|
||||||
|
host string,
|
||||||
|
tls_config tls.Config,
|
||||||
|
auth bool,
|
||||||
|
password_hash string,
|
||||||
|
) {
|
||||||
log.Info().Msgf(
|
log.Info().Msgf(
|
||||||
"New connection %v. Starting session.",
|
"New connection %v. Starting session.",
|
||||||
connection.RemoteAddr(),
|
connection.RemoteAddr(),
|
||||||
)
|
)
|
||||||
defer connection.Close()
|
defer connection.Close()
|
||||||
|
|
||||||
session := MakeSMTPSession(connection, host, password_hash, tls_config)
|
session := MakeSMTPSession(connection, host, tls_config, auth, password_hash)
|
||||||
err := session.Run()
|
err := session.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf(
|
log.Error().Msgf(
|
||||||
|
@ -52,10 +58,11 @@ func handle(connection net.Conn, host string, password_hash string, tls_config t
|
||||||
|
|
||||||
func Run(
|
func Run(
|
||||||
host string,
|
host string,
|
||||||
password_hash string,
|
|
||||||
plain_port string,
|
plain_port string,
|
||||||
tls_port string,
|
tls_port string,
|
||||||
tls_config tls.Config,
|
tls_config tls.Config,
|
||||||
|
auth bool,
|
||||||
|
password_hash string,
|
||||||
) error {
|
) error {
|
||||||
var wait_group sync.WaitGroup
|
var wait_group sync.WaitGroup
|
||||||
if plain_port != "disabled" {
|
if plain_port != "disabled" {
|
||||||
|
@ -65,7 +72,7 @@ func Run(
|
||||||
}
|
}
|
||||||
log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host)
|
log.Info().Msgf("Plain text server started on port %v for host %v", plain_port, host)
|
||||||
wait_group.Add(1)
|
wait_group.Add(1)
|
||||||
go Listen(wait_group, host, password_hash, tls_config, listener)
|
go Listen(wait_group, host, tls_config, auth, password_hash, listener)
|
||||||
}
|
}
|
||||||
if tls_port != "disabled" {
|
if tls_port != "disabled" {
|
||||||
listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config)
|
listener, err := tls.Listen("tcp", fmt.Sprintf(":%v", tls_port), &tls_config)
|
||||||
|
@ -74,7 +81,7 @@ func Run(
|
||||||
}
|
}
|
||||||
log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host)
|
log.Info().Msgf("TLS server started on port %v for host %v", tls_port, host)
|
||||||
wait_group.Add(1)
|
wait_group.Add(1)
|
||||||
go Listen(wait_group, host, password_hash, tls_config, listener)
|
go Listen(wait_group, host, tls_config, auth, password_hash, listener)
|
||||||
}
|
}
|
||||||
wait_group.Wait()
|
wait_group.Wait()
|
||||||
|
|
||||||
|
@ -84,8 +91,9 @@ func Run(
|
||||||
func Listen(
|
func Listen(
|
||||||
wait_group sync.WaitGroup,
|
wait_group sync.WaitGroup,
|
||||||
host string,
|
host string,
|
||||||
password_hash string,
|
|
||||||
tls_config tls.Config,
|
tls_config tls.Config,
|
||||||
|
auth bool,
|
||||||
|
password_hash string,
|
||||||
listener net.Listener,
|
listener net.Listener,
|
||||||
) {
|
) {
|
||||||
defer wait_group.Done()
|
defer wait_group.Done()
|
||||||
|
@ -94,7 +102,7 @@ func Listen(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("Failed to accept client: %v", err)
|
log.Error().Msgf("Failed to accept client: %v", err)
|
||||||
} else {
|
} else {
|
||||||
go handle(connection, host, password_hash, tls_config)
|
go handle(connection, host, tls_config, auth, password_hash)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,16 +35,20 @@ type SMTPSession struct {
|
||||||
password_hash string
|
password_hash string
|
||||||
tls_config tls.Config
|
tls_config tls.Config
|
||||||
buffer [4096]byte
|
buffer [4096]byte
|
||||||
|
HasHelloed bool
|
||||||
|
requires_auth bool
|
||||||
|
is_authed bool
|
||||||
|
|
||||||
ReversePathBuffer *string
|
ReversePathBuffer *string
|
||||||
ForwardPathBuffer []string
|
ForwardPathBuffer []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeSMTPSession(connection net.Conn, host string, password_hash string, tls_config tls.Config) SMTPSession {
|
func MakeSMTPSession(connection net.Conn, host string, tls_config tls.Config, auth bool, password_hash string) SMTPSession {
|
||||||
return SMTPSession{
|
return SMTPSession{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
host: host,
|
host: host,
|
||||||
password_hash: password_hash,
|
password_hash: password_hash,
|
||||||
|
requires_auth: auth,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,6 +141,15 @@ func (self *SMTPSession) GetHost() string {
|
||||||
return self.host
|
return self.host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (self *SMTPSession) RequiresAuthentication() bool {
|
||||||
|
return self.requires_auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *SMTPSession) AuthenticationSatisfied() bool {
|
||||||
|
fmt.Println(self.requires_auth, self.is_authed)
|
||||||
|
return !self.requires_auth || self.is_authed
|
||||||
|
}
|
||||||
|
|
||||||
func (self *SMTPSession) ValidatePassword(password string) bool {
|
func (self *SMTPSession) ValidatePassword(password string) bool {
|
||||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash
|
return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue