server.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. package guerrilla
  2. import (
  3. "crypto/rand"
  4. "crypto/tls"
  5. "fmt"
  6. "io"
  7. "net"
  8. "strings"
  9. "time"
  10. "runtime"
  11. log "github.com/Sirupsen/logrus"
  12. "sync"
  13. "sync/atomic"
  14. "github.com/flashmob/go-guerrilla/backends"
  15. "github.com/flashmob/go-guerrilla/response"
  16. )
  17. const (
  18. CommandVerbMaxLength = 16
  19. CommandLineMaxLength = 1024
  20. // Number of allowed unrecognized commands before we terminate the connection
  21. MaxUnrecognizedCommands = 5
  22. // The maximum total length of a reverse-path or forward-path is 256
  23. RFC2821LimitPath = 256
  24. // The maximum total length of a user name or other local-part is 64
  25. RFC2832LimitLocalPart = 64
  26. //The maximum total length of a domain name or number is 255
  27. RFC2821LimitDomain = 255
  28. // The minimum total number of recipients that must be buffered is 100
  29. RFC2821LimitRecipients = 100
  30. )
  31. const (
  32. // server has just been created
  33. ServerStateNew = iota
  34. // Server has just been stopped
  35. ServerStateStopped
  36. // Server has been started and is running
  37. ServerStateRunning
  38. // Server could not start due to an error
  39. ServerStateStartError
  40. )
  41. // Server listens for SMTP clients on the port specified in its config
  42. type server struct {
  43. configStore atomic.Value // stores guerrilla.ServerConfig
  44. backend backends.Backend
  45. tlsConfig *tls.Config
  46. tlsConfigStore atomic.Value
  47. timeout atomic.Value // stores time.Duration
  48. listenInterface string
  49. clientPool *Pool
  50. wg sync.WaitGroup // for waiting to shutdown
  51. listener net.Listener
  52. closedListener chan (bool)
  53. hosts allowedHosts // stores map[string]bool for faster lookup
  54. state int
  55. }
  56. type allowedHosts struct {
  57. table map[string]bool // host lookup table
  58. m sync.Mutex // guard access to the map
  59. }
  60. // Creates and returns a new ready-to-run Server from a configuration
  61. func newServer(sc *ServerConfig, b backends.Backend) (*server, error) {
  62. server := &server{
  63. backend: b,
  64. clientPool: NewPool(sc.MaxClients),
  65. closedListener: make(chan (bool), 1),
  66. listenInterface: sc.ListenInterface,
  67. state: ServerStateNew,
  68. }
  69. server.setConfig(sc)
  70. server.setTimeout(sc.Timeout)
  71. if err := server.configureSSL(); err != nil {
  72. return server, err
  73. }
  74. return server, nil
  75. }
  76. func (s *server) configureSSL() error {
  77. sConfig := s.configStore.Load().(ServerConfig)
  78. if sConfig.TLSAlwaysOn || sConfig.StartTLSOn {
  79. cert, err := tls.LoadX509KeyPair(sConfig.PublicKeyFile, sConfig.PrivateKeyFile)
  80. if err != nil {
  81. return fmt.Errorf("error while loading the certificate: %s", err)
  82. }
  83. tlsConfig := &tls.Config{
  84. Certificates: []tls.Certificate{cert},
  85. ClientAuth: tls.VerifyClientCertIfGiven,
  86. ServerName: sConfig.Hostname,
  87. }
  88. tlsConfig.Rand = rand.Reader
  89. s.tlsConfigStore.Store(tlsConfig)
  90. }
  91. return nil
  92. }
  93. // Set the timeout for the server and all clients
  94. func (server *server) setTimeout(seconds int) {
  95. duration := time.Duration(int64(seconds))
  96. server.clientPool.SetTimeout(duration)
  97. server.timeout.Store(duration)
  98. }
  99. // goroutine safe config store
  100. func (server *server) setConfig(sc *ServerConfig) {
  101. server.configStore.Store(*sc)
  102. }
  103. // goroutine safe
  104. func (server *server) isEnabled() bool {
  105. sc := server.configStore.Load().(ServerConfig)
  106. return sc.IsEnabled
  107. }
  108. // Set the allowed hosts for the server
  109. func (server *server) setAllowedHosts(allowedHosts []string) {
  110. defer server.hosts.m.Unlock()
  111. server.hosts.m.Lock()
  112. server.hosts.table = make(map[string]bool, len(allowedHosts))
  113. for _, h := range allowedHosts {
  114. server.hosts.table[strings.ToLower(h)] = true
  115. }
  116. }
  117. // Begin accepting SMTP clients. Will block unless there is an error or server.Shutdown() is called
  118. func (server *server) Start(startWG *sync.WaitGroup) error {
  119. var clientID uint64
  120. clientID = 0
  121. listener, err := net.Listen("tcp", server.listenInterface)
  122. server.listener = listener
  123. if err != nil {
  124. startWG.Done() // don't wait for me
  125. server.state = ServerStateStartError
  126. return fmt.Errorf("[%s] Cannot listen on port: %s ", server.listenInterface, err.Error())
  127. }
  128. log.Infof("Listening on TCP %s", server.listenInterface)
  129. server.state = ServerStateRunning
  130. startWG.Done() // start successful, don't wait for me
  131. for {
  132. log.Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.listenInterface, clientID+1)
  133. conn, err := listener.Accept()
  134. clientID++
  135. if err != nil {
  136. if e, ok := err.(net.Error); ok && !e.Temporary() {
  137. log.Infof("Server [%s] has stopped accepting new clients", server.listenInterface)
  138. // the listener has been closed, wait for clients to exit
  139. log.Infof("shutting down pool [%s]", server.listenInterface)
  140. server.clientPool.ShutdownState()
  141. server.clientPool.ShutdownWait()
  142. server.state = ServerStateStopped
  143. server.closedListener <- true
  144. return nil
  145. }
  146. log.WithError(err).Info("Temporary error accepting client")
  147. continue
  148. }
  149. go func(p Poolable, borrow_err error) {
  150. c := p.(*client)
  151. if borrow_err == nil {
  152. server.handleClient(c)
  153. server.clientPool.Return(c)
  154. } else {
  155. log.WithError(borrow_err).Info("couldn't borrow a new client")
  156. // we could not get a client, so close the connection.
  157. conn.Close()
  158. }
  159. // intentionally placed Borrow in args so that it's called in the
  160. // same main goroutine.
  161. }(server.clientPool.Borrow(conn, clientID))
  162. }
  163. }
  164. func (server *server) Shutdown() {
  165. if server.listener != nil {
  166. // This will cause Start function to return, by causing an error on listener.Accept
  167. server.listener.Close()
  168. // wait for the listener to listener.Accept
  169. <-server.closedListener
  170. // At this point Start will exit and close down the pool
  171. } else {
  172. server.clientPool.ShutdownState()
  173. // listener already closed, wait for clients to exit
  174. server.clientPool.ShutdownWait()
  175. server.state = ServerStateStopped
  176. }
  177. }
  178. func (server *server) GetActiveClientsCount() int {
  179. return server.clientPool.GetActiveClientsCount()
  180. }
  181. // Verifies that the host is a valid recipient.
  182. func (server *server) allowsHost(host string) bool {
  183. defer server.hosts.m.Unlock()
  184. server.hosts.m.Lock()
  185. if _, ok := server.hosts.table[strings.ToLower(host)]; ok {
  186. return true
  187. }
  188. return false
  189. }
  190. // Reads from the client until a terminating sequence is encountered,
  191. // or until a timeout occurs.
  192. func (server *server) readCommand(client *client, maxSize int64) (string, error) {
  193. var input, reply string
  194. var err error
  195. // In command state, stop reading at line breaks
  196. suffix := "\r\n"
  197. for {
  198. client.setTimeout(server.timeout.Load().(time.Duration))
  199. reply, err = client.bufin.ReadString('\n')
  200. input = input + reply
  201. if err != nil {
  202. break
  203. }
  204. if strings.HasSuffix(input, suffix) {
  205. // discard the suffix and stop reading
  206. input = input[0 : len(input)-len(suffix)]
  207. break
  208. }
  209. }
  210. return input, err
  211. }
  212. // Writes a response to the client.
  213. func (server *server) writeResponse(client *client) error {
  214. client.setTimeout(server.timeout.Load().(time.Duration))
  215. size, err := client.bufout.WriteString(client.response)
  216. if err != nil {
  217. return err
  218. }
  219. err = client.bufout.Flush()
  220. if err != nil {
  221. return err
  222. }
  223. client.response = client.response[size:]
  224. return nil
  225. }
  226. func (server *server) isShuttingDown() bool {
  227. return server.clientPool.IsShuttingDown()
  228. }
  229. // Handles an entire client SMTP exchange
  230. func (server *server) handleClient(client *client) {
  231. defer client.closeConn()
  232. sc := server.configStore.Load().(ServerConfig)
  233. log.Infof("Handle client [%s], id: %d", client.RemoteAddress, client.ID)
  234. // Initial greeting
  235. greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s gr:%d",
  236. sc.Hostname, Version, client.ID,
  237. server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339), runtime.NumGoroutine())
  238. helo := fmt.Sprintf("250 %s Hello", sc.Hostname)
  239. // ehlo is a multi-line reply and need additional \r\n at the end
  240. ehlo := fmt.Sprintf("250-%s Hello\r\n", sc.Hostname)
  241. // Extended feature advertisements
  242. messageSize := fmt.Sprintf("250-SIZE %d\r\n", sc.MaxSize)
  243. pipelining := "250-PIPELINING\r\n"
  244. advertiseTLS := "250-STARTTLS\r\n"
  245. advertiseEnhancedStatusCodes := "250-ENHANCEDSTATUSCODES\r\n"
  246. // the last line doesn't need \r\n since string will be printed as a new line
  247. help := "250 HELP"
  248. if sc.TLSAlwaysOn {
  249. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  250. if ok && client.upgradeToTLS(tlsConfig) {
  251. advertiseTLS = ""
  252. } else {
  253. // server requires TLS, but can't handshake
  254. client.kill()
  255. }
  256. }
  257. if !sc.StartTLSOn {
  258. // STARTTLS turned off, don't advertise it
  259. advertiseTLS = ""
  260. }
  261. for client.isAlive() {
  262. switch client.state {
  263. case ClientGreeting:
  264. client.responseAdd(greeting)
  265. client.state = ClientCmd
  266. case ClientCmd:
  267. client.bufin.setLimit(CommandLineMaxLength)
  268. input, err := server.readCommand(client, sc.MaxSize)
  269. log.Debugf("Client sent: %s", input)
  270. if err == io.EOF {
  271. log.WithError(err).Warnf("Client closed the connection: %s", client.RemoteAddress)
  272. return
  273. } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
  274. log.WithError(err).Warnf("Timeout: %s", client.RemoteAddress)
  275. return
  276. } else if err == LineLimitExceeded {
  277. client.responseAdd(response.CustomString(response.InvalidCommand, 554, response.ClassPermanentFailure, "Line too long."))
  278. client.kill()
  279. break
  280. } else if err != nil {
  281. log.WithError(err).Warnf("Read error: %s", client.RemoteAddress)
  282. client.kill()
  283. break
  284. }
  285. if server.isShuttingDown() {
  286. client.state = ClientShutdown
  287. continue
  288. }
  289. input = strings.Trim(input, " \r\n")
  290. cmdLen := len(input)
  291. if cmdLen > CommandVerbMaxLength {
  292. cmdLen = CommandVerbMaxLength
  293. }
  294. cmd := strings.ToUpper(input[:cmdLen])
  295. switch {
  296. case strings.Index(cmd, "HELO") == 0:
  297. client.Helo = strings.Trim(input[4:], " ")
  298. client.resetTransaction()
  299. client.responseAdd(helo)
  300. case strings.Index(cmd, "EHLO") == 0:
  301. client.Helo = strings.Trim(input[4:], " ")
  302. client.resetTransaction()
  303. client.responseAdd(ehlo + messageSize + pipelining + advertiseTLS + advertiseEnhancedStatusCodes + help)
  304. case strings.Index(cmd, "HELP") == 0:
  305. client.responseAdd("214 OK\r\n" + messageSize + pipelining + advertiseTLS + help)
  306. case strings.Index(cmd, "MAIL FROM:") == 0:
  307. if client.isInTransaction() {
  308. client.responseAdd(response.CustomString(response.InvalidCommand, 503, response.ClassPermanentFailure, "Error: nested MAIL command"))
  309. break
  310. }
  311. from, err := extractEmail(input[10:])
  312. if err != nil {
  313. client.responseAdd(err.Error())
  314. } else {
  315. client.MailFrom = from
  316. client.responseAdd(response.CustomString(response.OtherAddressStatus, 250, response.ClassSuccess, "OK"))
  317. }
  318. case strings.Index(cmd, "RCPT TO:") == 0:
  319. if len(client.RcptTo) > RFC2821LimitRecipients {
  320. client.responseAdd(response.CustomString(response.TooManyRecipients, 452, response.ClassTransientFailure, "Too many recipients"))
  321. break
  322. }
  323. to, err := extractEmail(input[8:])
  324. if err != nil {
  325. client.responseAdd(err.Error())
  326. } else {
  327. if !server.allowsHost(to.Host) {
  328. client.responseAdd(response.CustomString(response.BadDestinationMailboxAddress, 454, response.ClassTransientFailure, "Error: Relay access denied: "+to.Host))
  329. } else {
  330. client.RcptTo = append(client.RcptTo, *to)
  331. client.responseAdd(response.String(response.DestinationMailboxAddressValid, response.ClassSuccess))
  332. }
  333. }
  334. case strings.Index(cmd, "RSET") == 0:
  335. client.resetTransaction()
  336. client.responseAdd(response.CustomString(response.OtherAddressStatus, 250, response.ClassSuccess, "OK"))
  337. case strings.Index(cmd, "VRFY") == 0:
  338. client.responseAdd(response.CustomString(response.OtherOrUndefinedProtocolStatus, 252, response.ClassSuccess, "Cannot verify user"))
  339. case strings.Index(cmd, "NOOP") == 0:
  340. client.responseAdd(response.String(response.DestinationMailboxAddressValid, response.ClassSuccess))
  341. case strings.Index(cmd, "QUIT") == 0:
  342. client.responseAdd(response.CustomString(response.OtherStatus, 221, response.ClassSuccess, "Bye"))
  343. client.kill()
  344. case strings.Index(cmd, "DATA") == 0:
  345. if client.MailFrom.IsEmpty() {
  346. client.responseAdd(response.CustomString(response.InvalidCommand, 503, response.ClassPermanentFailure, "Error: No sender"))
  347. break
  348. }
  349. if len(client.RcptTo) == 0 {
  350. client.responseAdd(response.CustomString(response.InvalidCommand, 503, response.ClassPermanentFailure, "Error: No recipients"))
  351. break
  352. }
  353. client.responseAdd("354 Enter message, ending with '.' on a line by itself")
  354. client.state = ClientData
  355. case sc.StartTLSOn && strings.Index(cmd, "STARTTLS") == 0:
  356. client.responseAdd(response.CustomString(response.OtherStatus, 220, response.ClassSuccess, "Ready to start TLS"))
  357. client.state = ClientStartTLS
  358. default:
  359. client.responseAdd(response.CustomString(response.SyntaxError, 500, response.ClassPermanentFailure, "Unrecognized command: "+cmd))
  360. client.errors++
  361. if client.errors > MaxUnrecognizedCommands {
  362. client.responseAdd(response.CustomString(response.InvalidCommand, 554, response.ClassPermanentFailure, "Too many unrecognized commands"))
  363. client.kill()
  364. }
  365. }
  366. case ClientData:
  367. // intentionally placed the limit 1MB above so that reading does not return with an error
  368. // if the client goes a little over. Anything above will err
  369. client.bufin.setLimit(int64(sc.MaxSize) + 1024000) // This a hard limit.
  370. n, err := client.Data.ReadFrom(client.smtpReader.DotReader())
  371. if n > sc.MaxSize {
  372. err = fmt.Errorf("Maximum DATA size exceeded (%d)", sc.MaxSize)
  373. }
  374. if err != nil {
  375. if err == LineLimitExceeded {
  376. client.responseAdd(response.CustomString(response.SyntaxError, 550, response.ClassPermanentFailure, "Error: "+LineLimitExceeded.Error()))
  377. client.kill()
  378. } else if err == MessageSizeExceeded {
  379. client.responseAdd(response.CustomString(response.SyntaxError, 550, response.ClassPermanentFailure, "Error: "+MessageSizeExceeded.Error()))
  380. client.kill()
  381. } else {
  382. client.kill()
  383. client.responseAdd(response.CustomString(response.OtherOrUndefinedMailSystemStatus, 451, response.ClassTransientFailure, "Error: "+err.Error()))
  384. }
  385. log.WithError(err).Warn("Error reading data")
  386. break
  387. }
  388. res := server.backend.Process(client.Envelope)
  389. if res.Code() < 300 {
  390. client.messagesSent++
  391. }
  392. client.responseAdd(res.String())
  393. client.state = ClientCmd
  394. if server.isShuttingDown() {
  395. client.state = ClientShutdown
  396. }
  397. client.resetTransaction()
  398. case ClientStartTLS:
  399. if !client.TLS && sc.StartTLSOn {
  400. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  401. if ok && client.upgradeToTLS(tlsConfig) {
  402. advertiseTLS = ""
  403. client.resetTransaction()
  404. }
  405. }
  406. // change to command state
  407. client.state = ClientCmd
  408. case ClientShutdown:
  409. // shutdown state
  410. client.responseAdd(response.CustomString(response.OtherOrUndefinedMailSystemStatus, 421, response.ClassTransientFailure, "Server is shutting down. Please try again later. Sayonara!"))
  411. client.kill()
  412. }
  413. if len(client.response) > 0 {
  414. log.Debugf("Writing response to client: \n%s", client.response)
  415. err := server.writeResponse(client)
  416. if err != nil {
  417. log.WithError(err).Debug("Error writing response")
  418. return
  419. }
  420. }
  421. }
  422. }