server.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. package guerrilla
  2. import (
  3. "crypto/rand"
  4. "crypto/tls"
  5. "fmt"
  6. "io"
  7. "net"
  8. "runtime"
  9. "strings"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/flashmob/go-guerrilla/backends"
  14. "github.com/flashmob/go-guerrilla/envelope"
  15. "github.com/flashmob/go-guerrilla/log"
  16. "github.com/flashmob/go-guerrilla/response"
  17. )
  18. const (
  19. CommandVerbMaxLength = 16
  20. CommandLineMaxLength = 1024
  21. // Number of allowed unrecognized commands before we terminate the connection
  22. MaxUnrecognizedCommands = 5
  23. // The maximum total length of a reverse-path or forward-path is 256
  24. RFC2821LimitPath = 256
  25. // The maximum total length of a user name or other local-part is 64
  26. RFC2832LimitLocalPart = 64
  27. //The maximum total length of a domain name or number is 255
  28. RFC2821LimitDomain = 255
  29. // The minimum total number of recipients that must be buffered is 100
  30. RFC2821LimitRecipients = 100
  31. )
  32. const (
  33. // server has just been created
  34. ServerStateNew = iota
  35. // Server has just been stopped
  36. ServerStateStopped
  37. // Server has been started and is running
  38. ServerStateRunning
  39. // Server could not start due to an error
  40. ServerStateStartError
  41. )
  42. // Server listens for SMTP clients on the port specified in its config
  43. type server struct {
  44. configStore atomic.Value // stores guerrilla.ServerConfig
  45. backend backends.Backend
  46. tlsConfigStore atomic.Value
  47. timeout atomic.Value // stores time.Duration
  48. listenInterface string
  49. clientPool *Pool
  50. wg sync.WaitGroup // for waiting to shutdown
  51. listener net.Listener
  52. closedListener chan (bool)
  53. hosts allowedHosts // stores map[string]bool for faster lookup
  54. state int
  55. mainlog log.Logger
  56. log log.Logger
  57. // If log changed after a config reload, newLogStore stores the value here until it's safe to change it
  58. logStore atomic.Value
  59. mainlogStore atomic.Value
  60. }
  61. type allowedHosts struct {
  62. table map[string]bool // host lookup table
  63. m sync.Mutex // guard access to the map
  64. }
  65. // Creates and returns a new ready-to-run Server from a configuration
  66. func newServer(sc *ServerConfig, b backends.Backend, l log.Logger) (*server, error) {
  67. server := &server{
  68. backend: b,
  69. clientPool: NewPool(sc.MaxClients),
  70. closedListener: make(chan (bool), 1),
  71. listenInterface: sc.ListenInterface,
  72. state: ServerStateNew,
  73. mainlog: l,
  74. }
  75. var logOpenError error
  76. if sc.LogFile == "" {
  77. // none set, use the same log file as mainlog
  78. server.log, logOpenError = log.GetLogger(server.mainlog.GetLogDest())
  79. } else {
  80. server.log, logOpenError = log.GetLogger(sc.LogFile)
  81. }
  82. if logOpenError != nil {
  83. server.log.WithError(logOpenError).Errorf("Failed creating a logger for server [%s]", sc.ListenInterface)
  84. }
  85. // set to same level
  86. server.log.SetLevel(server.mainlog.GetLevel())
  87. server.setConfig(sc)
  88. server.setTimeout(sc.Timeout)
  89. if err := server.configureSSL(); err != nil {
  90. return server, err
  91. }
  92. return server, nil
  93. }
  94. func (s *server) configureSSL() error {
  95. sConfig := s.configStore.Load().(ServerConfig)
  96. if sConfig.TLSAlwaysOn || sConfig.StartTLSOn {
  97. cert, err := tls.LoadX509KeyPair(sConfig.PublicKeyFile, sConfig.PrivateKeyFile)
  98. if err != nil {
  99. return fmt.Errorf("error while loading the certificate: %s", err)
  100. }
  101. tlsConfig := &tls.Config{
  102. Certificates: []tls.Certificate{cert},
  103. ClientAuth: tls.VerifyClientCertIfGiven,
  104. ServerName: sConfig.Hostname,
  105. }
  106. tlsConfig.Rand = rand.Reader
  107. s.tlsConfigStore.Store(tlsConfig)
  108. }
  109. return nil
  110. }
  111. // configureLog checks to see if there is a new logger, so that the server.log can be safely changed
  112. // this function is not gorotine safe, although it'll read the new value safely
  113. func (s *server) configureLog() {
  114. // when log changed
  115. if l, ok := s.logStore.Load().(log.Logger); ok {
  116. if l != s.log {
  117. s.log = l
  118. }
  119. }
  120. // when mainlog changed
  121. if ml, ok := s.mainlogStore.Load().(log.Logger); ok {
  122. if ml != s.mainlog {
  123. s.mainlog = ml
  124. }
  125. }
  126. }
  127. // Set the timeout for the server and all clients
  128. func (server *server) setTimeout(seconds int) {
  129. duration := time.Duration(int64(seconds))
  130. server.clientPool.SetTimeout(duration)
  131. server.timeout.Store(duration)
  132. }
  133. // goroutine safe config store
  134. func (server *server) setConfig(sc *ServerConfig) {
  135. server.configStore.Store(*sc)
  136. }
  137. // goroutine safe
  138. func (server *server) isEnabled() bool {
  139. sc := server.configStore.Load().(ServerConfig)
  140. return sc.IsEnabled
  141. }
  142. // Set the allowed hosts for the server
  143. func (server *server) setAllowedHosts(allowedHosts []string) {
  144. defer server.hosts.m.Unlock()
  145. server.hosts.m.Lock()
  146. server.hosts.table = make(map[string]bool, len(allowedHosts))
  147. for _, h := range allowedHosts {
  148. server.hosts.table[strings.ToLower(h)] = true
  149. }
  150. }
  151. // Begin accepting SMTP clients. Will block unless there is an error or server.Shutdown() is called
  152. func (server *server) Start(startWG *sync.WaitGroup) error {
  153. var clientID uint64
  154. clientID = 0
  155. listener, err := net.Listen("tcp", server.listenInterface)
  156. server.listener = listener
  157. if err != nil {
  158. startWG.Done() // don't wait for me
  159. server.state = ServerStateStartError
  160. return fmt.Errorf("[%s] Cannot listen on port: %s ", server.listenInterface, err.Error())
  161. }
  162. server.log.Infof("Listening on TCP %s", server.listenInterface)
  163. server.state = ServerStateRunning
  164. startWG.Done() // start successful, don't wait for me
  165. for {
  166. server.log.Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.listenInterface, clientID+1)
  167. conn, err := listener.Accept()
  168. server.configureLog()
  169. clientID++
  170. if err != nil {
  171. if e, ok := err.(net.Error); ok && !e.Temporary() {
  172. server.log.Infof("Server [%s] has stopped accepting new clients", server.listenInterface)
  173. // the listener has been closed, wait for clients to exit
  174. server.log.Infof("shutting down pool [%s]", server.listenInterface)
  175. server.clientPool.ShutdownState()
  176. server.clientPool.ShutdownWait()
  177. server.state = ServerStateStopped
  178. server.closedListener <- true
  179. return nil
  180. }
  181. server.mainlog.WithError(err).Info("Temporary error accepting client")
  182. continue
  183. }
  184. go func(p Poolable, borrow_err error) {
  185. c := p.(*client)
  186. if borrow_err == nil {
  187. server.handleClient(c)
  188. server.clientPool.Return(c)
  189. } else {
  190. server.log.WithError(borrow_err).Info("couldn't borrow a new client")
  191. // we could not get a client, so close the connection.
  192. conn.Close()
  193. }
  194. // intentionally placed Borrow in args so that it's called in the
  195. // same main goroutine.
  196. }(server.clientPool.Borrow(conn, clientID, server.log))
  197. }
  198. }
  199. func (server *server) Shutdown() {
  200. if server.listener != nil {
  201. // This will cause Start function to return, by causing an error on listener.Accept
  202. server.listener.Close()
  203. // wait for the listener to listener.Accept
  204. <-server.closedListener
  205. // At this point Start will exit and close down the pool
  206. } else {
  207. server.clientPool.ShutdownState()
  208. // listener already closed, wait for clients to exit
  209. server.clientPool.ShutdownWait()
  210. server.state = ServerStateStopped
  211. }
  212. }
  213. func (server *server) GetActiveClientsCount() int {
  214. return server.clientPool.GetActiveClientsCount()
  215. }
  216. // Verifies that the host is a valid recipient.
  217. func (server *server) allowsHost(host string) bool {
  218. defer server.hosts.m.Unlock()
  219. server.hosts.m.Lock()
  220. if _, ok := server.hosts.table[strings.ToLower(host)]; ok {
  221. return true
  222. }
  223. return false
  224. }
  225. // Reads from the client until a terminating sequence is encountered,
  226. // or until a timeout occurs.
  227. func (server *server) readCommand(client *client, maxSize int64) (string, error) {
  228. var input, reply string
  229. var err error
  230. // In command state, stop reading at line breaks
  231. suffix := "\r\n"
  232. for {
  233. client.setTimeout(server.timeout.Load().(time.Duration))
  234. reply, err = client.bufin.ReadString('\n')
  235. input = input + reply
  236. if err != nil {
  237. break
  238. }
  239. if strings.HasSuffix(input, suffix) {
  240. // discard the suffix and stop reading
  241. input = input[0 : len(input)-len(suffix)]
  242. break
  243. }
  244. }
  245. return input, err
  246. }
  247. // flushResponse a response to the client. Flushes the client.bufout buffer to the connection
  248. func (server *server) flushResponse(client *client) error {
  249. client.setTimeout(server.timeout.Load().(time.Duration))
  250. return client.bufout.Flush()
  251. }
  252. func (server *server) isShuttingDown() bool {
  253. return server.clientPool.IsShuttingDown()
  254. }
  255. // Handles an entire client SMTP exchange
  256. func (server *server) handleClient(client *client) {
  257. defer func() {
  258. server.log.WithFields(map[string]interface{}{
  259. "event": "disconnect",
  260. "id": client.ID,
  261. }).Info("Disconnect client")
  262. client.closeConn()
  263. }()
  264. sc := server.configStore.Load().(ServerConfig)
  265. server.log.WithFields(map[string]interface{}{
  266. "event": "connect",
  267. "id": client.ID,
  268. }).Info("Handle client")
  269. // Initial greeting
  270. greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s gr:%d",
  271. sc.Hostname, Version, client.ID,
  272. server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339), runtime.NumGoroutine())
  273. helo := fmt.Sprintf("250 %s Hello", sc.Hostname)
  274. // ehlo is a multi-line reply and need additional \r\n at the end
  275. ehlo := fmt.Sprintf("250-%s Hello\r\n", sc.Hostname)
  276. // Extended feature advertisements
  277. messageSize := fmt.Sprintf("250-SIZE %d\r\n", sc.MaxSize)
  278. pipelining := "250-PIPELINING\r\n"
  279. advertiseTLS := "250-STARTTLS\r\n"
  280. advertiseEnhancedStatusCodes := "250-ENHANCEDSTATUSCODES\r\n"
  281. // The last line doesn't need \r\n since string will be printed as a new line.
  282. // Also, Last line has no dash -
  283. help := "250 HELP"
  284. if sc.TLSAlwaysOn {
  285. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  286. if !ok {
  287. server.mainlog.Error("Failed to load *tls.Config")
  288. } else if err := client.upgradeToTLS(tlsConfig); err == nil {
  289. advertiseTLS = ""
  290. } else {
  291. server.log.WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteAddress)
  292. // server requires TLS, but can't handshake
  293. client.kill()
  294. }
  295. }
  296. if !sc.StartTLSOn {
  297. // STARTTLS turned off, don't advertise it
  298. advertiseTLS = ""
  299. }
  300. for client.isAlive() {
  301. switch client.state {
  302. case ClientGreeting:
  303. client.sendResponse(greeting)
  304. client.state = ClientCmd
  305. case ClientCmd:
  306. client.bufin.setLimit(CommandLineMaxLength)
  307. input, err := server.readCommand(client, sc.MaxSize)
  308. server.log.Debugf("Client sent: %s", input)
  309. if err == io.EOF {
  310. server.log.WithError(err).Warnf("Client closed the connection: %s", client.RemoteAddress)
  311. return
  312. } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
  313. server.log.WithError(err).Warnf("Timeout: %s", client.RemoteAddress)
  314. return
  315. } else if err == LineLimitExceeded {
  316. client.sendResponse(response.Canned.FailLineTooLong)
  317. client.kill()
  318. break
  319. } else if err != nil {
  320. server.log.WithError(err).Warnf("Read error: %s", client.RemoteAddress)
  321. client.kill()
  322. break
  323. }
  324. if server.isShuttingDown() {
  325. client.state = ClientShutdown
  326. continue
  327. }
  328. input = strings.Trim(input, " \r\n")
  329. cmdLen := len(input)
  330. if cmdLen > CommandVerbMaxLength {
  331. cmdLen = CommandVerbMaxLength
  332. }
  333. cmd := strings.ToUpper(input[:cmdLen])
  334. switch {
  335. case strings.Index(cmd, "HELO") == 0:
  336. client.Helo = strings.Trim(input[4:], " ")
  337. client.resetTransaction()
  338. client.sendResponse(helo)
  339. case strings.Index(cmd, "EHLO") == 0:
  340. client.Helo = strings.Trim(input[4:], " ")
  341. client.resetTransaction()
  342. client.sendResponse(ehlo,
  343. messageSize,
  344. pipelining,
  345. advertiseTLS,
  346. advertiseEnhancedStatusCodes,
  347. help)
  348. case strings.Index(cmd, "HELP") == 0:
  349. quote := response.GetQuote()
  350. client.sendResponse("214-OK\r\n" + quote)
  351. case strings.Index(cmd, "MAIL FROM:") == 0:
  352. if client.isInTransaction() {
  353. client.sendResponse(response.Canned.FailNestedMailCmd)
  354. break
  355. }
  356. mail := input[10:]
  357. from := envelope.EmailAddress{}
  358. if !(strings.Index(mail, "<>") == 0) &&
  359. !(strings.Index(mail, " <>") == 0) {
  360. // Not Bounce, extract mail.
  361. from, err = extractEmail(mail)
  362. }
  363. if err != nil {
  364. client.sendResponse(err)
  365. } else {
  366. server.log.WithFields(map[string]interface{}{
  367. "event": "mailfrom",
  368. "helo": client.Helo,
  369. "domain": from.Host,
  370. "address": client.RemoteAddress,
  371. "id": client.ID,
  372. }).Info("Mail from")
  373. client.MailFrom = from
  374. client.sendResponse(response.Canned.SuccessMailCmd)
  375. }
  376. case strings.Index(cmd, "RCPT TO:") == 0:
  377. if len(client.RcptTo) > RFC2821LimitRecipients {
  378. client.sendResponse(response.Canned.ErrorTooManyRecipients)
  379. break
  380. }
  381. to, err := extractEmail(input[8:])
  382. if err != nil {
  383. client.sendResponse(err.Error())
  384. } else {
  385. if !server.allowsHost(to.Host) {
  386. client.sendResponse(response.Canned.ErrorRelayDenied, to.Host)
  387. } else {
  388. client.RcptTo = append(client.RcptTo, to)
  389. client.sendResponse(response.Canned.SuccessRcptCmd)
  390. }
  391. }
  392. case strings.Index(cmd, "RSET") == 0:
  393. client.resetTransaction()
  394. client.sendResponse(response.Canned.SuccessResetCmd)
  395. case strings.Index(cmd, "VRFY") == 0:
  396. client.sendResponse(response.Canned.SuccessVerifyCmd)
  397. case strings.Index(cmd, "NOOP") == 0:
  398. client.sendResponse(response.Canned.SuccessNoopCmd)
  399. case strings.Index(cmd, "QUIT") == 0:
  400. client.sendResponse(response.Canned.SuccessQuitCmd)
  401. client.kill()
  402. case strings.Index(cmd, "DATA") == 0:
  403. if client.MailFrom.IsEmpty() {
  404. client.sendResponse(response.Canned.FailNoSenderDataCmd)
  405. break
  406. }
  407. if len(client.RcptTo) == 0 {
  408. client.sendResponse(response.Canned.FailNoRecipientsDataCmd)
  409. break
  410. }
  411. client.sendResponse(response.Canned.SuccessDataCmd)
  412. client.state = ClientData
  413. case sc.StartTLSOn && strings.Index(cmd, "STARTTLS") == 0:
  414. client.sendResponse(response.Canned.SuccessStartTLSCmd)
  415. client.state = ClientStartTLS
  416. default:
  417. client.errors++
  418. if client.errors >= MaxUnrecognizedCommands {
  419. client.sendResponse(response.Canned.FailMaxUnrecognizedCmd)
  420. client.kill()
  421. } else {
  422. client.sendResponse(response.Canned.FailUnrecognizedCmd)
  423. }
  424. }
  425. case ClientData:
  426. // intentionally placed the limit 1MB above so that reading does not return with an error
  427. // if the client goes a little over. Anything above will err
  428. client.bufin.setLimit(int64(sc.MaxSize) + 1024000) // This a hard limit.
  429. n, err := client.Data.ReadFrom(client.smtpReader.DotReader())
  430. if n > sc.MaxSize {
  431. err = fmt.Errorf("Maximum DATA size exceeded (%d)", sc.MaxSize)
  432. }
  433. if err != nil {
  434. if err == LineLimitExceeded {
  435. client.sendResponse(response.Canned.FailReadLimitExceededDataCmd, LineLimitExceeded.Error())
  436. client.kill()
  437. } else if err == MessageSizeExceeded {
  438. client.sendResponse(response.Canned.FailMessageSizeExceeded, MessageSizeExceeded.Error())
  439. client.kill()
  440. } else {
  441. client.sendResponse(response.Canned.FailReadErrorDataCmd, err.Error())
  442. client.kill()
  443. }
  444. server.log.WithError(err).Warn("Error reading data")
  445. break
  446. }
  447. res := server.backend.Process(client.Envelope)
  448. if res.Code() < 300 {
  449. client.messagesSent++
  450. server.log.WithFields(map[string]interface{}{
  451. "helo": client.Helo,
  452. "remoteAddress": client.RemoteAddress,
  453. "success": true,
  454. }).Info("Received message")
  455. }
  456. client.sendResponse(res.String())
  457. client.state = ClientCmd
  458. if server.isShuttingDown() {
  459. client.state = ClientShutdown
  460. }
  461. client.resetTransaction()
  462. case ClientStartTLS:
  463. if !client.TLS && sc.StartTLSOn {
  464. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  465. if !ok {
  466. server.mainlog.Error("Failed to load *tls.Config")
  467. } else if err := client.upgradeToTLS(tlsConfig); err == nil {
  468. advertiseTLS = ""
  469. client.resetTransaction()
  470. } else {
  471. server.log.WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteAddress)
  472. // Don't disconnect, let the client decide if it wants to continue
  473. }
  474. }
  475. // change to command state
  476. client.state = ClientCmd
  477. case ClientShutdown:
  478. // shutdown state
  479. client.sendResponse(response.Canned.ErrorShutdown)
  480. client.kill()
  481. }
  482. if client.bufout.Buffered() > 0 {
  483. if server.log.IsDebug() {
  484. server.log.Debugf("Writing response to client: \n%s", client.response.String())
  485. }
  486. err := server.flushResponse(client)
  487. if err != nil {
  488. server.log.WithError(err).Debug("Error writing response")
  489. return
  490. }
  491. }
  492. }
  493. }