server.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. package guerrilla
  2. import (
  3. "crypto/rand"
  4. "crypto/tls"
  5. "fmt"
  6. "io"
  7. "net"
  8. "strings"
  9. "time"
  10. "runtime"
  11. log "github.com/Sirupsen/logrus"
  12. "sync"
  13. "sync/atomic"
  14. "github.com/flashmob/go-guerrilla/backends"
  15. )
  16. const (
  17. CommandVerbMaxLength = 16
  18. CommandLineMaxLength = 1024
  19. // Number of allowed unrecognized commands before we terminate the connection
  20. MaxUnrecognizedCommands = 5
  21. // The maximum total length of a reverse-path or forward-path is 256
  22. RFC2821LimitPath = 256
  23. // The maximum total length of a user name or other local-part is 64
  24. RFC2832LimitLocalPart = 64
  25. //The maximum total length of a domain name or number is 255
  26. RFC2821LimitDomain = 255
  27. // The minimum total number of recipients that must be buffered is 100
  28. RFC2821LimitRecipients = 100
  29. )
  30. const (
  31. // server has just been created
  32. ServerStateNew = iota
  33. // Server has just been stopped
  34. ServerStateStopped
  35. // Server has been started and is running
  36. ServerStateRunning
  37. // Server could not start due to an error
  38. ServerStateStartError
  39. )
  40. // Server listens for SMTP clients on the port specified in its config
  41. type server struct {
  42. //mainConfigStore atomic.Value // stores guerrilla.Config
  43. configStore atomic.Value // stores guerrilla.ServerConfig
  44. //config *ServerConfig
  45. backend backends.Backend
  46. tlsConfig *tls.Config
  47. tlsConfigStore atomic.Value
  48. timeout atomic.Value // stores time.Duration
  49. listenInterface string
  50. clientPool *Pool
  51. wg sync.WaitGroup // for waiting to shutdown
  52. listener net.Listener
  53. closedListener chan (bool)
  54. hosts allowedHosts // stores map[string]bool for faster lookup
  55. state int
  56. }
  57. type allowedHosts struct {
  58. table map[string]bool // host lookup table
  59. m sync.Mutex // guard access to the map
  60. }
  61. // Creates and returns a new ready-to-run Server from a configuration
  62. func newServer(sc *ServerConfig, b backends.Backend) (*server, error) {
  63. server := &server{
  64. backend: b,
  65. clientPool: NewPool(sc.MaxClients),
  66. closedListener: make(chan (bool), 1),
  67. listenInterface: sc.ListenInterface,
  68. state: ServerStateNew,
  69. }
  70. server.setConfig(sc)
  71. server.setTimeout(sc.Timeout)
  72. if err := server.configureSSL(); err != nil {
  73. return server, err
  74. }
  75. return server, nil
  76. }
  77. func (s *server) configureSSL() error {
  78. sConfig := s.configStore.Load().(ServerConfig)
  79. if sConfig.TLSAlwaysOn || sConfig.StartTLSOn {
  80. cert, err := tls.LoadX509KeyPair(sConfig.PublicKeyFile, sConfig.PrivateKeyFile)
  81. if err != nil {
  82. return fmt.Errorf("error while loading the certificate: %s", err)
  83. }
  84. tlsConfig := &tls.Config{
  85. Certificates: []tls.Certificate{cert},
  86. ClientAuth: tls.VerifyClientCertIfGiven,
  87. ServerName: sConfig.Hostname,
  88. }
  89. tlsConfig.Rand = rand.Reader
  90. s.tlsConfigStore.Store(tlsConfig)
  91. }
  92. return nil
  93. }
  94. // Set the timeout for the server and all clients
  95. func (server *server) setTimeout(seconds int) {
  96. duration := time.Duration(int64(seconds))
  97. server.clientPool.SetTimeout(duration)
  98. server.timeout.Store(duration)
  99. }
  100. // goroutine safe config store
  101. func (server *server) setConfig(sc *ServerConfig) {
  102. server.configStore.Store(*sc)
  103. }
  104. // goroutine safe
  105. func (server *server) isEnabled() bool {
  106. sc := server.configStore.Load().(ServerConfig)
  107. return sc.IsEnabled
  108. }
  109. // Set the allowed hosts for the server
  110. func (server *server) setAllowedHosts(allowedHosts []string) {
  111. defer server.hosts.m.Unlock()
  112. server.hosts.m.Lock()
  113. server.hosts.table = make(map[string]bool, len(allowedHosts))
  114. for _, h := range allowedHosts {
  115. server.hosts.table[strings.ToLower(h)] = true
  116. }
  117. }
  118. // Begin accepting SMTP clients. Will block unless there is an error or server.Shutdown() is called
  119. func (server *server) Start(startWG *sync.WaitGroup) error {
  120. var clientID uint64
  121. clientID = 0
  122. listener, err := net.Listen("tcp", server.listenInterface)
  123. server.listener = listener
  124. if err != nil {
  125. startWG.Done() // don't wait for me
  126. server.state = ServerStateStartError
  127. return fmt.Errorf("[%s] Cannot listen on port: %s ", server.listenInterface, err.Error())
  128. }
  129. log.Infof("Listening on TCP %s", server.listenInterface)
  130. server.state = ServerStateRunning
  131. startWG.Done() // start successful, don't wait for me
  132. for {
  133. log.Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.listenInterface, clientID+1)
  134. conn, err := listener.Accept()
  135. clientID++
  136. if err != nil {
  137. if e, ok := err.(net.Error); ok && !e.Temporary() {
  138. log.Infof("Server [%s] has stopped accepting new clients", server.listenInterface)
  139. // the listener has been closed, wait for clients to exit
  140. log.Infof("shutting down pool [%s]", server.listenInterface)
  141. server.clientPool.ShutdownState()
  142. server.clientPool.ShutdownWait()
  143. server.state = ServerStateStopped
  144. server.closedListener <- true
  145. return nil
  146. }
  147. log.WithError(err).Info("Temporary error accepting client")
  148. continue
  149. }
  150. go func(p Poolable, borrow_err error) {
  151. c := p.(*client)
  152. if borrow_err == nil {
  153. server.handleClient(c)
  154. server.clientPool.Return(c)
  155. } else {
  156. log.WithError(borrow_err).Info("couldn't borrow a new client")
  157. // we could not get a client, so close the connection.
  158. conn.Close()
  159. }
  160. // intentionally placed Borrow in args so that it's called in the
  161. // same main goroutine.
  162. }(server.clientPool.Borrow(conn, clientID))
  163. }
  164. }
  165. func (server *server) Shutdown() {
  166. if server.listener != nil {
  167. // This will cause Start function to return, by causing an error on listener.Accept
  168. server.listener.Close()
  169. // wait for the listener to listener.Accept
  170. <-server.closedListener
  171. // At this point Start will exit and close down the pool
  172. } else {
  173. server.clientPool.ShutdownState()
  174. // listener already closed, wait for clients to exit
  175. server.clientPool.ShutdownWait()
  176. server.state = ServerStateStopped
  177. }
  178. }
  179. func (server *server) GetActiveClientsCount() int {
  180. return server.clientPool.GetActiveClientsCount()
  181. }
  182. // Verifies that the host is a valid recipient.
  183. func (server *server) allowsHost(host string) bool {
  184. defer server.hosts.m.Unlock()
  185. server.hosts.m.Lock()
  186. if _, ok := server.hosts.table[strings.ToLower(host)]; ok {
  187. return true
  188. }
  189. return false
  190. }
  191. // Reads from the client until a terminating sequence is encountered,
  192. // or until a timeout occurs.
  193. func (server *server) read(client *client, maxSize int64) (string, error) {
  194. var input, reply string
  195. var err error
  196. // In command state, stop reading at line breaks
  197. suffix := "\r\n"
  198. if client.state == ClientData {
  199. // In data state, stop reading at solo periods
  200. suffix = "\r\n.\r\n"
  201. }
  202. for {
  203. client.setTimeout(server.timeout.Load().(time.Duration))
  204. reply, err = client.bufin.ReadString('\n')
  205. input = input + reply
  206. if err == nil && client.state == ClientData {
  207. if reply != "" {
  208. // Extract the subject while we're at it
  209. client.scanSubject(reply)
  210. }
  211. if int64(len(input)) > maxSize {
  212. return input, fmt.Errorf("Maximum DATA size exceeded (%d)", maxSize)
  213. }
  214. }
  215. if err != nil {
  216. break
  217. }
  218. if strings.HasSuffix(input, suffix) {
  219. // discard the suffix and stop reading
  220. input = input[0 : len(input)-len(suffix)]
  221. break
  222. }
  223. }
  224. return input, err
  225. }
  226. // Writes a response to the client.
  227. func (server *server) writeResponse(client *client) error {
  228. client.setTimeout(server.timeout.Load().(time.Duration))
  229. size, err := client.bufout.WriteString(client.response)
  230. if err != nil {
  231. return err
  232. }
  233. err = client.bufout.Flush()
  234. if err != nil {
  235. return err
  236. }
  237. client.response = client.response[size:]
  238. return nil
  239. }
  240. func (server *server) isShuttingDown() bool {
  241. return server.clientPool.IsShuttingDown()
  242. }
  243. // Handles an entire client SMTP exchange
  244. func (server *server) handleClient(client *client) {
  245. defer client.closeConn()
  246. sc := server.configStore.Load().(ServerConfig)
  247. log.Infof("Handle client [%s], id: %d", client.RemoteAddress, client.ID)
  248. // Initial greeting
  249. greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s gr:%d",
  250. sc.Hostname, Version, client.ID,
  251. server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339), runtime.NumGoroutine())
  252. helo := fmt.Sprintf("250 %s Hello", sc.Hostname)
  253. // ehlo is a multi-line reply and need additional \r\n at the end
  254. ehlo := fmt.Sprintf("250-%s Hello\r\n", sc.Hostname)
  255. // Extended feature advertisements
  256. messageSize := fmt.Sprintf("250-SIZE %d\r\n", sc.MaxSize)
  257. pipelining := "250-PIPELINING\r\n"
  258. advertiseTLS := "250-STARTTLS\r\n"
  259. // the last line doesn't need \r\n since string will be printed as a new line
  260. help := "250 HELP"
  261. if sc.TLSAlwaysOn {
  262. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  263. if ok && client.upgradeToTLS(tlsConfig) {
  264. advertiseTLS = ""
  265. } else {
  266. // server requires TLS, but can't handshake
  267. client.kill()
  268. }
  269. }
  270. if !sc.StartTLSOn {
  271. // STARTTLS turned off, don't advertise it
  272. advertiseTLS = ""
  273. }
  274. for client.isAlive() {
  275. switch client.state {
  276. case ClientGreeting:
  277. client.responseAdd(greeting)
  278. client.state = ClientCmd
  279. case ClientCmd:
  280. client.bufin.setLimit(CommandLineMaxLength)
  281. input, err := server.read(client, sc.MaxSize)
  282. log.Debugf("Client sent: %s", input)
  283. if err == io.EOF {
  284. log.WithError(err).Warnf("Client closed the connection: %s", client.RemoteAddress)
  285. return
  286. } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
  287. log.WithError(err).Warnf("Timeout: %s", client.RemoteAddress)
  288. return
  289. } else if err == LineLimitExceeded {
  290. client.responseAdd("500 Line too long.")
  291. client.kill()
  292. break
  293. } else if err != nil {
  294. log.WithError(err).Warnf("Read error: %s", client.RemoteAddress)
  295. client.kill()
  296. break
  297. }
  298. if server.isShuttingDown() {
  299. client.state = ClientShutdown
  300. continue
  301. }
  302. input = strings.Trim(input, " \r\n")
  303. cmdLen := len(input)
  304. if cmdLen > CommandVerbMaxLength {
  305. cmdLen = CommandVerbMaxLength
  306. }
  307. cmd := strings.ToUpper(input[:cmdLen])
  308. switch {
  309. case strings.Index(cmd, "HELO") == 0:
  310. client.Helo = strings.Trim(input[4:], " ")
  311. client.resetTransaction()
  312. client.responseAdd(helo)
  313. case strings.Index(cmd, "EHLO") == 0:
  314. client.Helo = strings.Trim(input[4:], " ")
  315. client.resetTransaction()
  316. client.responseAdd(ehlo + messageSize + pipelining + advertiseTLS + help)
  317. case strings.Index(cmd, "HELP") == 0:
  318. client.responseAdd("214 OK\r\n" + messageSize + pipelining + advertiseTLS + help)
  319. case strings.Index(cmd, "MAIL FROM:") == 0:
  320. if client.isInTransaction() {
  321. client.responseAdd("503 Error: nested MAIL command")
  322. break
  323. }
  324. from, err := extractEmail(input[10:])
  325. if err != nil {
  326. client.responseAdd(err.Error())
  327. } else {
  328. client.MailFrom = from
  329. client.responseAdd("250 OK")
  330. }
  331. case strings.Index(cmd, "RCPT TO:") == 0:
  332. if len(client.RcptTo) > RFC2821LimitRecipients {
  333. client.responseAdd("452 Too many recipients")
  334. break
  335. }
  336. to, err := extractEmail(input[8:])
  337. if err != nil {
  338. client.responseAdd(err.Error())
  339. } else {
  340. if !server.allowsHost(to.Host) {
  341. client.responseAdd("454 Error: Relay access denied: " + to.Host)
  342. } else {
  343. client.RcptTo = append(client.RcptTo, *to)
  344. client.responseAdd("250 OK")
  345. }
  346. }
  347. case strings.Index(cmd, "RSET") == 0:
  348. client.resetTransaction()
  349. client.responseAdd("250 OK")
  350. case strings.Index(cmd, "VRFY") == 0:
  351. client.responseAdd("252 Cannot verify user")
  352. case strings.Index(cmd, "NOOP") == 0:
  353. client.responseAdd("250 OK")
  354. case strings.Index(cmd, "QUIT") == 0:
  355. client.responseAdd("221 Bye")
  356. client.kill()
  357. case strings.Index(cmd, "DATA") == 0:
  358. if client.MailFrom.IsEmpty() {
  359. client.responseAdd("503 Error: No sender")
  360. break
  361. }
  362. if len(client.RcptTo) == 0 {
  363. client.responseAdd("503 Error: No recipients")
  364. break
  365. }
  366. client.responseAdd("354 Enter message, ending with '.' on a line by itself")
  367. client.state = ClientData
  368. case sc.StartTLSOn && strings.Index(cmd, "STARTTLS") == 0:
  369. client.responseAdd("220 Ready to start TLS")
  370. client.state = ClientStartTLS
  371. default:
  372. client.responseAdd("500 Unrecognized command: " + cmd)
  373. client.errors++
  374. if client.errors > MaxUnrecognizedCommands {
  375. client.responseAdd("554 Too many unrecognized commands")
  376. client.kill()
  377. }
  378. }
  379. case ClientData:
  380. var err error
  381. // intentionally placed the limit 1MB above so that reading does not return with an error
  382. // if the client goes a little over. Anything above will err
  383. client.bufin.setLimit(int64(sc.MaxSize) + 1024000) // This a hard limit.
  384. client.Data, err = server.read(client, sc.MaxSize)
  385. if err != nil {
  386. if err == LineLimitExceeded {
  387. client.responseAdd("550 Error: " + LineLimitExceeded.Error())
  388. client.kill()
  389. } else if err == MessageSizeExceeded {
  390. client.responseAdd("550 Error: " + MessageSizeExceeded.Error())
  391. client.kill()
  392. } else {
  393. client.kill()
  394. client.responseAdd("451 Error: " + err.Error())
  395. }
  396. log.WithError(err).Warn("Error reading data")
  397. break
  398. }
  399. res := server.backend.Process(client.Envelope)
  400. if res.Code() < 300 {
  401. client.messagesSent++
  402. }
  403. client.responseAdd(res.String())
  404. client.state = ClientCmd
  405. if server.isShuttingDown() {
  406. client.state = ClientShutdown
  407. }
  408. client.resetTransaction()
  409. case ClientStartTLS:
  410. if !client.TLS && sc.StartTLSOn {
  411. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  412. if ok && client.upgradeToTLS(tlsConfig) {
  413. advertiseTLS = ""
  414. client.resetTransaction()
  415. }
  416. }
  417. // change to command state
  418. client.state = ClientCmd
  419. case ClientShutdown:
  420. // shutdown state
  421. client.responseAdd("421 Server is shutting down. Please try again later. Sayonara!")
  422. client.kill()
  423. }
  424. if len(client.response) > 0 {
  425. log.Debugf("Writing response to client: \n%s", client.response)
  426. err := server.writeResponse(client)
  427. if err != nil {
  428. log.WithError(err).Debug("Error writing response")
  429. return
  430. }
  431. }
  432. }
  433. }