Added AUTH command and PLAIN sasl mechanism
This commit is contained in:
parent
3523f1cf4e
commit
86dc81451a
7 changed files with 75 additions and 24 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -21,3 +21,5 @@
|
|||
|
||||
# ...even if they are in subdirectories
|
||||
!*/
|
||||
|
||||
!hash.sh
|
||||
|
|
|
@ -24,12 +24,17 @@ import (
|
|||
)
|
||||
|
||||
type Config struct {
|
||||
LogLevel string
|
||||
Host string
|
||||
General GeneralConfig
|
||||
Plain PlainConfig
|
||||
TLS TLSConfig
|
||||
}
|
||||
|
||||
type GeneralConfig struct {
|
||||
LogLevel string
|
||||
Host string
|
||||
PasswordHash string
|
||||
}
|
||||
|
||||
type PlainConfig struct {
|
||||
Enabled bool
|
||||
Port string
|
||||
|
|
|
@ -62,8 +62,8 @@ func main() {
|
|||
log.Fatal().Msgf("%v", err)
|
||||
}
|
||||
if log_level == "" {
|
||||
if config.LogLevel != "" {
|
||||
log_level = config.LogLevel
|
||||
if config.General.LogLevel != "" {
|
||||
log_level = config.General.LogLevel
|
||||
} else {
|
||||
log_level = "info"
|
||||
}
|
||||
|
@ -107,7 +107,8 @@ func main() {
|
|||
}
|
||||
|
||||
err = smtp.Run(
|
||||
config.Host,
|
||||
config.General.Host,
|
||||
config.General.PasswordHash,
|
||||
plain_config,
|
||||
tls_config,
|
||||
)
|
||||
|
|
3
hash.sh
Executable file
3
hash.sh
Executable file
|
@ -0,0 +1,3 @@
|
|||
#!/bin/sh
|
||||
|
||||
printf "$@" | base64 | tr -d '\n' | sha256sum
|
|
@ -68,6 +68,7 @@ func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
|
|||
err := smtp_session.Write(
|
||||
fmt.Sprintf(
|
||||
"250-%v is shy" + "\r\n" +
|
||||
"250-AUTH PLAIN" + "\r\n" +
|
||||
"250-8BITMIME" + "\r\n" +
|
||||
"250 SMTPUTF8" + "\r\n",
|
||||
smtp_session.GetHost(),
|
||||
|
@ -80,6 +81,31 @@ func Ehlo(smtp_session *SMTPSession, message string) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
func Auth(smtp_session *SMTPSession, message string) (bool, error) {
|
||||
parts := strings.Split(message, " ")[1:]
|
||||
if parts[0] == "PLAIN" {
|
||||
var password string
|
||||
if len(parts) > 1 {
|
||||
password = parts[1][:len(parts[1]) - 1]
|
||||
} else {
|
||||
err := smtp_session.Write("334 \n")
|
||||
if err != nil {
|
||||
return true, nil
|
||||
}
|
||||
password, err = smtp_session.ReadUntil("\r\n")
|
||||
if err != nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if smtp_session.ValidatePassword(password) {
|
||||
smtp_session.Write("235 Authentication successful\n")
|
||||
} else {
|
||||
smtp_session.Write("535 Authentication credentials invalid\n")
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func MailFrom(smtp_session *SMTPSession, message string) (bool, error) {
|
||||
match := ReversePath.FindStringSubmatch(message)
|
||||
if match == nil {
|
||||
|
@ -131,7 +157,7 @@ func Data(smtp_session *SMTPSession, message string) (bool, error) {
|
|||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
data, err := smtp_session.Read("\r\n.\r\n")
|
||||
data, err := smtp_session.ReadUntil("\r\n.\r\n")
|
||||
if err != nil {
|
||||
smtp_session.Write(
|
||||
"550 Action not taken\n",
|
||||
|
|
|
@ -36,14 +36,14 @@ type TLSConfig struct {
|
|||
TlsConfig tls.Config
|
||||
}
|
||||
|
||||
func handle(connection net.Conn, host string) {
|
||||
func handle(connection net.Conn, host string, password_hash string) {
|
||||
log.Info().Msgf(
|
||||
"New connection %v. Starting session.",
|
||||
connection.RemoteAddr(),
|
||||
)
|
||||
defer connection.Close()
|
||||
|
||||
session := MakeSMTPSession(connection, host)
|
||||
session := MakeSMTPSession(connection, host, password_hash)
|
||||
err := session.Run()
|
||||
if err != nil {
|
||||
log.Error().Msgf(
|
||||
|
@ -59,7 +59,7 @@ func handle(connection net.Conn, host string) {
|
|||
}
|
||||
}
|
||||
|
||||
func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
|
||||
func Run(host string, password_hash string, plain_config *PlainConfig, tls_config *TLSConfig) error {
|
||||
var wait_group sync.WaitGroup
|
||||
if plain_config != nil {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("%v:%v", host, plain_config.Port))
|
||||
|
@ -68,7 +68,7 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
|
|||
}
|
||||
log.Info().Msgf("Plain text server started on port %v for host %v", plain_config.Port, host)
|
||||
wait_group.Add(1)
|
||||
go Listen(wait_group, host, listener)
|
||||
go Listen(wait_group, host, password_hash, listener)
|
||||
}
|
||||
if tls_config != nil {
|
||||
listener, err := tls.Listen("tcp", fmt.Sprintf("%v:%v", host, tls_config.Port), &tls_config.TlsConfig)
|
||||
|
@ -77,18 +77,18 @@ func Run(host string, plain_config *PlainConfig, tls_config *TLSConfig) error {
|
|||
}
|
||||
log.Info().Msgf("TLS server started on port %v for host %v", tls_config.Port, host)
|
||||
wait_group.Add(1)
|
||||
go Listen(wait_group, host, listener)
|
||||
go Listen(wait_group, host, password_hash, listener)
|
||||
}
|
||||
wait_group.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Listen(wait_group sync.WaitGroup, host string, listener net.Listener) {
|
||||
func Listen(wait_group sync.WaitGroup, host string, password_hash string, listener net.Listener) {
|
||||
defer wait_group.Done()
|
||||
for {
|
||||
connection, _ := listener.Accept()
|
||||
|
||||
go handle(connection, host)
|
||||
go handle(connection, host, password_hash)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"crypto/tls"
|
||||
"crypto/sha256"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -31,16 +32,18 @@ import (
|
|||
type SMTPSession struct {
|
||||
connection net.Conn
|
||||
host string
|
||||
password_hash string
|
||||
buffer [4096]byte
|
||||
|
||||
ReversePathBuffer *string
|
||||
ForwardPathBuffer []string
|
||||
}
|
||||
|
||||
func MakeSMTPSession(connection net.Conn, host string) SMTPSession {
|
||||
func MakeSMTPSession(connection net.Conn, host string, password_hash string) SMTPSession {
|
||||
return SMTPSession{
|
||||
connection: connection,
|
||||
host: host,
|
||||
password_hash: password_hash,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,6 +55,7 @@ func (self *SMTPSession) Run() error {
|
|||
COMMANDS := []Command{
|
||||
Command{ "HELO", Helo },
|
||||
Command{ "EHLO", Ehlo },
|
||||
Command{ "AUTH", Auth },
|
||||
Command{ "MAIL FROM:", MailFrom },
|
||||
Command{ "RCPT TO:", RcptTo },
|
||||
Command{ "DATA", Data },
|
||||
|
@ -61,7 +65,7 @@ func (self *SMTPSession) Run() error {
|
|||
}
|
||||
quit := false
|
||||
for !quit {
|
||||
message, err := self.Read("\n")
|
||||
message, err := self.ReadUntil("\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -102,7 +106,7 @@ func (self *SMTPSession) TraceSession(direction string, message string) {
|
|||
|
||||
}
|
||||
|
||||
func (self *SMTPSession) Read(delimiter string) (string, error) {
|
||||
func (self *SMTPSession) ReadUntil(delimiter string) (string, error) {
|
||||
var message string
|
||||
for !strings.Contains(message, delimiter) {
|
||||
num_read, err := self.connection.Read(self.buffer[:])
|
||||
|
@ -116,6 +120,12 @@ func (self *SMTPSession) Read(delimiter string) (string, error) {
|
|||
return message[:len(message) - len(delimiter)], nil
|
||||
}
|
||||
|
||||
func (self *SMTPSession) ReadCount(num_bytes int) ([]byte, error) {
|
||||
bytes := make([]byte, num_bytes)
|
||||
_, err := self.connection.Read(bytes[:])
|
||||
return bytes, err
|
||||
}
|
||||
|
||||
func (self *SMTPSession) Write(message string) error {
|
||||
self.TraceSession("<- ", message)
|
||||
_, err := self.connection.Write([]byte(message))
|
||||
|
@ -126,6 +136,10 @@ func (self *SMTPSession) GetHost() string {
|
|||
return self.host
|
||||
}
|
||||
|
||||
func (self *SMTPSession) ValidatePassword(password string) bool {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(password))) == self.password_hash
|
||||
}
|
||||
|
||||
func (self *SMTPSession) MailFrom(from string) {
|
||||
self.Reset()
|
||||
self.ReversePathBuffer = &from
|
||||
|
|
Loading…
Reference in a new issue