package client
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/textproto"
"strconv"
"strings"
"github.com/uponusolutions/go-sasl"
"github.com/uponusolutions/go-smtp"
"github.com/uponusolutions/go-smtp/internal/textsmtp"
)
// dial returns a connection to an SMTP server at addr. The addr must
// include a port, as in "mail.example.com:smtp".
func (c *Client) dial(ctx context.Context, addr string) (net.Conn, error) {
dialer := net.Dialer{Timeout: c.dialTimeout}
return dialer.DialContext(ctx, "tcp", addr)
}
// dialTLS returns a connection to an SMTP server at addr via TLS.
// The addr must include a port, as in "mail.example.com:smtps".
//
// A nil tlsConfig is equivalent to a zero tls.Config.
func (c *Client) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
tlsDialer := tls.Dialer{
NetDialer: &net.Dialer{Timeout: c.dialTimeout},
Config: c.tlsConfig,
}
return tlsDialer.DialContext(ctx, "tcp", addr)
}
// setConn sets the underlying network connection for the client.
func (c *Client) setConn(conn net.Conn) {
c.conn = conn
if c.debug != nil {
c.text = textsmtp.NewTextproto(struct {
io.Reader
io.Writer
io.Closer
}{
io.TeeReader(c.conn, c.debug),
io.MultiWriter(c.conn, c.debug),
c.conn,
}, c.readerSize, c.writerSize, c.maxLineLength)
}
if c.text != nil {
c.text.Replace(conn)
} else {
c.text = textsmtp.NewTextproto(conn, c.readerSize, c.writerSize, c.maxLineLength)
}
}
// Close closes the connection.
func (c *Client) Close() error {
if c.conn == nil {
return nil
}
err := c.text.Close()
c.conn = nil
return err
}
// greet reads the greeting of the server
// if an error occurred the connection is closed
func (c *Client) greet() error {
// Initial greeting timeout. RFC 5321 recommends 5 minutes.
timeout := smtp.Timeout(c.conn, c.commandTimeout)
defer timeout()
_, msg, err := c.readResponse(220)
if err != nil {
_ = c.Close()
}
if idx := strings.IndexRune(msg, ' '); idx >= 0 {
msg = msg[:idx]
}
c.connName = msg
return err
}
// hello runs a hello exchange
// if an error occurred the connection is closed
func (c *Client) hello() error {
err := c.ehlo()
var smtp *smtp.Status
if err != nil && errors.As(err, &smtp) && (smtp.Code == 500 || smtp.Code == 502) {
// The server doesn't support EHLO, fallback to HELO
err = c.helo()
}
if err != nil {
_ = c.Close()
}
return err
}
func (c *Client) readResponse(expectCode int) (int, string, error) {
code, msg, err := c.text.ReadResponse(expectCode)
if protoErr, ok := err.(*textproto.Error); ok {
err = toSMTPErr(protoErr)
}
return code, msg, err
}
// cmd is a convenience function that sends a command and returns the response
// textproto.Error returned by c.text.ReadResponse is converted into smtp.
func (c *Client) cmd(expectCode int, format string, args ...any) (int, string, error) {
timeout := smtp.Timeout(c.conn, c.commandTimeout)
defer timeout()
id, err := c.text.Cmd(format, args...)
if err != nil {
return 0, "", err
}
c.text.StartResponse(id)
defer c.text.EndResponse(id)
return c.readResponse(expectCode)
}
// helo sends the HELO greeting to the server. It should be used only when the
// server does not support ehlo.
func (c *Client) helo() error {
c.ext = nil
_, _, err := c.cmd(250, "HELO %s", c.localName)
return err
}
// ehlo sends the EHLO (extended hello) greeting to the server. It
// should be the preferred greeting for servers that support it.
func (c *Client) ehlo() error {
cmd := "EHLO"
_, msg, err := c.cmd(250, "%s %s", cmd, c.localName)
if err != nil {
return err
}
ext := make(map[string]string)
extList := strings.Split(msg, "\n")
if len(extList) > 1 {
extList = extList[1:]
for _, line := range extList {
i := strings.IndexByte(line, ' ')
if i < 0 {
ext[line] = ""
} else {
ext[line[:i]] = line[i+1:]
}
}
}
c.ext = ext
return err
}
// startTLS sends the STARTTLS command and encrypts all further communication.
// Only servers that advertise the STARTTLS extension support this function.
//
// A nil config is equivalent to a zero tls.Config.
//
// If server returns an error, it will be of type *smtp.
// if an error occurred the connection is closed
func (c *Client) startTLS(serverName string) error {
_, _, err := c.cmd(220, "STARTTLS")
if err != nil {
_ = c.Quit()
return err
}
config := c.tlsConfig
if config == nil {
config = &tls.Config{
ServerName: serverName,
}
} else if config.ServerName == "" && serverName != "" {
// Make a copy to avoid polluting argument
config = config.Clone()
config.ServerName = serverName
}
conn := tls.Client(c.conn, config)
timeout := smtp.Timeout(conn, c.tlsHandshakeTimeout)
defer timeout()
err = conn.Handshake()
if err != nil {
_ = c.Close()
return err
}
c.setConn(conn)
err = c.hello()
if err != nil {
return err
}
return nil
}
// TLSConnectionState returns the client's TLS connection state.
// The return values are their zero values if STARTTLS did
// not succeed.
func (c *Client) TLSConnectionState() (tls.ConnectionState, bool) {
tc, ok := c.conn.(*tls.Conn)
if !ok {
return tls.ConnectionState{}, ok
}
return tc.ConnectionState(), true
}
// Verify checks the validity of an email address on the server.
// If Verify returns nil, the address is valid. A non-nil return
// does not necessarily indicate an invalid address. Many servers
// will not verify addresses for security reasons.
//
// If server returns an error, it will be of type *smtp.
func (c *Client) Verify(addr string, opts *VrfyOptions) error {
if err := validateLine(addr); err != nil {
return err
}
var sb strings.Builder
sb.Grow(2048)
fmt.Fprintf(&sb, "VRFY %s", addr)
// By default utf8 is preferred
if opts == nil || opts.UTF8 != UTF8Disabled {
if _, ok := c.ext["SMTPUTF8"]; ok {
sb.WriteString(" SMTPUTF8")
} else if opts != nil && opts.UTF8 == UTF8Force {
return errors.New("smtp: server does not support SMTPUTF8")
}
}
_, _, err := c.cmd(250, "%s", sb.String())
return err
}
// Auth authenticates a client using the provided authentication mechanism.
// Only servers that advertise the AUTH extension support this function.
//
// If server returns an error, it will be of type *smtp.
func (c *Client) Auth(a sasl.Client) error {
encoding := base64.StdEncoding
mech, resp, err := a.Start()
if err != nil {
return err
}
var resp64 []byte
if len(resp) > 0 {
resp64 = make([]byte, encoding.EncodedLen(len(resp)))
encoding.Encode(resp64, resp)
} else if resp != nil {
resp64 = []byte{'='}
}
code, msg64, err := c.cmd(0, "%s", strings.TrimSpace(fmt.Sprintf("AUTH %s %s", mech, resp64)))
for err == nil {
var msg []byte
switch code {
case 334:
msg, err = encoding.DecodeString(msg64)
case 235:
// the last message isn't base64 because it isn't a challenge
msg = []byte(msg64)
default:
err = toSMTPErr(&textproto.Error{Code: code, Msg: msg64})
}
if err == nil {
if code == 334 {
resp, err = a.Next(msg)
} else {
resp = nil
}
}
if err != nil {
// abort the AUTH
_, _, _ = c.cmd(501, "*")
break
}
if resp == nil {
break
}
resp64 = make([]byte, encoding.EncodedLen(len(resp)))
encoding.Encode(resp64, resp)
code, msg64, err = c.cmd(0, "%s", string(resp64))
}
return err
}
// Mail issues a MAIL command to the server using the provided email address.
// If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME
// parameter.
// This initiates a mail transaction and is followed by one or more Rcpt calls.
//
// If opts is not nil, MAIL arguments provided in the structure will be added
// to the command. Handling of unsupported options depends on the extension.
//
// If server returns an error, it will be of type *smtp.
func (c *Client) Mail(from string, opts *MailOptions) error {
if err := validateLine(from); err != nil {
return err
}
var sb strings.Builder
// A high enough power of 2 than 510+14+26+11+9+9+39+500
sb.Grow(2048)
fmt.Fprintf(&sb, "MAIL FROM:<%s>", from)
if _, ok := c.ext["8BITMIME"]; ok {
sb.WriteString(" BODY=8BITMIME")
}
if _, ok := c.ext["SIZE"]; ok && opts != nil && opts.Size != 0 {
fmt.Fprintf(&sb, " SIZE=%v", opts.Size)
}
if opts != nil && opts.RequireTLS {
if _, ok := c.ext["REQUIRETLS"]; !ok {
return errors.New("smtp: server does not support REQUIRETLS")
}
sb.WriteString(" REQUIRETLS")
}
// By default utf8 is preferred
if opts == nil || opts.UTF8 != UTF8Disabled {
if _, ok := c.ext["SMTPUTF8"]; ok {
sb.WriteString(" SMTPUTF8")
} else if opts != nil && opts.UTF8 == UTF8Force {
return errors.New("smtp: server does not support SMTPUTF8")
}
}
if _, ok := c.ext["DSN"]; ok && opts != nil {
switch opts.Return {
case smtp.DSNReturnFull, smtp.DSNReturnHeaders:
fmt.Fprintf(&sb, " RET=%s", string(opts.Return))
case "":
// This space is intentionally left blank
default:
return errors.New("smtp: Unknown RET parameter value")
}
if opts.EnvelopeID != "" {
if !textsmtp.IsPrintableASCII(opts.EnvelopeID) {
return errors.New("smtp: Malformed ENVID parameter value")
}
fmt.Fprintf(&sb, " ENVID=%s", encodeXtext(opts.EnvelopeID))
}
}
if opts != nil && opts.Auth != nil {
if _, ok := c.ext["AUTH"]; ok {
fmt.Fprintf(&sb, " AUTH=%s", encodeXtext(*opts.Auth))
}
// We can safely discard parameter if server does not support AUTH.
}
if opts != nil && opts.XOORG != nil {
if _, ok := c.ext["XOORG"]; ok {
fmt.Fprintf(&sb, " XOORG=%s", encodeXtext(*opts.XOORG))
}
// We can safely discard parameter if server does not support AUTH.
}
_, _, err := c.cmd(250, "%s", sb.String())
return err
}
// Rcpt issues a RCPT command to the server using the provided email address.
// A call to Rcpt must be preceded by a call to Mail and may be followed by
// a Data call or another Rcpt call.
//
// If opts is not nil, RCPT arguments provided in the structure will be added
// to the command. Handling of unsupported options depends on the extension.
//
// If server returns an error, it will be of type *smtp.
func (c *Client) Rcpt(to string, opts *smtp.RcptOptions) error {
if err := validateLine(to); err != nil {
return err
}
var sb strings.Builder
// A high enough power of 2 than 510+29+501
sb.Grow(2048)
fmt.Fprintf(&sb, "RCPT TO:<%s>", to)
if _, ok := c.ext["DSN"]; ok && opts != nil {
if len(opts.Notify) != 0 {
sb.WriteString(" NOTIFY=")
if err := textsmtp.CheckNotifySet(opts.Notify); err != nil {
return errors.New("smtp: Malformed NOTIFY parameter value")
}
for i, v := range opts.Notify {
if i != 0 {
sb.WriteString(",")
}
sb.WriteString(string(v))
}
}
if opts.OriginalRecipient != "" {
var enc string
switch opts.OriginalRecipientType {
case smtp.DSNAddressTypeRFC822:
if !textsmtp.IsPrintableASCII(opts.OriginalRecipient) {
return errors.New("smtp: Illegal address")
}
enc = encodeXtext(opts.OriginalRecipient)
case smtp.DSNAddressTypeUTF8:
if _, ok := c.ext["SMTPUTF8"]; ok {
enc = encodeUTF8AddrUnitext(opts.OriginalRecipient)
} else {
enc = encodeUTF8AddrXtext(opts.OriginalRecipient)
}
default:
return errors.New("smtp: Unknown address type")
}
fmt.Fprintf(&sb, " ORCPT=%s;%s", string(opts.OriginalRecipientType), enc)
}
}
if _, _, err := c.cmd(25, "%s", sb.String()); err != nil {
return err
}
return nil
}
// DataCloser implement an io.WriteCloser with the additional
// CloseWithResponse function.
type DataCloser struct {
c *Client
io.WriteCloser
closed bool
}
// CloseWithResponse closes the data closer and returns code, msg
func (d *DataCloser) CloseWithResponse() (code int, msg string, err error) {
if d.closed {
return 0, "", errors.New("smtp: data writer closed twice")
}
if err := d.WriteCloser.Close(); err != nil {
return 0, "", err
}
timeout := smtp.Timeout(d.c.conn, d.c.submissionTimeout)
defer timeout()
code, msg, err = d.c.readResponse(250)
d.closed = true
return code, msg, err
}
// Close closes the data closer.
func (d *DataCloser) Close() error {
_, _, err := d.CloseWithResponse()
return err
}
// Data issues a DATA command to the server and returns a writer that
// can be used to write the mail headers and body. The caller should
// close the writer before calling any more methods on c. A call to
// Data must be preceded by one or more calls to Rcpt.
//
// If server returns an error, it will be of type *smtp.
func (c *Client) Data() (*DataCloser, error) {
_, _, err := c.cmd(354, "DATA")
if err != nil {
return nil, err
}
return &DataCloser{c: c, WriteCloser: textsmtp.NewDotWriter(c.text.W)}, nil
}
// Extension reports whether an extension is support by the server.
// The extension name is case-insensitive. If the extension is supported,
// Extension also returns a string that contains any parameters the
// server specifies for the extension.
func (c *Client) Extension(ext string) (bool, string) {
ext = strings.ToUpper(ext)
param, ok := c.ext[ext]
return ok, param
}
// SupportsAuth checks whether an authentication mechanism is supported.
func (c *Client) SupportsAuth(mech string) bool {
mechs, ok := c.ext["AUTH"]
if !ok {
return false
}
for _, m := range strings.Split(mechs, " ") {
if strings.EqualFold(m, mech) {
return true
}
}
return false
}
// MaxMessageSize returns the maximum message size accepted by the server.
// 0 means unlimited.
//
// If the server doesn't convey this information, ok = false is returned.
func (c *Client) MaxMessageSize() (size int, ok bool) {
v := c.ext["SIZE"]
if v == "" {
return 0, false
}
size, err := strconv.Atoi(v)
if err != nil || size < 0 {
return 0, false
}
return size, true
}
// Reset sends the RSET command to the server, aborting the current mail
// transaction.
func (c *Client) Reset() error {
if _, _, err := c.cmd(250, "RSET"); err != nil {
return err
}
return nil
}
// Noop sends the NOOP command to the server. It does nothing but check
// that the connection to the server is okay.
func (c *Client) Noop() error {
_, _, err := c.cmd(250, "NOOP")
return err
}
// Quit sends the QUIT command and closes the connection to the server.
// If Quit fails the connection will still be closed.
func (c *Client) Quit() error {
if c.conn == nil {
return nil
}
_, _, err := c.cmd(221, "QUIT")
if err != nil {
_ = c.Close()
return err
}
return c.Close()
}
package client
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"net"
"strings"
"time"
"github.com/uponusolutions/go-sasl"
"github.com/uponusolutions/go-smtp"
"github.com/uponusolutions/go-smtp/internal/textsmtp"
)
// Security describes how the connection is etablished.
type Security int32
const (
// SecurityPreferStartTLS tries to use StartTls but fallbacks to plain.
SecurityPreferStartTLS Security = 0
// SecurityPlain is always just a plain connection.
SecurityPlain Security = 1
// SecurityTLS does a implicit tls connection.
SecurityTLS Security = 2
// SecurityStartTLS always does starttls.
SecurityStartTLS Security = 3
)
// UTF8 describes how SMTPUTF8 is used.
type UTF8 int32
const (
// UTF8Prefer uses SMTPUTF8 if possible.
UTF8Prefer UTF8 = 0
// UTF8Force always uses SMTPUTF8.
UTF8Force UTF8 = 1
// UTF8Disabled never uses SMTPUTF8.
UTF8Disabled UTF8 = 2
)
// MailOptions contains parameters for the MAIL command.
type MailOptions struct {
// Size of the body. Can be 0 if not specified by client.
Size int64
// TLS is required for the message transmission.
//
// The message should be rejected if it can't be transmitted
// with TLS.
RequireTLS bool
// The message envelope or message header contains UTF-8-encoded strings.
// This flag is set by SMTPUTF8-aware (RFC 6531) client.
UTF8 UTF8
// Value of RET= argument, FULL or HDRS.
Return smtp.DSNReturn
// Envelope identifier set by the client.
EnvelopeID string
// Accepted Domain from Exchange Online, e.g. from OutgoingConnector
XOORG *string
// The authorization identity asserted by the message sender in decoded
// form with angle brackets stripped.
//
// nil value indicates missing AUTH, non-nil empty string indicates
// AUTH=<>.
//
// Defined in RFC 4954.
Auth *string
}
// VrfyOptions contains parameters for the VRFY command.
type VrfyOptions struct {
// The message envelope or message header contains UTF-8-encoded strings.
// This flag is set by SMTPUTF8-aware (RFC 6531) client.
UTF8 UTF8
}
// Client is an SMTP client.
// It sends one or more mails to a SMTP server over a single connection.
// TODO: Add context support.
type Client struct {
serverAddresses [][]string // Format address:port.
serverAddressIndex int // first server address to try
tlsConfig *tls.Config
saslClient sasl.Client
// keep a reference to the connection so it can be used to create a TLS
// connection later
conn net.Conn
connAddress string // Format address:port.
connName string // server greet name
text *textsmtp.Textproto
ext map[string]string // supported extensions
localName string // the name to use in HELO/EHLO/LHLO
// Time to wait for tls handshake to succeed.
tlsHandshakeTimeout time.Duration
// Time to wait for dial to succeed.
dialTimeout time.Duration
// Time to wait for command responses (this includes 3xx reply to DATA).
commandTimeout time.Duration
// Time to wait for responses after final dot.
submissionTimeout time.Duration
// Max line length, defaults to 2000
maxLineLength int
// Reader size
readerSize int
// Writer size
writerSize int
// Logger for all network activity.
debug io.Writer
// Defines the connection is secured
security Security
mailOptions MailOptions
}
// New returns a new SMTP client.
// When not set via options the address 127.0.0.1:25 is used.
// When not set via options a default tls.Config is used.
func New(opts ...Option) *Client {
c := &Client{
serverAddresses: [][]string{{"127.0.0.1:25"}},
security: SecurityPreferStartTLS,
localName: "localhost",
// As recommended by RFC 5321. For DATA command reply (3xx one) RFC
// recommends a slightly shorter timeout but we do not bother
// differentiating these.
commandTimeout: 5 * time.Minute,
// 10 minutes + 2 minute buffer in case the server is doing transparent
// forwarding and also follows recommended timeouts.
submissionTimeout: 12 * time.Minute,
// 30 seconds, very generous
tlsHandshakeTimeout: 30 * time.Second,
// 30 seconds, very generous
dialTimeout: 30 * time.Second,
// Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6)
maxLineLength: 2000,
// Reader buffer of textproto
readerSize: 4096,
// Writer buffer of textproto
writerSize: 4096,
}
for _, o := range opts {
o(c)
}
return c
}
// Option defines a client option.
type Option func(c *Client)
// WithServerAddresses sets the SMTP servers address.
func WithServerAddresses(addrs ...string) Option {
return func(c *Client) {
c.serverAddresses = [][]string{addrs}
}
}
// WithServerAddressesPrio sets the SMTP servers address.
func WithServerAddressesPrio(addrs ...[]string) Option {
return func(c *Client) {
c.serverAddresses = addrs
}
}
// WithServerAddressIndex sets the SMTP server index.
func WithServerAddressIndex(index int) Option {
return func(c *Client) {
c.serverAddressIndex = index
}
}
// WithMailOptions sets the mail options.
func WithMailOptions(mailOptions MailOptions) Option {
return func(c *Client) {
c.mailOptions = mailOptions
}
}
// WithSubmissionTimeout sets the submission timeout.
func WithSubmissionTimeout(submissionTimeout time.Duration) Option {
return func(c *Client) {
c.submissionTimeout = submissionTimeout
}
}
// WithCommandTimeout sets the command timeout.
func WithCommandTimeout(commandTimeout time.Duration) Option {
return func(c *Client) {
c.commandTimeout = commandTimeout
}
}
// WithDialTimeout sets the dial timeout.
func WithDialTimeout(dialTimeout time.Duration) Option {
return func(c *Client) {
c.dialTimeout = dialTimeout
}
}
// WithTlsHandshakeTimeout sets tls handshake timeout.
func WithTlsHandshakeTimeout(tlsHandshakeTimeout time.Duration) Option {
return func(c *Client) {
c.tlsHandshakeTimeout = tlsHandshakeTimeout
}
}
// WithLocalName sets the HELO local name.
func WithLocalName(localName string) Option {
return func(c *Client) {
c.localName = localName
}
}
// WithTLSConfig sets the TLS config.
func WithTLSConfig(cfg *tls.Config) Option {
return func(c *Client) {
c.tlsConfig = cfg
}
}
// WithSecurity sets the TLS config.
func WithSecurity(security Security) Option {
return func(c *Client) {
c.security = security
}
}
// WithSASLClient sets the SASL client.
func WithSASLClient(cl sasl.Client) Option {
return func(c *Client) {
c.saslClient = cl
}
}
// WithMaxLineLength sets the max line length.
func WithMaxLineLength(maxLineLength int) Option {
return func(c *Client) {
c.maxLineLength = maxLineLength
}
}
// WithReaderSize sets the reader size.
func WithReaderSize(readerSize int) Option {
return func(c *Client) {
c.readerSize = readerSize
}
}
// WithWriterSize sets the reader size.
func WithWriterSize(writerSize int) Option {
return func(c *Client) {
c.writerSize = writerSize
}
}
// ServerAddresses returns the server address.
func (c *Client) ServerAddresses() [][]string {
return c.serverAddresses
}
// ServerAddress returns the current server address.
func (c *Client) ServerAddress() string {
return c.connAddress
}
// ServerName returns the current server name.
func (c *Client) ServerName() string {
return c.connName
}
// Connect connects to one of the available SMTP server.
// When server supports auth and clients SaslClient is set, auth is called.
// Security is enforced like configured (Plain, TLS, StartTLS or PreferStartTLS)
// If an error occures, the connection is closed if open.
func (c *Client) Connect(ctx context.Context) error {
// verify if local name is valid
if strings.ContainsAny(c.localName, "\n\r") {
return errors.New("smtp: the local name must not contain CR or LF")
}
var err error
for i := 0; i < len(c.serverAddresses); i++ {
for p := 0; p < len(c.serverAddresses[i]); p++ {
// use c.serverAddressIndex
address := c.serverAddresses[i][(p+c.serverAddressIndex)%len(c.serverAddresses[i])]
err = c.connectAddress(ctx, address)
if err == nil {
c.connAddress = address
return nil
}
}
}
return err
}
// Connect connects to the SMTP server (addr).
// When server supports auth and clients SaslClient is set, auth is called.
// Security is enforced like configured (Plain, TLS, StartTLS or PreferStartTLS)
// If an error occures, the connection is closed if open.
func (c *Client) connectAddress(ctx context.Context, addr string) error {
var err error
var conn net.Conn
switch c.security {
case SecurityPlain:
fallthrough
case SecurityStartTLS:
fallthrough
case SecurityPreferStartTLS:
conn, err = c.dial(ctx, addr)
case SecurityTLS:
conn, err = c.dialTLS(ctx, addr)
}
if err != nil {
return err
}
c.setConn(conn)
if err = c.greet(); err != nil {
return err
}
if err = c.hello(); err != nil {
return err
}
if c.security == SecurityStartTLS || c.security == SecurityPreferStartTLS {
if ok, _ := c.Extension("STARTTLS"); !ok {
if c.security == SecurityStartTLS {
_ = c.Quit()
return errors.New("smtp: server doesn't support STARTTLS")
}
} else {
serverName, _, _ := net.SplitHostPort(addr)
err = c.startTLS(serverName)
if err != nil {
return err
}
}
}
return c.auth()
}
func (c *Client) auth() error {
// Authenticate if authentication is possible and sasl client available.
if ok, _ := c.Extension("AUTH"); ok && c.saslClient != nil {
if err := c.Auth(c.saslClient); err != nil {
_ = c.Quit()
return err
}
}
return nil
}
func (c *Client) prepare(from string, rcpt []string) (*DataCloser, error) {
if c.conn == nil {
return nil, errors.New("client is nil or not connected")
}
if len(rcpt) < 1 {
return nil, errors.New("no recipients")
}
// MAIL FROM:
if err := c.Mail(from, &c.mailOptions); err != nil {
return nil, err
}
// RCPT TO:
for _, addr := range rcpt {
if err := c.Rcpt(addr, &smtp.RcptOptions{}); err != nil {
return nil, err
}
}
// DATA
w, err := c.Data()
if err != nil {
return nil, err
}
return w, nil
}
// SendMail will use an existing connection to send an email from
// address from, to addresses to, with message r.
//
// This function does not start TLS, nor does it perform authentication. Use
// DialStartTLS and Auth before-hand if desirable.
//
// The addresses in the to parameter are the SMTP RCPT addresses.
//
// The r parameter should be an RFC 822-style email with headers
// first, a blank line, and then the message body. The lines of r
// should be CRLF terminated. The r headers should usually include
// fields such as "From", "To", "Subject", and "Cc". Sending "Bcc"
// messages is accomplished by including an email address in the to
// parameter but not including it in the r headers.
func (c *Client) SendMail(from string, rcpt []string, in io.Reader) (code int, msg string, err error) {
w, err := c.prepare(from, rcpt)
if err != nil {
return 0, "", err
}
_, err = io.Copy(w, in)
if err != nil {
return 0, "", err
}
return w.CloseWithResponse()
}
// SetXOORG set xoorg support
func (c *Client) SetXOORG(xoorg *string) {
c.mailOptions.XOORG = xoorg
}
// Send implements enmime.Sender interface.
func (c *Client) Send(from string, rcpt []string, msg []byte) error {
_, _, err := c.SendMail(from, rcpt, bytes.NewBuffer(msg))
return err
}
package client
import (
"errors"
"net/textproto"
"strconv"
"strings"
"github.com/uponusolutions/go-smtp"
)
func parseEnhancedCode(s string) (smtp.EnhancedCode, error) {
parts := strings.Split(s, ".")
if len(parts) != 3 {
return smtp.EnhancedCode{}, errors.New("wrong amount of enhanced code parts")
}
code := smtp.EnhancedCode{}
for i, part := range parts {
num, err := strconv.Atoi(part)
if err != nil {
return code, err
}
code[i] = num
}
return code, nil
}
// toSMTPErr converts textproto.Error into smtp, parsing
// enhanced status code if it is present.
func toSMTPErr(protoErr *textproto.Error) *smtp.Status {
smtpErr := &smtp.Status{
Code: protoErr.Code,
Message: protoErr.Msg,
}
parts := strings.SplitN(protoErr.Msg, " ", 2)
if len(parts) != 2 {
return smtpErr
}
enchCode, err := parseEnhancedCode(parts[0])
if err != nil {
return smtpErr
}
msg := parts[1]
// Per RFC 2034, enhanced code should be prepended to each line.
msg = strings.ReplaceAll(msg, "\n"+parts[0]+" ", "\n")
smtpErr.EnhancedCode = enchCode
smtpErr.Message = msg
return smtpErr
}
// validateLine checks to see if a line has CR or LF.
func validateLine(line string) error {
if strings.ContainsAny(line, "\n\r") {
return errors.New("smtp: a line must not contain CR or LF")
}
return nil
}
func encodeXtext(raw string) string {
var out strings.Builder
out.Grow(len(raw))
for _, ch := range raw {
switch {
case ch >= '!' && ch <= '~' && ch != '+' && ch != '=':
// printable non-space US-ASCII except '+' and '='
out.WriteRune(ch)
default:
out.WriteRune('+')
out.WriteString(strings.ToUpper(strconv.FormatInt(int64(ch), 16)))
}
}
return out.String()
}
// encodeUTF8AddrUnitext encodes raw string to the utf-8-addr-unitext form in RFC 6533.
func encodeUTF8AddrUnitext(raw string) string {
var out strings.Builder
out.Grow(len(raw))
for _, ch := range raw {
switch {
case ch >= '!' && ch <= '~' && ch != '+' && ch != '=':
// printable non-space US-ASCII except '+' and '='
out.WriteRune(ch)
case ch <= '\x7F':
// other ASCII: CTLs, space and specials
out.WriteRune('\\')
out.WriteRune('x')
out.WriteRune('{')
out.WriteString(strings.ToUpper(strconv.FormatInt(int64(ch), 16)))
out.WriteRune('}')
default:
// UTF-8 non-ASCII
out.WriteRune(ch)
}
}
return out.String()
}
// encodeUTF8AddrXtext encodes raw string to the utf-8-addr-xtext form in RFC 6533.
func encodeUTF8AddrXtext(raw string) string {
var out strings.Builder
out.Grow(len(raw))
for _, ch := range raw {
switch {
case ch >= '!' && ch <= '~' && ch != '+' && ch != '=':
// printable non-space US-ASCII except '+' and '='
out.WriteRune(ch)
default:
out.WriteRune('\\')
out.WriteRune('x')
out.WriteRune('{')
out.WriteString(strings.ToUpper(strconv.FormatInt(int64(ch), 16)))
out.WriteRune('}')
}
}
return out.String()
}
package smtp
import (
"fmt"
)
// EnhancedCode is the SMTP enhanced code
type EnhancedCode [3]int
// Status specifies the error code, enhanced error code (if any) and
// message returned by the server.
type Status struct {
Code int
EnhancedCode EnhancedCode
Message string
}
// NoEnhancedCode is used to indicate that enhanced error code should not be
// included in response.
//
// Note that RFC 2034 requires an enhanced code to be included in all 2xx, 4xx
// and 5xx responses. This constant is exported for use by extensions, you
// should probably use EnhancedCodeNotSet instead.
var NoEnhancedCode = EnhancedCode{-1, -1, -1}
// EnhancedCodeNotSet is a nil value of EnhancedCode field in smtp, used
// to indicate that backend failed to provide enhanced status code. X.0.0 will
// be used (X is derived from error code).
var EnhancedCodeNotSet = EnhancedCode{0, 0, 0}
// NewStatus creates a new status.
func NewStatus(code int, enhCode EnhancedCode, msg string) *Status {
return &Status{
Code: code,
EnhancedCode: enhCode,
Message: msg,
}
}
// Error returns a error string.
func (err *Status) Error() string {
s := fmt.Sprintf("SMTP error %03d", err.Code)
if err.Message != "" {
s += ": " + err.Message
}
return s
}
// Positive returns true if the status code is 2xx.
func (err *Status) Positive() bool {
return err.Code/100 == 2
}
// Temporary returns true if the status code is 4xx.
func (err *Status) Temporary() bool {
return err.Code/100 == 4
}
// Permanent returns true if the status code is 5xx.
func (err *Status) Permanent() bool {
return err.Code/100 == 5
}
var (
// Reset is returned by Reader passed to Data function if client does not
// send another BDAT command and instead issues RSET command.
Reset = &Status{
Code: 250,
EnhancedCode: EnhancedCode{2, 0, 0},
Message: "Session reset",
}
// VRFY default return.
VRFY = &Status{
Code: 252,
EnhancedCode: EnhancedCode{2, 5, 0},
Message: "Cannot VRFY user, but will accept message",
}
// Noop default return.
Noop = &Status{
Code: 250,
EnhancedCode: EnhancedCode{2, 0, 0},
Message: "I have successfully done nothing",
}
// Quit is returned by Reader passed to Data function if client does not
// send another BDAT command and instead issues QUIT command.
Quit = &Status{
Code: 221,
EnhancedCode: EnhancedCode{2, 0, 0},
Message: "Bye",
}
// ErrConnection is returned if a connection error occurs.
ErrConnection = &Status{
Code: 421,
EnhancedCode: EnhancedCode{4, 4, 0},
Message: "Connection error, sorry",
}
// ErrDataTooLarge is returned if the maximum message size is exceeded.
ErrDataTooLarge = &Status{
Code: 552,
EnhancedCode: EnhancedCode{5, 3, 4},
Message: "Maximum message size exceeded",
}
// ErrAuthFailed is returned if the authentication failed.
ErrAuthFailed = &Status{
Code: 535,
EnhancedCode: EnhancedCode{5, 7, 8},
Message: "Authentication failed",
}
// ErrAuthRequired is returned if the authentication is required.
ErrAuthRequired = &Status{
Code: 502,
EnhancedCode: EnhancedCode{5, 7, 0},
Message: "Please authenticate first",
}
// ErrAuthUnsupported is returned if the authentication is not supported.
ErrAuthUnsupported = &Status{
Code: 502,
EnhancedCode: EnhancedCode{5, 7, 0},
Message: "Authentication not supported",
}
// ErrAuthUnknownMechanism is returned if the authentication unsupported.
ErrAuthUnknownMechanism = &Status{
Code: 504,
EnhancedCode: EnhancedCode{5, 7, 4},
Message: "Unsupported authentication mechanism",
}
// ErrNoRecipients is returned if no recipients are set.
ErrNoRecipients = &Status{
Code: 502,
EnhancedCode: EnhancedCode{5, 5, 1},
Message: "Missing RCPT TO command.",
}
)
package main
import (
"context"
"fmt"
"log/slog"
"net/smtp"
"os"
"os/signal"
"github.com/uponusolutions/go-smtp/tester"
)
func main() {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
s := tester.Standard()
listen, err := s.Listen()
if err != nil {
slog.Error("error listen server", slog.Any("error", err))
}
addr := listen.Addr().String()
go func() {
if err := s.Serve(ctx, listen); err != nil {
slog.Error("smtp server response %s", slog.Any("error", err))
}
}()
defer func() {
if err := s.Close(); err != nil {
slog.Error("error closing server", slog.Any("error", err))
}
}()
// Send email.
from := "alice@i.com"
to := []string{"bob@e.com", "mal@b.com"}
msg := []byte("Test\r\n")
if err := smtp.SendMail(addr, nil, from, to, msg); err != nil {
panic(err)
}
// Lookup email.
m, found := tester.GetBackend(s).Load(from, to)
fmt.Printf("Found %t, mail %+v\n", found, m)
}
package benchmark
import (
"bytes"
"context"
"crypto/tls"
"embed"
"io"
"log/slog"
"github.com/uponusolutions/go-sasl"
"github.com/uponusolutions/go-smtp"
"github.com/uponusolutions/go-smtp/client"
"github.com/uponusolutions/go-smtp/server"
)
//go:embed testdata/*
var embedFSTestadata embed.FS
type message struct {
From string
To []string
RcptOpts []*smtp.RcptOptions
Data []byte
Opts *smtp.MailOptions
}
type backend struct{}
func (be *backend) NewSession(ctx context.Context, _ *server.Conn) (context.Context, server.Session, error) {
return ctx, &session{backend: be}, nil
}
type session struct {
backend *backend
msg *message
}
func (*session) Logger(_ context.Context) *slog.Logger {
return nil
}
func (*session) AuthMechanisms(_ context.Context) []string {
return []string{"PLAIN"}
}
func (*session) Auth(_ context.Context, _ string) (sasl.Server, error) {
return nil, nil
}
func (s *session) Reset(ctx context.Context, _ bool) (context.Context, error) {
s.msg = &message{}
return ctx, nil
}
func (*session) Close(_ context.Context, _ error) {
}
func (*session) STARTTLS(_ context.Context, tls *tls.Config) (*tls.Config, error) {
return tls, nil
}
func (*session) Verify(_ context.Context, _ string, _ *smtp.VrfyOptions) error {
return nil
}
func (s *session) Mail(ctx context.Context, from string, opts *smtp.MailOptions) error {
_, _ = s.Reset(ctx, false)
s.msg.From = from
s.msg.Opts = opts
return nil
}
func (s *session) Rcpt(_ context.Context, to string, opts *smtp.RcptOptions) error {
s.msg.To = append(s.msg.To, to)
s.msg.RcptOpts = append(s.msg.RcptOpts, opts)
return nil
}
func (s *session) Data(_ context.Context, r func() io.Reader) (string, error) {
b, err := io.ReadAll(r())
if err != nil {
return "", err
}
s.msg.Data = b
// s.backend.messages = append(s.backend.messages, s.msg)
return "", nil
}
func testServer(bei *backend, opts ...server.Option) (be *backend, s *server.Server, port string, err error) {
if bei == nil {
be = new(backend)
} else {
be = bei
}
curOpts := []server.Option{
server.WithAddr("127.0.0.1:0"),
server.WithBackend(be),
server.WithMaxLineLength(2000),
server.WithHostname("localhost"),
}
curOpts = append(curOpts, opts...)
s = server.New(
curOpts...,
)
ctx := context.Background()
l, err := s.Listen()
if err != nil {
return nil, nil, "", err
}
go func() {
// nolint: revive
_ = s.Serve(ctx, l)
}()
return be, s, l.Addr().String(), nil
}
func sendMailCon(c *client.Client, data []byte) error {
from := "alice@internal.com"
recipients := []string{"bob@external.com", "tim@external.com"}
in := bytes.NewReader(data)
_, _, err := c.SendMail(from, recipients, in)
return err
}
func sendMail(addr string, data []byte) error {
c := client.New(
client.WithServerAddresses(addr),
client.WithSecurity(client.SecurityPlain),
client.WithMailOptions(client.MailOptions{Size: int64(len(data))}),
)
err := c.Connect(context.Background())
if err != nil {
return nil
}
err = sendMailCon(c, data)
if err != nil {
return nil
}
return c.Quit()
}
package limit
import (
"errors"
"time"
)
// ErrRatelimit is returned if limit reached and strict mode is enabled.
var ErrRatelimit = errors.New("rate limit occurred")
// RatelimitConfig configures a rate limit.
type RatelimitConfig struct {
Rate int
Duration time.Duration
Strict bool
}
// Ratelimit is used e.g. to limit the calls to a function.
type Ratelimit struct {
start time.Time
count int
config *RatelimitConfig
}
// New creates a new rate limit.
func New(config *RatelimitConfig) *Ratelimit {
return &Ratelimit{
config: config,
start: time.Now(),
count: 0,
}
}
// Take returns when it is allowed to do something again.
func (c *Ratelimit) Take() error {
c.count++
if c.count <= c.config.Rate {
return nil
}
now := time.Now()
dur := now.Sub(c.start)
if dur < c.config.Duration {
if c.config.Strict {
return ErrRatelimit
}
time.Sleep(c.config.Duration - dur)
now = time.Now()
}
c.start = now
c.count = 1
return nil
}
package parse
import (
"errors"
"fmt"
"strings"
)
// CutPrefixFold is a version of strings.CutPrefix which is case-insensitive.
func CutPrefixFold(s, prefix string) (string, bool) {
if len(s) < len(prefix) || !strings.EqualFold(s[:len(prefix)], prefix) {
return "", false
}
return s[len(prefix):], true
}
// Cmd parses a line and returns the command, argument or an error.
func Cmd(line string) (cmd string, arg string, err error) {
line = strings.TrimRight(line, "\r\n")
l := len(line)
switch {
case strings.HasPrefix(strings.ToUpper(line), "STARTTLS"):
return "STARTTLS", "", nil
case l == 0:
return "", "", nil
case l < 4:
return "", "", fmt.Errorf("command too short: %q", line)
case l == 4:
return strings.ToUpper(line), "", nil
case l == 5:
// Too long to be only command, too short to have args
return "", "", fmt.Errorf("mangled command: %q", line)
}
// If we made it here, command is long enough to have args
if line[4] != ' ' {
// There wasn't a space after the command?
return "", "", fmt.Errorf("mangled command: %q", line)
}
return strings.ToUpper(line[0:4]), strings.TrimSpace(line[5:]), nil
}
// Args takes the arguments proceeding a command and files them
// into a map[string]string after uppercasing each key. Sample arg
// string:
//
// " BODY=8BITMIME SIZE=1024 SMTPUTF8"
//
// The leading space is mandatory.
func Args(s string) (map[string]string, error) {
argMap := map[string]string{}
for _, arg := range strings.Fields(s) {
m := strings.Split(arg, "=")
switch len(m) {
case 2:
argMap[strings.ToUpper(m[0])] = m[1]
case 1:
argMap[strings.ToUpper(m[0])] = ""
default:
return nil, fmt.Errorf("failed to parse arg string: %q", arg)
}
}
return argMap, nil
}
// HelloArgument parses helo argument
func HelloArgument(arg string) (string, error) {
domain := arg
if idx := strings.IndexRune(arg, ' '); idx >= 0 {
domain = arg[:idx]
}
if domain == "" {
return "", errors.New("invalid domain")
}
return domain, nil
}
// Parser parses command arguments defined in RFC 5321 section 4.1.2.
type Parser struct {
S string
}
func (p *Parser) peekByte() (byte, bool) {
if len(p.S) == 0 {
return 0, false
}
return p.S[0], true
}
func (p *Parser) readByte() (byte, bool) {
ch, ok := p.peekByte()
if ok {
p.S = p.S[1:]
}
return ch, ok
}
func (p *Parser) acceptByte(ch byte) bool {
got, ok := p.peekByte()
if !ok || got != ch {
return false
}
p.readByte()
return true
}
func (p *Parser) expectByte(ch byte) error {
if !p.acceptByte(ch) {
if len(p.S) == 0 {
return fmt.Errorf("expected '%v', got EOF", string(ch))
}
return fmt.Errorf("expected '%v', got '%v'", string(ch), string(p.S[0]))
}
return nil
}
// ReversePath parses a recipient.
func (p *Parser) ReversePath() (string, error) {
if strings.HasPrefix(p.S, "<>") {
p.S = strings.TrimPrefix(p.S, "<>")
return "", nil
}
return p.Path()
}
// Path parses a recipient.
func (p *Parser) Path() (string, error) {
hasBracket := p.acceptByte('<')
if p.acceptByte('@') {
i := strings.IndexByte(p.S, ':')
if i < 0 {
return "", errors.New("malformed a-d-l")
}
p.S = p.S[i+1:]
}
mbox, err := p.Mailbox()
if err != nil {
return "", fmt.Errorf("in mailbox: %v", err)
}
if hasBracket {
if err := p.expectByte('>'); err != nil {
return "", err
}
}
return mbox, nil
}
// Mailbox parses a mailbox.
func (p *Parser) Mailbox() (string, error) {
localPart, err := p.localPart()
if err != nil {
return "", fmt.Errorf("in local-part: %v", err)
} else if localPart == "" {
return "", errors.New("local-part is empty")
}
if err := p.expectByte('@'); err != nil {
return "", err
}
var sb strings.Builder
sb.WriteString(localPart)
sb.WriteByte('@')
for {
ch, ok := p.peekByte()
if !ok {
break
}
if ch == ' ' || ch == '\t' || ch == '>' {
break
}
p.readByte()
sb.WriteByte(ch)
}
if strings.HasSuffix(sb.String(), "@") {
return "", errors.New("domain is empty")
}
return sb.String(), nil
}
func (p *Parser) localPart() (string, error) {
var sb strings.Builder
if p.acceptByte('"') { // quoted-string
for {
ch, ok := p.readByte()
switch ch {
case '\\':
ch, ok = p.readByte()
case '"':
return sb.String(), nil
}
if !ok {
return "", errors.New("malformed quoted-string")
}
sb.WriteByte(ch)
}
} else { // dot-string
for {
ch, ok := p.peekByte()
if !ok {
return sb.String(), nil
}
switch ch {
case '@':
return sb.String(), nil
case '(', ')', '<', '>', '[', ']', ':', ';', '\\', ',', '"', ' ', '\t':
return "", errors.New("malformed dot-string")
}
p.readByte()
sb.WriteByte(ch)
}
}
}
package textsmtp
import (
"bufio"
"bytes"
"io"
"github.com/uponusolutions/go-smtp"
)
var crlfdot = []byte{'\r', '\n', '.'}
type dotReader struct {
r *bufio.Reader
state int
limited bool
n int64 // Maximum bytes remaining
}
// NewDotReader creates a new dot reader.
func NewDotReader(reader *bufio.Reader, maxMessageBytes int64) io.Reader {
dr := &dotReader{
r: reader,
}
if maxMessageBytes > 0 {
dr.limited = true
dr.n = maxMessageBytes
}
return dr
}
// Read reads in some more bytes.
func (r *dotReader) Read(b []byte) (int, error) {
// Run data through a simple state machine to
// elide leading dots and detect End-of-Data (<CR><LF>.<CR><LF>) line.
const (
stateBeginLine = iota // beginning of line; initial state; must be zero
stateCR // wrote \r
stateEOF // reached .\r\n end marker line
)
if r.limited {
if r.n <= 0 {
return 0, smtp.ErrDataTooLarge
}
if int64(len(b)) > r.n {
b = b[0:r.n]
}
}
var n int // data written to b
var skipped int // how many
// IMPORTANT: We cannot wait on read, because no EOL returns
if r.r.Buffered() == 0 {
_, _ = r.r.Peek(5)
}
// min 5, max buffer size, default len(b)
c, err := r.r.Peek(max(min(len(b), r.r.Buffered()), 5))
// write \n
if r.state == stateCR {
b[0] = '\n'
n++
skipped += 2
r.state = stateBeginLine
if c[3] == '\r' && c[4] == '\n' {
r.state = stateEOF
skipped += 2 // skip .\n\r
} else {
b = b[1:]
c = c[3:]
}
}
for r.state != stateEOF {
i := bytes.Index(c, crlfdot)
// no full \r\n. found
if i == -1 {
if err != io.EOF && len(c) > 1 && c[len(c)-2] == '\r' && c[len(c)-1] == '\n' {
// ends with \r\n, write everything before
n += copy(b, c[:len(c)-2])
} else if err != io.EOF && len(c) > 0 && c[len(c)-1] == '\r' {
// ends with \r, write everything before
n += copy(b, c[:len(c)-1])
} else {
n += copy(b, c)
}
break
} else if len(c)-1 < i+4 {
// i is \r, \n.\r\n needs to be accessible
// not enough bytes to check for \r\n.\r\n, write everything before
if i > 0 {
n += copy(b, c[:i])
}
break
}
p := copy(b, c[:i+2])
n += p
// b was to small
if p < i+2 {
// we only wrote \r
if i+2-p == 1 {
r.state = stateCR // next time we want to write \n
skipped-- // prevent \r from being discarded
}
break
}
// the end \r\n.\n\r
if c[i+3] == '\r' && c[i+4] == '\n' {
r.state = stateEOF
skipped += 3 // skip .\r\n
break
}
skipped++ // . isn't written
b = b[i+2:]
c = c[i+3:]
}
// n + skipped is always smaller then what was peeked, so it is guaranteed to work
_, _ = r.r.Discard(n + skipped)
if err == io.EOF && r.state != stateEOF {
err = io.ErrUnexpectedEOF
} else if err == nil && r.state == stateEOF {
err = io.EOF
}
if r.limited {
r.n -= int64(n)
}
return n, err
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Based on the modifications from
// https://github.com/go-textproto/textproto/blob/v0/writer.go
package textsmtp
import (
"bufio"
"bytes"
"io"
)
var (
crnl = []byte{'\r', '\n'}
dotcrnl = []byte{'.', '\r', '\n'}
)
// NewDotWriter returns a writer that can be used to write a dot-encoding to w.
// It takes care of inserting leading dots when necessary,
// translating line-ending \n into \r\n, and adding the final .\r\n line
// when the DotWriter is closed. The caller should close the
// DotWriter before the next call to a method on w.
//
// See the documentation for Reader's DotReader method for details about dot-encoding.
func NewDotWriter(writer *bufio.Writer) io.WriteCloser {
return &dotWriter{
W: writer,
}
}
type dotWriter struct {
W *bufio.Writer
state int
}
const (
wstateBegin = iota // starting state
wstateBeginLine // beginning of line
wstateCR // wrote \r (possibly at end of line)
wstateData // writing data in middle of line
)
func (d *dotWriter) Write(b []byte) (n int, err error) {
var (
i int
p []byte
pLen int
bw = d.W
)
for len(b) > 0 {
i = bytes.IndexByte(b, '\n')
if i >= 0 {
p, b = b[:i+1], b[i+1:]
} else {
p, b = b, nil
}
pLen = len(p)
if d.state == wstateBeginLine && p[0] == '.' {
err = bw.WriteByte('.')
if err != nil {
return n, err
}
}
if b == nil {
// no end of line found in p
if p[pLen-1] == '\r' {
// p ends with \r
d.state = wstateCR
} else {
// just write it down
d.state = wstateData
}
if _, err = bw.Write(p); err != nil {
return n, err
}
} else if d.state == wstateCR && pLen == 1 {
// if b isn't nil and pLen is 1, then it must be a \n
// as \r was send before, just write crnl
d.state = wstateBeginLine
if err = bw.WriteByte('\n'); err != nil {
return n, err
}
} else {
// line is ending
d.state = wstateBeginLine
if pLen >= 2 && p[pLen-2] == '\r' {
// fastpath if line ending is correct \r\n
if _, err = bw.Write(p); err != nil {
return n, err
}
} else {
// data + crnl
if _, err = bw.Write(p[:pLen-1]); err != nil {
return n, err
}
if _, err = bw.Write(crnl); err != nil {
return n, err
}
}
}
n += pLen
}
return n, err
}
func (d *dotWriter) Close() error {
bw := d.W
switch d.state {
default:
if err := bw.WriteByte('\r'); err != nil {
return err
}
fallthrough
case wstateCR:
// normally \r gets ignored if no \n follows, but at closing we just take it as a line break
// same behavior as original textproto
if err := bw.WriteByte('\n'); err != nil {
return err
}
fallthrough
case wstateBeginLine:
if _, err := bw.Write(dotcrnl); err != nil {
return err
}
}
return bw.Flush()
}
package textsmtp
import (
"bufio"
"errors"
"fmt"
"io"
"net/textproto"
"strconv"
"strings"
)
// Textproto is used as a wrapper around a connection to read and write to it
type Textproto struct {
R *bufio.Reader
W *bufio.Writer
conn io.ReadWriteCloser
maxLineLength int
lineLengthExceeded bool
textproto.Pipeline
}
// NewTextproto creates a new connection wrapper.
func NewTextproto(
conn io.ReadWriteCloser,
readerSize int,
writerSize int,
maxLineLength int,
) *Textproto {
if readerSize == 0 {
readerSize = 4096 // default
}
if writerSize == 0 {
writerSize = 4096 // default
}
return &Textproto{
R: bufio.NewReaderSize(conn, readerSize),
W: bufio.NewWriterSize(conn, writerSize),
conn: conn,
maxLineLength: maxLineLength,
lineLengthExceeded: false,
}
}
// ErrTooLongLine occurs if the smtp line is too long.
var ErrTooLongLine = errors.New("smtp: too long a line in input stream")
// Cmd is a convenience method that sends a command after
// waiting its turn in the pipeline. The command text is the
// result of formatting format with args and appending \r\n.
// Cmd returns the id of the command, for use with StartResponse and EndResponse.
//
// For example, a client might run a HELP command that returns a dot-body
// by using:
//
// id, err := c.Cmd("HELP")
// if err != nil {
// return nil, err
// }
//
// c.StartResponse(id)
// defer c.EndResponse(id)
//
// if _, _, err = c.ReadCodeLine(110); err != nil {
// return nil, err
// }
// text, err := c.ReadDotBytes()
// if err != nil {
// return nil, err
// }
// return c.ReadCodeLine(250)
func (t *Textproto) Cmd(format string, args ...any) (id uint, err error) {
id = t.Next()
t.StartRequest(id)
err = t.PrintfLineAndFlush(format, args...)
t.EndRequest(id)
if err != nil {
return 0, err
}
return id, nil
}
// PrintfLine writes the formatted output followed by \r\n.
func (t *Textproto) PrintfLine(format string, args ...any) error {
if _, err := fmt.Fprintf(t.W, format, args...); err != nil {
return err
}
_, err := t.W.Write(crnl)
return err
}
// PrintfLineAndFlush writes the formatted output followed by \r\n anf flushes.
func (t *Textproto) PrintfLineAndFlush(format string, args ...any) error {
err := t.PrintfLine(format, args...)
if err == nil {
err = t.W.Flush()
}
return err
}
// ReadResponse reads a multi-line response of the form:
//
// code-message line 1
// code-message line 2
// ...
// code message line n
//
// where code is a three-digit status code. The first line starts with the
// code and a hyphen. The response is terminated by a line that starts
// with the same code followed by a space. Each line in message is
// separated by a newline (\n).
//
// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for
// details of another form of response accepted:
//
// code-message line 1
// message line 2
// ...
// code message line n
//
// If the prefix of the status does not match the digits in expectCode,
// ReadResponse returns with err set to &Error{code, message}.
// For example, if expectCode is 31, an error will be returned if
// the status is not in the range [310,319].
//
// An expectCode <= 0 disables the check of the status code.
func (t *Textproto) ReadResponse(expectCode int) (code int, message string, err error) {
code, continued, message, err := t.readCodeLine(expectCode)
multi := continued
for continued {
line, err := t.ReadLine()
if err != nil {
return 0, "", err
}
var code2 int
var moreMessage string
code2, continued, moreMessage, err = parseCodeLine(line, 0)
if err != nil || code2 != code {
message += "\n" + strings.TrimRight(line, "\r\n")
continued = true
continue
}
message += "\n" + moreMessage
}
if err != nil && multi && message != "" {
// replace one line error message with all lines (full message)
err = &textproto.Error{Code: code, Msg: message}
}
return code, message, err
}
// ReadCodeLine reads a code line.
func (t *Textproto) ReadCodeLine(expectCode int) (int, string, error) {
code, continued, message, err := t.readCodeLine(expectCode)
if err == nil && continued {
err = textproto.ProtocolError("unexpected multi-line response: " + message)
}
return code, message, err
}
func (t *Textproto) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
line, err := t.ReadLine()
if err != nil {
return code, continued, message, err
}
return parseCodeLine(line, expectCode)
}
func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
err = textproto.ProtocolError("short response: " + line)
return code, continued, message, err
}
continued = line[3] == '-'
code, err = strconv.Atoi(line[0:3])
if err != nil || code < 100 {
err = textproto.ProtocolError("invalid response code: " + line)
return code, continued, message, err
}
message = line[4:]
if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
100 <= expectCode && expectCode < 1000 && code != expectCode {
err = &textproto.Error{Code: code, Msg: message}
}
return code, continued, message, err
}
// ReadLine reads a single line from r,
// eliding the final \n or \r\n from the returned string.
func (t *Textproto) ReadLine() (string, error) {
line, err := t.readLineSlice()
return string(line), err
}
func (t *Textproto) readLineSlice() ([]byte, error) {
// If the line limit was exceeded once, the connection shouldn't be used anymore.
if t.lineLengthExceeded {
return nil, ErrTooLongLine
}
var line []byte
for {
l, more, err := t.R.ReadLine()
if err != nil {
return nil, err
}
if t.maxLineLength > 0 && len(l)+len(line) > t.maxLineLength {
t.lineLengthExceeded = true
return nil, ErrTooLongLine
}
// Avoid the copy if the first call produced a full line.
if line == nil && !more {
return l, nil
}
line = append(line, l...)
if !more {
break
}
}
return line, nil
}
// Replace conn.
func (t *Textproto) Replace(conn io.ReadWriteCloser) {
t.conn = conn
t.R.Reset(t.conn)
t.W.Reset(t.conn)
}
// Close closes the connection.
func (t *Textproto) Close() error {
return t.conn.Close()
}
package textsmtp
import (
"errors"
"github.com/uponusolutions/go-smtp"
)
// IsPrintableASCII checks if string contains only printable ascii.
func IsPrintableASCII(val string) bool {
for _, ch := range val {
if ch < ' ' || '~' < ch {
return false
}
}
return true
}
// CheckNotifySet checks if a DSNNotify array isn't malformed.
func CheckNotifySet(values []smtp.DSNNotify) error {
if len(values) == 0 {
return errors.New("malformed NOTIFY parameter value")
}
seen := map[smtp.DSNNotify]struct{}{}
for _, val := range values {
switch val {
case smtp.DSNNotifyNever, smtp.DSNNotifyDelayed, smtp.DSNNotifyFailure, smtp.DSNNotifySuccess:
if _, ok := seen[val]; ok {
return errors.New("malformed NOTIFY parameter value")
}
default:
return errors.New("malformed NOTIFY parameter value")
}
seen[val] = struct{}{}
}
if _, ok := seen[smtp.DSNNotifyNever]; ok && len(seen) > 1 {
return errors.New("malformed NOTIFY parameter value")
}
return nil
}
package smtp
import (
"regexp"
)
/*
https://github.com/moisseev/rspamd/blob/master/rules/misc.lua
Detect PRVS/BATV addresses to avoid FORGED_SENDER
https://en.wikipedia.org/wiki/Bounce_Address_Tag_Validation
Signature syntax:
prvs=TAG=USER@example.com BATV draft (https://tools.ietf.org/html/draft-levine-smtp-batv-01)
prvs=USER=TAG@example.com
btv1==TAG==USER@example.com Barracuda appliance
msprvs1=TAG=USER@example.com Sparkpost email delivery service
*/
const (
regexpBATV = "^(?:(?:prvs|msprvs1)=[^=]+=|btv1==[^=]+==)([^@]+@(?:[^@]+))$"
regexpSRS = "^([^+]+)\\+SRS=[^=]+=[^=]+=[^=]+=[^@]+@([^@]+)$"
)
var (
compiledRegexpBATV = regexp.MustCompile(regexpBATV)
compiledRegexpSRS = regexp.MustCompile(regexpSRS)
)
// ParseBATV parses src to extract a BATV address.
// When BATV extration is not possible/needed src is returned.
func ParseBATV(src string) string {
res := compiledRegexpBATV.FindStringSubmatch(src)
if len(res) == 2 {
return res[1]
}
return src
}
// ParseSRS parses src to extract the forwarding sender from SRS (Exchange Online).
// When SRS extration is not possible/needed src is returned.
func ParseSRS(src string) string {
res := compiledRegexpSRS.FindStringSubmatch(src)
if len(res) == 3 {
return res[1] + "@" + res[2]
}
return src
}
// ParseSender combines ParseSRS and ParseBATV.
func ParseSender(src string) string {
return ParseBATV(ParseSRS(src))
}
package server
import (
"context"
"crypto/tls"
"io"
"log/slog"
"github.com/uponusolutions/go-sasl"
"github.com/uponusolutions/go-smtp"
)
// Backend is a SMTP server backend.
type Backend interface {
NewSession(ctx context.Context, c *Conn) (context.Context, Session, error)
}
// BackendFunc is an adapter to allow the use of an ordinary function as a
// Backend.
type BackendFunc func(ctx context.Context, c *Conn) (context.Context, Session, error)
// NewSession calls f(c).
// The returning context is used in the session.
func (f BackendFunc) NewSession(ctx context.Context, c *Conn) (context.Context, Session, error) {
return f(ctx, c)
}
// Session is used by servers to respond to an SMTP client.
//
// The methods are called when the remote client issues the matching command.
type Session interface {
// Discard currently processed message.
// The returning context replaces the context used in the current session.
// Upgrade is true when the reset is called after a tls upgrade.
Reset(ctx context.Context, upgrade bool) (context.Context, error)
// Free all resources associated with session.
// Error is set if an error occurred during session or connection.
// Close is always called after the session is done.
Close(ctx context.Context, err error)
// Returns logger to use when an error occurs inside a session.
// If no logger is returned the default *slog.Logger is used.
Logger(ctx context.Context) *slog.Logger
// Set return path for currently processed message.
Mail(ctx context.Context, from string, opts *smtp.MailOptions) error
// Add recipient for currently processed message.
Rcpt(ctx context.Context, to string, opts *smtp.RcptOptions) error
// Verify checks the validity of an email address on the server.
// If error is nil then smtp code 252 is send
// if error is smtp status then the smtp status is send
// else internal server error is returned and connection is closed
Verify(ctx context.Context, addr string, opts *smtp.VrfyOptions) error
// Set currently processed message contents and send it.
// If r is called then the data must be consumed completely before returning.
// The queuedid must not be unique.
Data(ctx context.Context, r func() io.Reader) (queueid string, err error)
// AuthMechanisms returns valid auth mechanism.
// Nil or an empty list means no authentication mechanism is allowed.
AuthMechanisms(ctx context.Context) []string
// Auth returns a matching sasl server for the given mech.
Auth(ctx context.Context, mech string) (sasl.Server, error)
// STARTTLS returns a valid *tls.Config.
// Is called with the default tls config and the returned tls config is used in the tls upgrade.
// If the tls.Config is nil or an error is returned, the tls upgrade is aborted and the connection closed.
// The *tls.Config received must not be changed.
STARTTLS(ctx context.Context, tls *tls.Config) (*tls.Config, error)
}
package server
import (
"io"
"strconv"
"strings"
"github.com/uponusolutions/go-smtp"
)
type bdat struct {
size int64
last bool
bytesReceived int64
maxMessageBytes int64
input io.Reader
chunk io.Reader
nextCommand func() (string, string, error)
}
func bdatArg(arg string) (int64, bool, error) {
args := strings.Fields(arg)
if len(args) == 0 {
return 0, true, smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Missing chunk size argument")
}
if len(args) > 2 {
return 0, true, smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Too many arguments")
}
last := false
if len(args) == 2 {
if !strings.EqualFold(args[1], "LAST") {
return 0, true, smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Unknown BDAT argument")
}
last = true
}
// ParseUint instead of Atoi so we will not accept negative values.
size, err := strconv.ParseUint(args[0], 10, 32)
if err != nil || (size == 0 && !last) {
return 0, true, smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Malformed size argument")
}
return int64(size), last, nil
}
func (d *bdat) Read(b []byte) (int, error) {
if d.size == 0 {
if d.last {
return 0, io.EOF
}
d.chunk = nil
cmd, arg, err := d.nextCommand()
if err != nil {
if err == io.EOF {
return 0, smtp.ErrConnection
}
return 0, err
}
switch cmd {
case "RSET":
return 0, smtp.Reset
case "QUIT":
return 0, smtp.Quit
case "BDAT":
d.size, d.last, err = bdatArg(arg)
if err != nil {
return 0, err
}
if d.last && d.size == 0 {
return 0, io.EOF
}
default:
return 0, smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "BDAT command expected")
}
}
if d.maxMessageBytes != 0 && d.bytesReceived+d.size > d.maxMessageBytes {
return 0, smtp.NewStatus(552, smtp.EnhancedCode{5, 3, 4}, "Max message size exceeded")
}
if d.chunk == nil {
d.chunk = io.LimitReader(d.input, int64(d.size))
}
n, err := d.chunk.Read(b)
d.bytesReceived += int64(n)
d.size -= int64(n)
// this isn't the end
if err == io.EOF && !d.last {
// stream broke in the middle
if d.size > 0 {
err = smtp.ErrConnection
} else {
err = nil
}
}
return n, err
}
package server
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"net"
"slices"
"strconv"
"strings"
"time"
"github.com/uponusolutions/go-smtp"
"github.com/uponusolutions/go-smtp/internal/parse"
"github.com/uponusolutions/go-smtp/internal/textsmtp"
)
type state int32
const (
stateInit state = 0
stateUpgrade state = 1
stateEnforceAuthentication state = 2
stateEnforceSecureConnection state = 3
stateGreeted state = 4
stateMail state = 5
)
// Conn is a connection inside a smtp server.
type Conn struct {
ctx context.Context
conn net.Conn
state state
text *textsmtp.Textproto
server *Server
session Session
binarymime bool
helo string // set in helo / ehlo
mechanisms []string // seh in helo / ehlo
recipients int // count recipients
didAuth bool
}
// run loops until an error occurs (quit for example)
func (c *Conn) run() error {
c.greet()
for {
cmd, arg, err := c.nextCommand()
if err != nil {
return err
}
err = c.handle(cmd, arg)
if err != nil {
// if error is a smtp status it isn't necessary to close the connection
if smtpErr, ok := err.(*smtp.Status); ok {
// Service closing transmission channel, after quit
if smtpErr.Code == 221 {
return smtpErr
}
// ToDo: close connection on repeated errors (e.g. authentication tries)
c.writeStatus(smtpErr)
continue
}
return err
}
}
}
func (c *Conn) nextCommand() (cmd string, arg string, err error) {
line, err := c.readLine()
if err != nil {
return "", "", err
}
return parse.Cmd(line)
}
// Commands are dispatched to the appropriate handler functions.
func (c *Conn) handle(cmd string, arg string) error {
if cmd == "" {
return smtp.NewStatus(500, smtp.EnhancedCode{5, 5, 2}, "Error: bad syntax")
}
cmd = strings.ToUpper(cmd)
switch c.state {
case stateInit:
fallthrough
case stateUpgrade:
return c.handleStateInit(cmd, arg)
case stateEnforceSecureConnection:
return c.handleStateEnforceSecureConnection(cmd, arg)
case stateEnforceAuthentication:
return c.handleStateEnforceAuthentication(cmd, arg)
case stateGreeted:
return c.handleStateGreeted(cmd, arg)
case stateMail:
return c.handleStateMail(cmd, arg)
}
return fmt.Errorf("unsupported state %d, how?", c.state)
}
func (c *Conn) handleStateInit(cmd string, arg string) error {
switch cmd {
case "HELO", "EHLO":
return c.handleGreet(cmd == "EHLO", arg)
case "NOOP":
return smtp.Noop
case "VRFY":
return c.handleVrfy(arg)
case "RSET": // Reset session
return c.handleRSET()
case "QUIT":
return smtp.Quit
default:
return c.commandUnknown(cmd)
}
}
func (c *Conn) handleStateEnforceAuthentication(cmd string, arg string) error {
switch cmd {
case "HELO", "EHLO":
return c.handleGreet(cmd == "EHLO", arg)
case "NOOP":
return smtp.Noop
case "VRFY":
return c.handleVrfy(arg)
case "RSET": // Reset session
return c.handleRSET()
case "QUIT":
return smtp.Quit
case "AUTH":
// there is always a mechanism, as it is an enforce authentication precondition
return c.handleAuth(arg)
case "STARTTLS":
return c.handleStartTLS()
default:
return smtp.NewStatus(530, smtp.EnhancedCode{5, 7, 0}, "Authentication required")
}
}
func (c *Conn) handleStateGreeted(cmd string, arg string) error {
switch cmd {
case "HELO", "EHLO":
return c.handleGreet(cmd == "EHLO", arg)
case "MAIL":
return c.handleMail(arg)
case "NOOP":
return smtp.Noop
case "VRFY":
return c.handleVrfy(arg)
case "RSET": // Reset session
return c.handleRSET()
case "QUIT":
return smtp.Quit
case "AUTH":
if len(c.mechanisms) > 0 {
return c.handleAuth(arg)
}
return smtp.ErrAuthUnsupported
case "STARTTLS":
return c.handleStartTLS()
default:
return c.commandUnknown(cmd)
}
}
func (c *Conn) handleStateMail(cmd string, arg string) error {
switch cmd {
case "HELO", "EHLO":
return c.handleGreet(cmd == "EHLO", arg)
case "RCPT":
return c.handleRcpt(arg)
case "NOOP":
return smtp.Noop
case "VRFY":
return c.handleVrfy(arg)
case "RSET": // Reset session
return c.handleRSET()
case "BDAT":
if !c.server.enableCHUNKING {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "CHUNKING is not implemented")
}
return c.handleBdat(arg)
case "DATA":
return c.handleData(arg)
case "QUIT":
return smtp.Quit
case "STARTTLS":
return c.handleStartTLS()
default:
return c.commandUnknown(cmd)
}
}
func (c *Conn) handleStateEnforceSecureConnection(cmd string, arg string) error {
switch cmd {
case "HELO", "EHLO":
return c.handleGreet(cmd == "EHLO", arg)
case "NOOP":
return smtp.Noop
case "VRFY":
return c.handleVrfy(arg)
case "STARTTLS":
return c.handleStartTLS()
case "QUIT":
return smtp.Quit
default:
return smtp.NewStatus(530, smtp.EnhancedCode{5, 7, 0}, "Must issue a STARTTLS command first")
}
}
func (c *Conn) commandUnknown(cmd string) *smtp.Status {
return smtp.NewStatus(502, smtp.EnhancedCode{5, 5, 1}, fmt.Sprintf("%s command unknown, state %d", cmd, c.state))
}
// Server returns the server this connection comes from.
func (c *Conn) Server() *Server {
return c.server
}
// Close closes the connection.
func (c *Conn) Close(err error) {
c.logger().DebugContext(c.ctx, "connection is closing")
closeErr := c.conn.Close()
if closeErr != nil {
if err == nil {
err = closeErr
} else {
err = errors.Join(err, closeErr)
}
}
if err != nil {
c.logger().ErrorContext(c.ctx, "close error", slog.Any("err", err))
}
if c.session != nil {
c.session.Close(c.ctx, err)
c.session = nil
}
}
// TLSConnectionState returns the connection's TLS connection state.
// Zero values are returned if the connection doesn't use TLS.
func (c *Conn) TLSConnectionState() (tls.ConnectionState, bool) {
tc, ok := c.conn.(*tls.Conn)
if !ok {
return tls.ConnectionState{}, ok
}
return tc.ConnectionState(), true
}
// IsTLS returns if the connection is encrypted by tls.
func (c *Conn) IsTLS() bool {
_, ok := c.conn.(*tls.Conn)
return ok
}
// Hostname returns the name of the connected client.
func (c *Conn) Hostname() string {
return c.helo
}
// Mechanisms returns the allowed auth mechanism for this connection.
func (c *Conn) Mechanisms() []string {
return c.mechanisms
}
// Conn returns the connection.
func (c *Conn) Conn() net.Conn {
return c.conn
}
func (c *Conn) handleRSET() error {
err := c.reset()
if err != nil {
return err
}
c.writeStatus(smtp.Reset)
return nil
}
// GREET state -> waiting for HELO
func (c *Conn) handleGreet(enhanced bool, arg string) error {
domain, err := parse.HelloArgument(arg)
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 2}, "Domain/address argument required for HELO")
}
// c.helo is populated before NewSession so
// NewSession can access it via Conn.Hostname.
c.helo = domain
// RFC 5321: "An EHLO command MAY be issued by a client later in the session"
// RFC 5321: "... the SMTP server MUST clear all buffers
// and reset the state exactly as if a RSET command has been issued."
if c.state != stateInit && c.state != stateEnforceSecureConnection && c.state != stateEnforceAuthentication {
err := c.reset()
if err != nil {
return err
}
}
if c.server.enforceSecureConnection && !c.IsTLS() {
c.state = stateEnforceSecureConnection
} else if c.server.enforceAuthentication {
c.state = stateEnforceAuthentication
} else {
c.state = stateGreeted
}
if !enhanced {
return smtp.NewStatus(250, smtp.EnhancedCode{2, 0, 0}, fmt.Sprintf("Hello %s", domain))
}
caps := strings.Builder{}
caps.Grow(512)
caps.WriteString("Hello ")
caps.WriteString(domain)
caps.WriteString("\nPIPELINING\n8BITMIME\nENHANCEDSTATUSCODES")
if c.server.enableCHUNKING {
caps.WriteString("\nCHUNKING")
}
isTLS := c.IsTLS()
if !isTLS && c.server.tlsConfig != nil {
caps.WriteString("\nSTARTTLS")
}
c.mechanisms = c.session.AuthMechanisms(c.ctx)
if len(c.mechanisms) > 0 {
caps.WriteString("\nAUTH")
for _, name := range c.mechanisms {
caps.WriteByte(' ')
caps.WriteString(name)
}
} else if c.server.enforceAuthentication {
// without any auth mechanism, no authentication can happen => deadlock
return c.newStatusError(451, smtp.EnhancedCode{4, 0, 0}, "No auth mechanism available but authentication enforced", err)
}
if c.server.enableSMTPUTF8 {
caps.WriteString("\nSMTPUTF8")
}
if isTLS && c.server.enableREQUIRETLS {
caps.WriteString("\nREQUIRETLS")
}
if c.server.enableBINARYMIME {
caps.WriteString("\nBINARYMIME")
}
if c.server.enableDSN {
caps.WriteString("\nDSN")
}
if c.server.enableXOORG {
caps.WriteString("\nXOORG")
}
if c.server.maxMessageBytes > 0 {
caps.WriteString(fmt.Sprintf("\nSIZE %v", c.server.maxMessageBytes))
} else {
caps.WriteString("\nSIZE")
}
if c.server.maxRecipients > 0 {
caps.WriteString(fmt.Sprintf("\nLIMITS RCPTMAX=%v", c.server.maxRecipients))
}
return smtp.NewStatus(250, smtp.NoEnhancedCode, caps.String())
}
// handleError handles error and closes the connection afterwards.
func (c *Conn) handleError(err error) {
if err == io.EOF || errors.Is(err, net.ErrClosed) {
c.Close(fmt.Errorf("connection closed unexpectedly: %w", err))
return
}
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
c.writeResponse(421, smtp.EnhancedCode{4, 4, 2}, "Idle timeout, bye bye")
c.Close(fmt.Errorf("idle timeout: %w", err))
return
}
if smtpErr, ok := err.(*smtp.Status); ok {
c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message)
if smtpErr.Code != 221 {
c.Close(fmt.Errorf("smtp error: %w", err))
} else {
c.Close(nil)
}
return
}
if err == textsmtp.ErrTooLongLine {
c.writeResponse(500, smtp.EnhancedCode{5, 4, 0}, "Too long line")
c.Close(errors.New("line too long"))
return
}
c.writeStatus(smtp.ErrConnection)
c.Close(fmt.Errorf("unknown error: %w", err))
}
func (c *Conn) logger() *slog.Logger {
// Fallback if the connection couldn't be created or is already closed.
if c.session == nil {
return slog.Default()
}
logger := c.session.Logger(c.ctx)
if logger == nil {
return slog.Default()
}
return logger
}
// READY state -> waiting for MAIL
// nolint: revive
func (c *Conn) handleMail(arg string) error {
arg, ok := parse.CutPrefixFold(arg, "FROM:")
if !ok {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 2}, "Was expecting MAIL arg syntax of FROM:<address>")
}
p := parse.Parser{S: strings.TrimSpace(arg)}
from, err := p.ReversePath()
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 2}, "Was expecting MAIL arg syntax of FROM:<address>")
}
args, err := parse.Args(p.S)
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Unable to parse MAIL ESMTP parameters")
}
opts := &smtp.MailOptions{}
c.binarymime = false
// This is where the Conn may put BODY=8BITMIME, but we already
// read the DATA as bytes, so it does not effect our processing.
for key, value := range args {
switch key {
case "SIZE":
size, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Unable to parse SIZE as an integer")
}
if c.server.maxMessageBytes > 0 && int64(size) > c.server.maxMessageBytes {
return smtp.NewStatus(552, smtp.EnhancedCode{5, 3, 4}, "Max message size exceeded")
}
opts.Size = int64(size)
case "XOORG":
value, err := decodeXtext(value)
if err != nil || value == "" {
return smtp.NewStatus(500, smtp.EnhancedCode{5, 5, 4}, "Malformed XOORG parameter value")
}
if !c.server.enableXOORG {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "EnableXOORG is not implemented")
}
opts.XOORG = &value
case "SMTPUTF8":
if !c.server.enableSMTPUTF8 {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "SMTPUTF8 is not implemented")
}
opts.UTF8 = true
case "REQUIRETLS":
if !c.server.enableREQUIRETLS {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "REQUIRETLS is not implemented")
}
opts.RequireTLS = true
case "BODY":
value = strings.ToUpper(value)
switch smtp.BodyType(value) {
case smtp.BodyBinaryMIME:
if !c.server.enableBINARYMIME {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "BINARYMIME is not implemented")
}
c.binarymime = true
case smtp.Body7Bit, smtp.Body8BitMIME:
// This space is intentionally left blank
default:
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Unknown BODY value")
}
opts.Body = smtp.BodyType(value)
case "RET":
if !c.server.enableDSN {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "RET is not implemented")
}
value = strings.ToUpper(value)
switch smtp.DSNReturn(value) {
case smtp.DSNReturnFull, smtp.DSNReturnHeaders:
// This space is intentionally left blank
default:
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Unknown RET value")
}
opts.Return = smtp.DSNReturn(value)
case "ENVID":
if !c.server.enableDSN {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "ENVID is not implemented")
}
value, err := decodeXtext(value)
if err != nil || value == "" || !textsmtp.IsPrintableASCII(value) {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Malformed ENVID parameter value")
}
opts.EnvelopeID = value
case "AUTH":
value, err := decodeXtext(value)
if err != nil || value == "" {
return smtp.NewStatus(500, smtp.EnhancedCode{5, 5, 4}, "Malformed AUTH parameter value")
}
if value == "<>" {
value = ""
} else {
p := parse.Parser{S: value}
value, err = p.Mailbox()
if err != nil || p.S != "" {
return smtp.NewStatus(500, smtp.EnhancedCode{5, 5, 4}, "Malformed AUTH parameter mailbox")
}
}
opts.Auth = &value
default:
return smtp.NewStatus(500, smtp.EnhancedCode{5, 5, 4}, "Unknown MAIL FROM argument")
}
}
if err := c.session.Mail(c.ctx, from, opts); err != nil {
if smtpErr, ok := err.(*smtp.Status); ok {
// a positive response also counts as a success
if smtpErr.Positive() {
c.state = stateMail
}
return smtpErr
}
return c.newStatusError(451, smtp.EnhancedCode{4, 0, 0}, "Mail not accepted", err)
}
c.state = stateMail
return smtp.NewStatus(250, smtp.EnhancedCode{2, 0, 0}, fmt.Sprintf("Roger, accepting mail from <%v>", from))
}
// MAIL state -> waiting for RCPTs followed by DATA
func (c *Conn) handleRcpt(arg string) error {
arg, ok := parse.CutPrefixFold(arg, "TO:")
if !ok {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 2}, "Was expecting RCPT arg syntax of TO:<address>")
}
p := parse.Parser{S: strings.TrimSpace(arg)}
recipient, err := p.Path()
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 2}, "Was expecting RCPT arg syntax of TO:<address>")
}
if c.server.maxRecipients > 0 && c.recipients >= c.server.maxRecipients {
return smtp.NewStatus(452, smtp.EnhancedCode{4, 5, 3}, fmt.Sprintf("Maximum limit of %v recipients reached", c.server.maxRecipients))
}
args, err := parse.Args(p.S)
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Unable to parse RCPT ESMTP parameters")
}
opts := &smtp.RcptOptions{}
for key, value := range args {
switch key {
case "NOTIFY":
if !c.server.enableDSN {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "NOTIFY is not implemented")
}
notify := []smtp.DSNNotify{}
for _, val := range strings.Split(value, ",") {
notify = append(notify, smtp.DSNNotify(strings.ToUpper(val)))
}
if err := textsmtp.CheckNotifySet(notify); err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Malformed NOTIFY parameter value")
}
opts.Notify = notify
case "ORCPT":
if !c.server.enableDSN {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "ORCPT is not implemented")
}
aType, aAddr, err := decodeTypedAddress(value)
if err != nil || aAddr == "" {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Malformed ORCPT parameter value")
}
opts.OriginalRecipientType = aType
opts.OriginalRecipient = aAddr
default:
return smtp.NewStatus(500, smtp.EnhancedCode{5, 5, 4}, "Unknown RCPT TO argument")
}
}
if err := c.session.Rcpt(c.ctx, recipient, opts); err != nil {
if smtpErr, ok := err.(*smtp.Status); ok {
// a positive response also counts as a success
if smtpErr.Positive() {
c.recipients++
}
return smtpErr
}
return c.newStatusError(451, smtp.EnhancedCode{4, 0, 0}, "Recipient not accepted", err)
}
c.recipients++
return smtp.NewStatus(250, smtp.EnhancedCode{2, 0, 0}, fmt.Sprintf("I'll make sure <%v> gets this", recipient))
}
func (c *Conn) handleVrfy(arg string) error {
p := parse.Parser{S: strings.TrimSpace(arg)}
vrfy, err := p.Path()
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 2}, "Was expecting <address>")
}
args, err := parse.Args(p.S)
if err != nil {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "Unable to parse VRFY ESMTP parameters")
}
opts := &smtp.VrfyOptions{}
for key := range args {
if key == "SMTPUTF8" {
if !c.server.enableSMTPUTF8 {
return smtp.NewStatus(504, smtp.EnhancedCode{5, 5, 4}, "SMTPUTF8 is not implemented")
}
opts.UTF8 = true
}
}
res := c.session.Verify(c.ctx, vrfy, opts)
if res == nil {
return smtp.VRFY
}
return res
}
func (c *Conn) handleAuth(arg string) error {
if c.didAuth {
return smtp.NewStatus(503, smtp.EnhancedCode{5, 5, 1}, "Already authenticated")
}
parts := strings.Fields(arg)
if len(parts) == 0 {
return smtp.NewStatus(502, smtp.EnhancedCode{5, 5, 4}, "Missing parameter")
}
mechanism := strings.ToUpper(parts[0])
// Is mechanism allowed?
if !slices.Contains(c.mechanisms, mechanism) {
return smtp.NewStatus(502, smtp.EnhancedCode{5, 5, 4}, "Invalid mechanism")
}
// Parse client initial response if there is one
var ir []byte
if len(parts) > 1 {
var err error
ir, err = decodeSASLResponse(parts[1])
if err != nil {
return smtp.NewStatus(454, smtp.EnhancedCode{4, 7, 0}, "Invalid base64 data")
}
}
sasl, err := c.session.Auth(c.ctx, mechanism)
if err != nil {
return c.newStatusError(454, smtp.EnhancedCode{4, 7, 0}, "Authentication failed", err)
}
if sasl == nil {
return c.newStatusError(451, smtp.EnhancedCode{4, 0, 0}, "No auth handler received, but mechanism seems valid.", err)
}
response := ir
for {
challenge, done, err := sasl.Next(response)
if err != nil {
return c.newStatusError(454, smtp.EnhancedCode{4, 7, 0}, "Authentication failed", err)
}
if done {
break
}
encoded := ""
if len(challenge) > 0 {
encoded = base64.StdEncoding.EncodeToString(challenge)
}
c.writeResponse(334, smtp.NoEnhancedCode, encoded)
encoded, err = c.readLine()
if err != nil {
return err
}
if encoded == "*" {
// https://tools.ietf.org/html/rfc4954#page-4
return smtp.NewStatus(501, smtp.EnhancedCode{5, 0, 0}, "Negotiation cancelled")
}
response, err = decodeSASLResponse(encoded)
if err != nil {
return smtp.NewStatus(454, smtp.EnhancedCode{4, 7, 0}, "Invalid base64 data")
}
}
c.didAuth = true
if c.state == stateEnforceAuthentication {
c.state = stateGreeted
}
return smtp.NewStatus(235, smtp.EnhancedCode{2, 0, 0}, "Authentication succeeded")
}
func (c *Conn) handleStartTLS() error {
if _, isTLS := c.TLSConnectionState(); isTLS {
return smtp.NewStatus(502, smtp.EnhancedCode{5, 5, 1}, "Already running in TLS")
}
if c.server.tlsConfig == nil {
return smtp.NewStatus(502, smtp.EnhancedCode{5, 5, 1}, "TLS not supported")
}
// allow the session to change tlsConfig
tlsConfig, err := c.session.STARTTLS(c.ctx, c.server.tlsConfig)
if err != nil {
return c.newStatusError(451, smtp.EnhancedCode{4, 0, 0}, "TLS config retrieval failed", err)
}
if tlsConfig == nil {
return smtp.NewStatus(451, smtp.EnhancedCode{4, 0, 0}, "TLS config retrieval nil returned")
}
c.writeResponse(220, smtp.EnhancedCode{2, 0, 0}, "Ready to start TLS")
// Upgrade to TLS
tlsConn := tls.Server(c.conn, tlsConfig)
if err := tlsConn.HandshakeContext(c.ctx); err != nil {
c.logger().ErrorContext(c.ctx, "handleStartTLS", slog.Any("err", err))
return smtp.NewStatus(550, smtp.EnhancedCode{5, 0, 0}, "Handshake error")
}
c.conn = tlsConn
c.text.Replace(tlsConn)
c.state = stateUpgrade // same as StateInit but calls logout/reset on ehlo/helo
return nil
}
// DATA
func (c *Conn) handleData(arg string) error {
// at least a single recipient needs to be set
if c.recipients == 0 {
return smtp.ErrNoRecipients
}
if arg != "" {
return smtp.NewStatus(501, smtp.EnhancedCode{5, 5, 4}, "DATA command should not have any arguments")
}
if c.binarymime {
return smtp.NewStatus(502, smtp.EnhancedCode{5, 5, 1}, "DATA not allowed for BINARYMIME messages")
}
var r io.Reader
rstart := func() io.Reader {
if r != nil {
return r
}
// We have recipients, go to accept data
c.writeResponse(354, smtp.NoEnhancedCode, "Go ahead. End your data with <CR><LF>.<CR><LF>")
r := textsmtp.NewDotReader(c.text.R, c.server.maxMessageBytes)
return r
}
uuid, err := c.session.Data(c.ctx, rstart)
if err != nil {
// an error which isn't a SMTPStatus error will always terminate the connection
// if it is an SMTPStatus then wi need to make sure the stream ist read to the end
if _, ok := err.(*smtp.Status); ok && r != nil {
_, _ = io.Copy(io.Discard, r)
}
return err
}
// Make sure all the data has been consumed
if r != nil {
_, _ = io.Copy(io.Discard, r)
}
if err = c.reset(); err != nil {
return err
}
return c.accepted(uuid)
}
func (c *Conn) handleBdat(arg string) error {
// at least a single recipient needs to be set
if c.recipients == 0 {
return smtp.ErrNoRecipients
}
closed := false
size, last, err := bdatArg(arg)
if err != nil {
return err
}
data := &bdat{
maxMessageBytes: c.server.maxMessageBytes,
size: size,
last: last,
bytesReceived: 0,
input: c.text.R,
nextCommand: func() (string, string, error) {
// if bdat is closed (error occurred)
if closed {
return "", "", io.EOF
}
c.writeResponse(250, smtp.EnhancedCode{2, 0, 0}, "Continue")
return c.nextCommand()
},
}
queueid, err := c.session.Data(c.ctx, func() io.Reader {
return data
})
if err != nil {
if smtpErr, ok := err.(*smtp.Status); ok {
// write down error
c.writeStatus(smtpErr)
// read anything left to continue after this failure, ignore any read error
// https://www.rfc-editor.org/rfc/rfc3030.html
// If a 5XX or 4XX code is received by the sender-SMTP in response to a BDAT
// chunk, the transaction should be considered failed and the sender-
// SMTP MUST NOT send any additional BDAT segments. If the receiver-
// SMTP has declared support for command pipelining [PIPE], the receiver
// SMTP MUST be prepared to accept and discard additional BDAT chunks
// already in the pipeline after the failed BDAT.
closed = true
_, _ = io.Copy(io.Discard, data)
return c.reset()
}
// an error which isn't a SMTPStatus error will always terminate the connection
return err
}
// Make sure all the data has been consumed
_, _ = io.Copy(io.Discard, data)
if err = c.reset(); err != nil {
return err
}
return c.accepted(queueid)
}
func (*Conn) accepted(queueid string) *smtp.Status {
if queueid != "" {
// limit length if queueid is too long (< 1000)
if len(queueid) > 977 {
queueid = queueid[:974] + "..."
}
return smtp.NewStatus(250, smtp.EnhancedCode{2, 0, 0}, "OK: queued as "+queueid)
}
return smtp.NewStatus(250, smtp.EnhancedCode{2, 0, 0}, "OK: queued")
}
func (c *Conn) greet() {
protocol := "ESMTP"
c.writeResponse(220, smtp.NoEnhancedCode, fmt.Sprintf("%v %s Service Ready", c.server.hostname, protocol))
}
func (c *Conn) writeStatus(status *smtp.Status) {
c.writeResponse(status.Code, status.EnhancedCode, status.Message)
}
func (c *Conn) writeResponse(code int, enhCode smtp.EnhancedCode, text string) {
c.logger().DebugContext(c.ctx, "write", slog.Int("code", code), slog.Any("enhCode", enhCode), slog.Any("text", text))
// TODO: error handling
if c.server.writeTimeout != 0 {
_ = c.conn.SetWriteDeadline(time.Now().Add(c.server.writeTimeout))
}
// All responses must include an enhanced code, if it is missing - use
// a generic code X.0.0.
if enhCode == smtp.EnhancedCodeNotSet {
cat := code / 100
switch cat {
case 2, 4, 5:
enhCode = smtp.EnhancedCode{cat, 0, 0}
default:
enhCode = smtp.NoEnhancedCode
}
}
p := 0
for {
i := strings.IndexByte(text[p:], '\n')
if i < 0 {
break
}
_ = c.text.PrintfLine("%d-%v", code, text[p:p+i])
p += i + 1
}
if enhCode == smtp.NoEnhancedCode {
_ = c.text.PrintfLineAndFlush("%d %v", code, text[p:])
} else {
_ = c.text.PrintfLineAndFlush("%d %v.%v.%v %v", code, enhCode[0], enhCode[1], enhCode[2], text[p:])
}
}
func (c *Conn) newStatusError(code int, enhCode smtp.EnhancedCode, msg string, err error) *smtp.Status {
if smtpErr, ok := err.(*smtp.Status); ok {
return smtpErr
}
c.logger().ErrorContext(c.ctx, msg, slog.Any("err", err))
return smtp.NewStatus(code, enhCode, msg)
}
// Reads a line of input
func (c *Conn) readLine() (string, error) {
if c.server.readTimeout != 0 {
_ = c.conn.SetReadDeadline(time.Now().Add(c.server.readTimeout))
}
line, err := c.text.ReadLine()
if err == nil {
c.logger().DebugContext(c.ctx, "read", slog.String("line", line))
}
return line, err
}
func (c *Conn) reset() error {
// Reset state to Greeted
if c.state == stateMail {
c.state = stateGreeted
}
c.recipients = 0
upgrade := c.state == stateUpgrade
// Authentication is only revoked if starttls is used.
if upgrade {
c.didAuth = false
}
ctx, err := c.session.Reset(c.ctx, upgrade)
c.ctx = ctx
return err
}
package server
import (
"encoding/base64"
"errors"
"regexp"
"strconv"
"strings"
"github.com/uponusolutions/go-smtp"
"github.com/uponusolutions/go-smtp/internal/textsmtp"
)
func decodeSASLResponse(s string) ([]byte, error) {
if s == "=" {
return []byte{}, nil
}
return base64.StdEncoding.DecodeString(s)
}
// This regexp matches 'hexchar' token defined in
// https://tools.ietf.org/html/rfc4954#section-8 however it is intentionally
// relaxed by requiring only '+' to be present. It allows us to detect
// malformed values such as +A or +HH and report them appropriately.
var hexcharRe = regexp.MustCompile(`\+[0-9A-F]?[0-9A-F]?`)
func decodeXtext(val string) (string, error) {
if !strings.Contains(val, "+") {
return val, nil
}
var replaceErr error
decoded := hexcharRe.ReplaceAllStringFunc(val, func(match string) string {
if len(match) != 3 {
replaceErr = errors.New("incomplete hexchar")
return ""
}
char, err := strconv.ParseInt(match, 16, 8)
if err != nil {
replaceErr = err
return ""
}
return string(rune(char))
})
if replaceErr != nil {
return "", replaceErr
}
return decoded, nil
}
// This regexp matches 'EmbeddedUnicodeChar' token defined in
// https://datatracker.ietf.org/doc/html/rfc6533.html#section-3
// however it is intentionally relaxed by requiring only '\x{HEX}' to be
// present. It also matches disallowed characters in QCHAR and QUCHAR defined
// in above.
// So it allows us to detect malformed values and report them appropriately.
var eUOrDCharRe = regexp.MustCompile(`\\x[{][0-9A-F]+[}]|[[:cntrl:] \\+=]`)
// Decodes the utf-8-addr-xtext or the utf-8-addr-unitext form.
func decodeUTF8AddrXtext(val string) (string, error) {
var replaceErr error
decoded := eUOrDCharRe.ReplaceAllStringFunc(val, func(match string) string {
if len(match) == 1 {
replaceErr = errors.New("disallowed character:" + match)
return ""
}
hexpoint := match[3 : len(match)-1]
char, err := strconv.ParseUint(hexpoint, 16, 21)
if err != nil {
replaceErr = err
return ""
}
switch len(hexpoint) {
case 2:
switch {
// all xtext-specials
case 0x01 <= char && char <= 0x09 ||
0x11 <= char && char <= 0x19 ||
char == 0x10 || char == 0x20 ||
char == 0x2B || char == 0x3D || char == 0x7F:
// 2-digit forms
case char == 0x5C || 0x80 <= char && char <= 0xFF:
// This space is intentionally left blank
default:
replaceErr = errors.New("illegal hexpoint:" + hexpoint)
return ""
}
// 3-digit forms
case 3:
switch {
case 0x100 <= char && char <= 0xFFF:
// This space is intentionally left blank
default:
replaceErr = errors.New("illegal hexpoint:" + hexpoint)
return ""
}
// 4-digit forms excluding surrogate
case 4:
switch {
case 0x1000 <= char && char <= 0xD7FF:
case 0xE000 <= char && char <= 0xFFFF:
// This space is intentionally left blank
default:
replaceErr = errors.New("illegal hexpoint:" + hexpoint)
return ""
}
// 5-digit forms
case 5:
switch {
case 0x1_0000 <= char && char <= 0xF_FFFF:
// This space is intentionally left blank
default:
replaceErr = errors.New("illegal hexpoint:" + hexpoint)
return ""
}
// 6-digit forms
case 6:
switch {
case 0x10_0000 <= char && char <= 0x10_FFFF:
// This space is intentionally left blank
default:
replaceErr = errors.New("illegal hexpoint:" + hexpoint)
return ""
}
// the other invalid forms
default:
replaceErr = errors.New("illegal hexpoint:" + hexpoint)
return ""
}
return string(rune(char))
})
if replaceErr != nil {
return "", replaceErr
}
return decoded, nil
}
func decodeTypedAddress(val string) (smtp.DSNAddressType, string, error) {
tv := strings.SplitN(val, ";", 2)
if len(tv) != 2 || tv[0] == "" || tv[1] == "" {
return "", "", errors.New("bad address")
}
aType, aAddr := strings.ToUpper(tv[0]), tv[1]
var err error
switch smtp.DSNAddressType(aType) {
case smtp.DSNAddressTypeRFC822:
aAddr, err = decodeXtext(aAddr)
if err == nil && !textsmtp.IsPrintableASCII(aAddr) {
err = errors.New("illegal address:" + aAddr)
}
case smtp.DSNAddressTypeUTF8:
aAddr, err = decodeUTF8AddrXtext(aAddr)
default:
err = errors.New("unknown address type:" + aType)
}
if err != nil {
return "", "", err
}
return smtp.DSNAddressType(aType), aAddr, nil
}
package server
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
"runtime/debug"
"time"
"github.com/uponusolutions/go-smtp"
"github.com/uponusolutions/go-smtp/internal/textsmtp"
)
// Serve accepts incoming connections on the Listener l.
func (s *Server) Serve(ctx context.Context, l net.Listener) error {
s.locker.Lock()
s.listeners = append(s.listeners, l)
s.locker.Unlock()
var tempDelay time.Duration // how long to sleep on accept failure
for {
c, err := l.Accept()
if err != nil {
select {
case <-s.done:
// we called Close()
return nil
default:
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if maxDelay := 1 * time.Second; tempDelay > maxDelay {
tempDelay = maxDelay
}
s.logger.ErrorContext(
ctx,
"accept error, retrying",
slog.Any("err", err),
slog.Any("temp_delay", tempDelay),
)
time.Sleep(tempDelay)
continue
}
return err
}
s.wg.Add(1)
go s.handleConn(ctx, c)
}
}
func (s *Server) handleConn(ctx context.Context, conn net.Conn) {
ctx, cancel := context.WithCancel(ctx)
c := &Conn{
ctx: ctx,
server: s,
conn: conn,
text: textsmtp.NewTextproto(conn, s.readerSize, s.writerSize, s.maxLineLength),
}
s.locker.Lock()
s.conns[c] = struct{}{}
s.locker.Unlock()
var err error
defer func() {
if err := recover(); err != nil {
c.writeResponse(421, smtp.EnhancedCode{4, 0, 0}, "Internal server error")
stack := debug.Stack()
c.logger().ErrorContext(
c.ctx,
"panic serving",
slog.Any("err", err),
slog.Any("stack", string(stack)),
)
c.Close(errors.New("recovered from panic inside handleConn"))
}
s.locker.Lock()
delete(s.conns, c)
s.locker.Unlock()
s.wg.Done()
cancel()
}()
sctx, session, err := s.backend.NewSession(ctx, c)
if err != nil {
c.Close(fmt.Errorf("couldn't create connection wrapper: %w", err))
return
}
// update ctx and set session
c.ctx = sctx
c.session = session
c.logger().DebugContext(c.ctx, "connection is opened")
// explicit tls handshake call
if tlsConn, ok := c.conn.(*tls.Conn); ok {
if d := s.readTimeout; d != 0 {
_ = c.conn.SetReadDeadline(time.Now().Add(d))
}
if d := s.writeTimeout; d != 0 {
_ = c.conn.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
c.handleError(err)
return
}
}
// run always returns an error when finished
c.handleError(c.run())
}
// Listen listens on the network address s.Addr
// to handle requests on incoming connections.
//
// If s.Addr is blank and LMTP is disabled, ":smtp" is used.
func (s *Server) Listen() (net.Listener, error) {
network := s.network
if network == "" {
network = "tcp"
}
addr := s.addr
if addr == "" {
addr = ":smtp"
}
var l net.Listener
var err error
if s.implicitTLS {
l, err = tls.Listen(network, addr, s.tlsConfig)
} else {
l, err = net.Listen(network, addr)
}
if err != nil {
return nil, err
}
return l, nil
}
// ListenAndServe listens on the network address s.Addr and then calls Serve
// to handle requests on incoming connections.
//
// If s.Addr is blank and LMTP is disabled, ":smtp" is used.
func (s *Server) ListenAndServe(ctx context.Context) error {
network := s.network
if network == "" {
network = "tcp"
}
addr := s.addr
if addr == "" {
addr = ":smtp"
}
var l net.Listener
var err error
if s.implicitTLS {
l, err = tls.Listen(network, addr, s.tlsConfig)
} else {
l, err = net.Listen(network, addr)
}
if err != nil {
return err
}
return s.Serve(ctx, l)
}
// Close immediately closes all active listeners and connections.
//
// Close returns any error returned from closing the server's underlying
// listener(s).
func (s *Server) Close() error {
select {
case <-s.done:
return ErrServerClosed
default:
close(s.done)
}
var err error
s.locker.Lock()
for _, l := range s.listeners {
if lerr := l.Close(); lerr != nil && err == nil {
err = lerr
}
}
for conn := range s.conns {
// directly close underlying connection
_ = conn.conn.Close()
}
s.locker.Unlock()
return err
}
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners and then waiting indefinitely for connections to return to
// idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener(s).
func (s *Server) Shutdown(ctx context.Context) error {
select {
case <-s.done:
return ErrServerClosed
default:
close(s.done)
}
var err error
s.locker.Lock()
for _, l := range s.listeners {
if lerr := l.Close(); lerr != nil && err == nil {
err = lerr
}
}
s.locker.Unlock()
connDone := make(chan struct{})
go func() {
defer close(connDone)
s.wg.Wait()
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-connDone:
return err
}
}
package server
import (
"crypto/tls"
"errors"
"log/slog"
"net"
"sync"
"time"
)
// ErrServerClosed occurs if a server is already closed.
var ErrServerClosed = errors.New("smtp: server already closed")
// Server implements a SMTP server.
type Server struct {
// The type of network, "tcp" or "unix".
network string
// TCP or Unix address to listen on.
addr string
// The server TLS configuration.
tlsConfig *tls.Config
hostname string
maxRecipients int
// Max line length for every command except data and bdat.
maxLineLength int
// Maximum size when receiving data and bdat.
maxMessageBytes int64
// Reader buffer size.
readerSize int
// Writer buffer size.
writerSize int
readTimeout time.Duration
writeTimeout time.Duration
implicitTLS bool
// Enforces usage of implicit tls or starttls before accepting commands except NOOP, EHLO, STARTTLS, or QUIT.
enforceSecureConnection bool
// Enforces usage of authentication.
enforceAuthentication bool
// Advertise SMTPUTF8 (RFC 6531) capability.
// Should be used only if backend supports it.
enableSMTPUTF8 bool
// Advertise REQUIRETLS (RFC 8689) capability.
// Should be used only if backend supports it.
enableREQUIRETLS bool
// Advertise CHUNKING (RFC 1830) capability.
enableCHUNKING bool
// Advertise BINARYMIME (RFC 3030) capability.
// Should be used only if backend supports it.
enableBINARYMIME bool
// Advertise DSN (RFC 3461) capability.
// Should be used only if backend supports it.
enableDSN bool
// Advertise XOORG capability.
// Should be used only if backend supports it.
enableXOORG bool
// The server backend.
backend Backend
logger *slog.Logger
wg sync.WaitGroup
done chan struct{}
locker sync.Mutex
listeners []net.Listener
conns map[*Conn]struct{}
}
// Backend returns the servers Backend.
func (s *Server) Backend() Backend {
return s.backend
}
// Option is an option for the server.
type Option func(*Server)
// New creates a new SMTP server.
func New(opts ...Option) *Server {
s := &Server{
done: make(chan struct{}, 1),
conns: make(map[*Conn]struct{}),
hostname: "localhost",
}
for _, o := range opts {
o(s)
}
if s.logger == nil {
s.logger = slog.Default()
}
return s
}
// WithLogger sets the backend.
func WithLogger(logger *slog.Logger) Option {
return func(s *Server) {
s.logger = logger
}
}
// WithBackend sets the backend.
func WithBackend(backend Backend) Option {
return func(s *Server) {
s.backend = backend
}
}
// WithNetwork sets the network.
func WithNetwork(network string) Option {
return func(s *Server) {
s.network = network
}
}
// WithReadTimeout sets the read timeout.
func WithReadTimeout(readTimeout time.Duration) Option {
return func(s *Server) {
s.readTimeout = readTimeout
}
}
// WithWriteTimeout sets the write timeout.
func WithWriteTimeout(writeTimeout time.Duration) Option {
return func(s *Server) {
s.writeTimeout = writeTimeout
}
}
// WithMaxMessageBytes sets the max message size.
func WithMaxMessageBytes(maxMessageBytes int64) Option {
return func(s *Server) {
s.maxMessageBytes = maxMessageBytes
}
}
// WithMaxLineLength sets the max length per line.
func WithMaxLineLength(maxLineLength int) Option {
return func(s *Server) {
s.maxLineLength = maxLineLength
}
}
// WithMaxRecipients sets the max recipients per mail.
func WithMaxRecipients(maxRecipients int) Option {
return func(s *Server) {
s.maxRecipients = maxRecipients
}
}
// WithAddr sets addr.
func WithAddr(addr string) Option {
return func(s *Server) {
s.addr = addr
}
}
// WithEnableXOORG enables xoorg.
func WithEnableXOORG(enableXOORG bool) Option {
return func(s *Server) {
s.enableXOORG = enableXOORG
}
}
// WithEnableBINARYMIME sets EnableBINARYMIME.
func WithEnableBINARYMIME(enableBINARYMIME bool) Option {
return func(s *Server) {
s.enableBINARYMIME = enableBINARYMIME
}
}
// WithEnableREQUIRETLS sets EnableREQUIRETLS.
func WithEnableREQUIRETLS(enableREQUIRETLS bool) Option {
return func(s *Server) {
s.enableREQUIRETLS = enableREQUIRETLS
}
}
// WithEnableCHUNKING sets EnableCHUNKING.
func WithEnableCHUNKING(enableCHUNKING bool) Option {
return func(s *Server) {
s.enableCHUNKING = enableCHUNKING
}
}
// WithEnableSMTPUTF8 sets EnableSMTPUTF8.
func WithEnableSMTPUTF8(enableSMTPUTF8 bool) Option {
return func(s *Server) {
s.enableSMTPUTF8 = enableSMTPUTF8
}
}
// WithEnableDSN sets EnableDSN.
func WithEnableDSN(enableDSN bool) Option {
return func(s *Server) {
s.enableDSN = enableDSN
}
}
// WithImplicitTLS sets implicitTLS.
func WithImplicitTLS(implicitTLS bool) Option {
return func(s *Server) {
s.implicitTLS = implicitTLS
}
}
// WithHostname sets the domain.
func WithHostname(hostname string) Option {
return func(s *Server) {
s.hostname = hostname
}
}
// WithTLSConfig sets certificate.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(s *Server) {
s.tlsConfig = tlsConfig
}
}
// WithEnforceSecureConnection enforces implicit TLS or STARTTLS.
func WithEnforceSecureConnection(enforceSecureConnection bool) Option {
return func(s *Server) {
s.enforceSecureConnection = enforceSecureConnection
}
}
// WithEnforceAuthentication enforces authentication before mail usage.
func WithEnforceAuthentication(enforceAuthentication bool) Option {
return func(s *Server) {
s.enforceAuthentication = enforceAuthentication
}
}
// WithReaderSize sets ReaderSize.
func WithReaderSize(readerSize int) Option {
return func(s *Server) {
s.readerSize = readerSize
}
}
// WithWriterSize sets WriterSize.
func WithWriterSize(writerSize int) Option {
return func(s *Server) {
s.writerSize = writerSize
}
}
package tester
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"math/big"
"time"
)
// GenX509KeyPair generates a self signed smtp server certificate with the given domain.
func GenX509KeyPair(domain string) (tls.Certificate, error) {
now := time.Now()
template := &x509.Certificate{
SerialNumber: big.NewInt(now.Unix()),
Subject: pkix.Name{
CommonName: domain,
Country: []string{"Germany"},
Organization: []string{"UPONU GmbH"},
OrganizationalUnit: []string{"mail:u secure"},
},
NotBefore: now.AddDate(0, 0, -1),
NotAfter: now.AddDate(999, 0, 0),
SubjectKeyId: []byte{113, 117, 105, 99, 107, 115, 101, 114, 118, 101},
BasicConstraintsValid: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
DNSNames: []string{domain},
}
extSubjectAltName := pkix.Extension{}
extSubjectAltName.Id = asn1.ObjectIdentifier{2, 5, 29, 17}
extSubjectAltName.Critical = false
var err error
extSubjectAltName.Value, err = asn1.Marshal([]string{`dns:` + domain})
if err != nil {
return tls.Certificate{}, err
}
template.ExtraExtensions = []pkix.Extension{extSubjectAltName}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return tls.Certificate{}, err
}
cert, err := x509.CreateCertificate(rand.Reader, template, template,
priv.Public(), priv)
if err != nil {
return tls.Certificate{}, err
}
var outCert tls.Certificate
outCert.Certificate = append(outCert.Certificate, cert)
outCert.PrivateKey = priv
return outCert, nil
}
package tester
import (
"bytes"
"embed"
"io"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func getAllFilenames(fs *embed.FS, path string) (out []string, err error) {
if len(path) == 0 {
path = "."
}
entries, err := fs.ReadDir(path)
if err != nil {
return nil, err
}
for _, entry := range entries {
fp := filepath.Join(path, entry.Name())
if entry.IsDir() {
res, err := getAllFilenames(fs, fp)
if err != nil {
return nil, err
}
out = append(out, res...)
continue
}
out = append(out, fp)
}
return out, err
}
func chunkSlice(slice []byte, chunkSize int) [][]byte {
var chunks [][]byte
for len(slice) != 0 {
// necessary check to avoid slicing beyond
// slice capacity
if len(slice) < chunkSize {
chunkSize = len(slice)
}
chunks = append(chunks, slice[0:chunkSize])
slice = slice[chunkSize:]
}
return chunks
}
func checkExpectedBufferAgainsActual(t *testing.T, b []byte, expected func(io.Writer) io.WriteCloser, actual func(io.Writer) io.WriteCloser) {
var buf bytes.Buffer
var err error
f := expected(&buf)
_, err = f.Write(b)
require.NoError(t, err)
require.NoError(t, f.Close())
size := 1
for size < 4048 && len(b) >= size {
bsplitted := chunkSlice(b, size)
var buf1 bytes.Buffer
f := actual(&buf1)
for _, b := range bsplitted {
_, err = f.Write(b)
require.NoError(t, err)
}
require.NoError(t, f.Close())
require.Equal(t, buf, buf1)
size++
}
}
// WriterCompareTest reads all files out of fs[path] and compares the expected func against the actual func.
// To simulate differences of Write calls of different sizes it slices the files in increasing sizes up to 4048.
func WriterCompareTest(t *testing.T, fs *embed.FS, path string, expected func(io.Writer) io.WriteCloser, actual func(io.Writer) io.WriteCloser) {
files, err := getAllFilenames(fs, path)
require.NoError(t, err)
for _, file := range files {
dat, err := fs.ReadFile(file)
require.NoError(t, err)
checkExpectedBufferAgainsActual(t, []byte(dat), expected, actual)
}
}
func checkRaderExpectedAgainsActual(t *testing.T, b []byte, expected func(io.Reader) ([]byte, error), actual func(io.Reader) ([]byte, error)) {
pr, pw := io.Pipe()
go func() {
_, err := pw.Write(b)
require.NoError(t, err)
err = pw.Close()
require.NoError(t, err)
}()
buf, err := expected(pr)
require.ErrorIs(t, io.ErrUnexpectedEOF, err)
size := 1
for size < 4048 && len(b) >= size {
bsplitted := chunkSlice(b, size)
pr, pw = io.Pipe()
go func() {
for _, b := range bsplitted {
_, err = pw.Write(b)
require.NoError(t, err)
}
err = pw.Close()
require.NoError(t, err)
}()
buf1, err := actual(pr)
require.ErrorIs(t, io.ErrUnexpectedEOF, err)
// print(string(buf), string(buf1))
require.Equal(t, buf, buf1)
size++
}
}
// ReaderCompareTest reads all files out of fs[path] and compares the result of the expected func against the actual func.
// To simulate differences of Read calls with different sizes it slices the files in increasing sizes up to 4048.
func ReaderCompareTest(t *testing.T, fs *embed.FS, path string, expected func(io.Reader) ([]byte, error), actual func(io.Reader) ([]byte, error)) {
files, err := getAllFilenames(fs, path)
require.NoError(t, err)
for _, file := range files {
dat, err := fs.ReadFile(file)
require.NoError(t, err)
checkRaderExpectedAgainsActual(t, []byte(dat), expected, actual)
}
}
package tester
import (
"bytes"
"io"
"net"
"strings"
"time"
)
// FakeConn fakes a conn for testing.
type FakeConn struct {
io.ReadWriter
RemoteAddrReturn net.Addr
}
// NewFakeConnStream creates a new FakeConn with a stream as a input.
func NewFakeConnStream(in io.Reader, out *bytes.Buffer) *FakeConn {
rw := struct {
io.Reader
io.Writer
}{
Reader: in,
Writer: out,
}
return &FakeConn{
ReadWriter: rw,
}
}
// NewFakeConn creates a new FakeConn with a string as a input.
func NewFakeConn(in string, out *bytes.Buffer) *FakeConn {
rw := struct {
io.Reader
io.Writer
}{
Reader: strings.NewReader(in),
Writer: out,
}
return &FakeConn{
ReadWriter: rw,
}
}
// Close always returns nil.
func (FakeConn) Close() error { return nil }
// LocalAddr always returns nil.
func (FakeConn) LocalAddr() net.Addr { return nil }
// RemoteAddr always returns RemoteAddrReturn.
func (f FakeConn) RemoteAddr() net.Addr { return f.RemoteAddrReturn }
// SetDeadline always returns nil and does nothing.
func (FakeConn) SetDeadline(time.Time) error { return nil }
// SetReadDeadline always returns nil and does nothing.
func (FakeConn) SetReadDeadline(time.Time) error { return nil }
// SetWriteDeadline always returns nil and does nothing.
func (FakeConn) SetWriteDeadline(time.Time) error { return nil }
package tester
import "strings"
// Mail is one mail received by SMTP server.
type Mail struct {
From string
Recipients []string
Data []byte
}
// LookupKey call LookupKey for current mail.
func (m *Mail) LookupKey() string {
return m.From + "+" + strings.Join(m.Recipients, "+")
}
// LookupKey returns a key of the format:
//
// m.From+m.Recipient_1+m.Recipient_2...
func LookupKey(f string, r []string) string {
return f + "+" + strings.Join(r, "+")
}
// Package smtptester implements a simple SMTP server for testing. All
// received mails are saved in a sync.Map with a key:
//
// From+Recipient1+Recipient2
//
// Mails to the same sender and recipients will overwrite a previous
// received mail, when the recipients slice has the same order as
// in the mail received before.
package tester
import (
"context"
"crypto/tls"
"io"
"log/slog"
"sync"
"time"
"github.com/uponusolutions/go-sasl"
"github.com/uponusolutions/go-smtp"
"github.com/uponusolutions/go-smtp/server"
)
// Standard returns a standard SMTP server listening on a random Port.
func Standard() *server.Server {
return server.New(
server.WithAddr(":0"),
server.WithReadTimeout(10*time.Second),
server.WithWriteTimeout(10*time.Second),
server.WithMaxMessageBytes(1024*1024),
server.WithMaxRecipients(100),
server.WithBackend(NewBackend()),
)
}
// StandardWithAddress with address returns a standard SMTP server listenting on addr.
func StandardWithAddress(addr string) *server.Server {
return server.New(
server.WithAddr(addr),
server.WithReadTimeout(10*time.Second),
server.WithWriteTimeout(10*time.Second),
server.WithMaxMessageBytes(1024*1024),
server.WithMaxRecipients(100),
server.WithBackend(NewBackend()),
)
}
///////////////////////////////////////////////////////////////////////////
// Backend
///////////////////////////////////////////////////////////////////////////
// Backend is the backend for out test server.
// It contains a sync.Map with all mails received.
type Backend struct {
Mails sync.Map
}
// NewBackend returns a new Backend with an empty (not nil) Mails map.
func NewBackend() *Backend {
return &Backend{Mails: sync.Map{}}
}
// NewSession returns a new Session.
func (b *Backend) NewSession(ctx context.Context, _ *server.Conn) (context.Context, server.Session, error) {
return ctx, newSession(b), nil
}
// GetBackend returns the concrete type *Backend from SMTP server.
func GetBackend(s *server.Server) *Backend {
if s.Backend() == nil {
return nil
}
b, ok := s.Backend().(*Backend)
if !ok {
return nil
}
return b
}
// Add adds mail to backends map.
func (b *Backend) Add(m *Mail) {
b.Mails.Store(m.LookupKey(), m)
}
// Load loads mail from 'from' to recipients 'recipients'. The ok
// result indicates whether value was found in the map.
func (b *Backend) Load(from string, recipients []string) (*Mail, bool) {
i, ok := b.Mails.Load(LookupKey(from, recipients))
if !ok {
return nil, ok
}
return i.(*Mail), ok //nolint
}
///////////////////////////////////////////////////////////////////////////
// Session
///////////////////////////////////////////////////////////////////////////
// A Session is returned after successful login.
type Session struct {
backend *Backend
mail *Mail
}
func newSession(b *Backend) *Session {
return &Session{
backend: b,
mail: &Mail{},
}
}
// Reset implements Reset interface.
func (s *Session) Reset(ctx context.Context, _ bool) (context.Context, error) {
s.mail = &Mail{}
return ctx, nil
}
// Close implements the Close interface.
func (s *Session) Close(_ context.Context, _ error) {
s.mail = &Mail{}
}
// Logger implements the Logger interface.
func (Session) Logger(_ context.Context) *slog.Logger {
return nil
}
// Verify implements the Verify interface.
func (Session) Verify(_ context.Context, _ string, _ *smtp.VrfyOptions) error {
return nil
}
// Mail implements the Mail interface.
func (s *Session) Mail(_ context.Context, from string, _ *smtp.MailOptions) error {
s.mail.From = from
return nil
}
// Rcpt implements the Rcpt interface.
func (s *Session) Rcpt(_ context.Context, to string, _ *smtp.RcptOptions) error {
s.mail.Recipients = append(s.mail.Recipients, to)
return nil
}
// Data implements the Data interface.
func (s *Session) Data(_ context.Context, r func() io.Reader) (string, error) {
var err error
if s.mail.Data, err = io.ReadAll(r()); err != nil {
return "", err
}
s.backend.Add(s.mail)
return "", nil
}
// AuthMechanisms implements the AuthMechanisms interface.
func (Session) AuthMechanisms(_ context.Context) []string {
return nil
}
// Auth implements the Auth interface.
func (Session) Auth(_ context.Context, _ string) (sasl.Server, error) {
return nil, nil
}
// STARTTLS implements the STARTTLS interface.
func (Session) STARTTLS(_ context.Context, config *tls.Config) (*tls.Config, error) {
return config, nil
}
package smtp
import (
"net"
"time"
)
// Timeout sets a timeout by deadline to the connection and relieves it when returning func is used.
func Timeout(conn net.Conn, duration time.Duration) func() {
_ = conn.SetDeadline(time.Now().Add(duration))
return func() {
_ = conn.SetDeadline(time.Time{})
}
}