smtpd.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. package main
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "os"
  12. "strconv"
  13. "strings"
  14. "time"
  15. )
  16. const commandMaxLength = 1024
  17. type Client struct {
  18. state int
  19. helo string
  20. mail_from string
  21. rcpt_to string
  22. response string
  23. address string
  24. data string
  25. subject string
  26. hash string
  27. time int64
  28. tls_on bool
  29. conn net.Conn
  30. bufin *smtpBufferedReader
  31. bufout *bufio.Writer
  32. kill_time int64
  33. errors int
  34. clientId int64
  35. savedNotify chan int
  36. }
  37. type SmtpdServer struct {
  38. tlsConfig *tls.Config
  39. max_size int // max email DATA size
  40. timeout time.Duration
  41. allowedHosts map[string]bool
  42. sem chan int // currently active client list
  43. Config ServerConfig
  44. logger *log.Logger
  45. }
  46. func (server *SmtpdServer) logln(level int, s string) {
  47. if mainConfig.Verbose {
  48. fmt.Println(s)
  49. }
  50. // fatal errors
  51. if level == 2 {
  52. server.logger.Fatalf(s)
  53. }
  54. // warnings
  55. if level == 1 && len(server.Config.Log_file) > 0 {
  56. server.logger.Println(s)
  57. }
  58. }
  59. func (server *SmtpdServer) openLog() {
  60. server.logger = log.New(&bytes.Buffer{}, "", log.Lshortfile)
  61. // custom log file
  62. if len(server.Config.Log_file) > 0 {
  63. logfile, err := os.OpenFile(
  64. server.Config.Log_file,
  65. os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_SYNC, 0600)
  66. if err != nil {
  67. server.logln(1, fmt.Sprintf("Unable to open log file [%s]: %s ", server.Config.Log_file, err))
  68. }
  69. server.logger.SetOutput(logfile)
  70. }
  71. }
  72. // Upgrades the connection to TLS
  73. // Sets up buffers with the upgraded connection
  74. func (server *SmtpdServer) upgradeToTls(client *Client) bool {
  75. var tlsConn *tls.Conn
  76. tlsConn = tls.Server(client.conn, server.tlsConfig)
  77. err := tlsConn.Handshake()
  78. if err == nil {
  79. client.conn = net.Conn(tlsConn)
  80. client.bufin = newSmtpBufferedReader(client.conn)
  81. client.bufout = bufio.NewWriter(client.conn)
  82. client.tls_on = true
  83. return true
  84. } else {
  85. server.logln(1, fmt.Sprintf("Could not TLS handshake:%v", err))
  86. return false
  87. }
  88. }
  89. func (server *SmtpdServer) handleClient(client *Client) {
  90. defer server.closeClient(client)
  91. advertiseTls := "250-STARTTLS\r\n"
  92. if server.Config.Tls_always_on {
  93. if server.upgradeToTls(client) {
  94. advertiseTls = ""
  95. }
  96. }
  97. greeting := "220 " + server.Config.Host_name +
  98. " SMTP Guerrilla-SMTPd #" +
  99. strconv.FormatInt(client.clientId, 10) +
  100. " (" + strconv.Itoa(len(server.sem)) + ") " + time.Now().Format(time.RFC1123Z)
  101. if !server.Config.Start_tls_on {
  102. // STARTTLS turned off
  103. advertiseTls = ""
  104. }
  105. for i := 0; i < 100; i++ {
  106. switch client.state {
  107. case 0:
  108. responseAdd(client, greeting)
  109. client.state = 1
  110. case 1:
  111. client.bufin.setLimit(commandMaxLength)
  112. input, err := server.readSmtp(client)
  113. if err != nil {
  114. if err == io.EOF {
  115. // client closed the connection already
  116. server.logln(0, fmt.Sprintf("%s: %v", client.address, err))
  117. return
  118. } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
  119. // too slow, timeout
  120. server.logln(0, fmt.Sprintf("%s: %v", client.address, err))
  121. return
  122. } else if err == INPUT_LIMIT_EXCEEDED {
  123. responseAdd(client, "500 Line too long.")
  124. // kill it so that another one can connect
  125. killClient(client)
  126. }
  127. server.logln(1, fmt.Sprintf("Read error: %v", err))
  128. break
  129. }
  130. input = strings.Trim(input, " \n\r")
  131. bound := len(input)
  132. if bound > 16 {
  133. bound = 16
  134. }
  135. cmd := strings.ToUpper(input[0:bound])
  136. switch {
  137. case strings.Index(cmd, "HELO") == 0:
  138. if len(input) > 5 {
  139. client.helo = input[5:]
  140. }
  141. responseAdd(client, "250 "+server.Config.Host_name+" Hello ")
  142. case strings.Index(cmd, "EHLO") == 0:
  143. if len(input) > 5 {
  144. client.helo = input[5:]
  145. }
  146. responseAdd(client, "250-"+server.Config.Host_name+
  147. " Hello "+client.helo+"["+client.address+"]"+"\r\n"+
  148. "250-SIZE "+strconv.Itoa(server.Config.Max_size)+"\r\n"+
  149. "250-PIPELINING \r\n"+
  150. advertiseTls+"250 HELP")
  151. case strings.Index(cmd, "HELP") == 0:
  152. responseAdd(client, "250 Help! I need somebody...")
  153. case strings.Index(cmd, "MAIL FROM:") == 0:
  154. if len(input) > 10 {
  155. client.mail_from = input[10:]
  156. }
  157. responseAdd(client, "250 Ok")
  158. case strings.Index(cmd, "XCLIENT") == 0:
  159. // Nginx sends this
  160. // XCLIENT ADDR=212.96.64.216 NAME=[UNAVAILABLE]
  161. client.address = input[13:]
  162. client.address = client.address[0:strings.Index(client.address, " ")]
  163. fmt.Println("client address:[" + client.address + "]")
  164. responseAdd(client, "250 OK")
  165. case strings.Index(cmd, "RCPT TO:") == 0:
  166. if len(input) > 8 {
  167. client.rcpt_to = input[8:]
  168. }
  169. responseAdd(client, "250 Accepted")
  170. case strings.Index(cmd, "NOOP") == 0:
  171. responseAdd(client, "250 OK")
  172. case strings.Index(cmd, "RSET") == 0:
  173. client.mail_from = ""
  174. client.rcpt_to = ""
  175. responseAdd(client, "250 OK")
  176. case strings.Index(cmd, "DATA") == 0:
  177. responseAdd(client, "354 Enter message, ending with \".\" on a line by itself")
  178. client.state = 2
  179. case (strings.Index(cmd, "STARTTLS") == 0) &&
  180. !client.tls_on &&
  181. server.Config.Start_tls_on:
  182. responseAdd(client, "220 Ready to start TLS")
  183. // go to start TLS state
  184. client.state = 3
  185. case strings.Index(cmd, "QUIT") == 0:
  186. responseAdd(client, "221 Bye")
  187. killClient(client)
  188. default:
  189. responseAdd(client, "500 unrecognized command: "+cmd)
  190. client.errors++
  191. if client.errors > 3 {
  192. responseAdd(client, "500 Too many unrecognized commands")
  193. killClient(client)
  194. }
  195. }
  196. case 2:
  197. var err error
  198. client.bufin.setLimit(int64(server.Config.Max_size) + 1024000) // This is a hard limit.
  199. client.data, err = server.readSmtp(client)
  200. if err == nil {
  201. if _, _, mailErr := validateEmailData(client); mailErr == nil {
  202. // to do: timeout when adding to SaveMailChan
  203. // place on the channel so that one of the save mail workers can pick it up
  204. SaveMailChan <- &savePayload{client: client, server: server}
  205. // wait for the save to complete
  206. // or timeout
  207. select {
  208. case status := <-client.savedNotify:
  209. if status == 1 {
  210. responseAdd(client, "250 OK : queued as "+client.hash)
  211. } else {
  212. responseAdd(client, "554 Error: transaction failed, blame it on the weather")
  213. }
  214. case <-time.After(time.Second * 30):
  215. fmt.Println("timeout 1")
  216. responseAdd(client, "554 Error: transaction timeout")
  217. }
  218. } else {
  219. responseAdd(client, "550 Error: "+mailErr.Error())
  220. }
  221. } else {
  222. if (err == INPUT_LIMIT_EXCEEDED) {
  223. // hard limit reached, end to make room for other clients
  224. responseAdd(client, "550 Error: DATA limit exceeded by more than a megabyte!")
  225. killClient(client)
  226. } else {
  227. responseAdd(client, "550 Error: "+err.Error())
  228. }
  229. server.logln(1, fmt.Sprintf("DATA read error: %v", err))
  230. }
  231. client.state = 1
  232. case 3:
  233. // upgrade to TLS
  234. if server.upgradeToTls(client) {
  235. advertiseTls = ""
  236. client.state = 1
  237. }
  238. }
  239. // Send a response back to the client
  240. err := server.responseWrite(client)
  241. if err != nil {
  242. if err == io.EOF {
  243. // client closed the connection already
  244. return
  245. }
  246. if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
  247. // too slow, timeout
  248. return
  249. }
  250. }
  251. if client.kill_time > 1 {
  252. return
  253. }
  254. }
  255. }
  256. // add a response on the response buffer
  257. func responseAdd(client *Client, line string) {
  258. client.response = line + "\r\n"
  259. }
  260. func (server *SmtpdServer) closeClient(client *Client) {
  261. client.conn.Close()
  262. <-server.sem // Done; enable next client to run.
  263. }
  264. func killClient(client *Client) {
  265. client.kill_time = time.Now().Unix()
  266. }
  267. var INPUT_LIMIT_EXCEEDED = errors.New("Line too long") // 500 Line too long.
  268. // we need to adjust the limit, so we embed io.LimitedReader
  269. type adjustableLimitedReader struct {
  270. R *io.LimitedReader
  271. }
  272. // bolt this on so we can adjust the limit
  273. func (alr *adjustableLimitedReader) setLimit(n int64) {
  274. alr.R.N = n
  275. }
  276. // this just delegates to the underlying reader in order to satisfy the Reader interface
  277. // Since the vanilla limited reader returns io.EOF when the limit is reached, we need a more specific
  278. // error so that we can distinguish when a limit is reached
  279. func (alr *adjustableLimitedReader) Read(p []byte) (n int, err error) {
  280. n, err = alr.R.Read(p)
  281. if err == io.EOF && alr.R.N <= 0 {
  282. // return our custom error since std lib returns EOF
  283. err = INPUT_LIMIT_EXCEEDED
  284. }
  285. return
  286. }
  287. // allocate a new adjustableLimitedReader
  288. func newAdjustableLimitedReader(r io.Reader, n int64) *adjustableLimitedReader {
  289. lr := &io.LimitedReader{R: r, N: n}
  290. return &adjustableLimitedReader{lr}
  291. }
  292. // This is a bufio.Reader what will use our adjustable limit reader
  293. // We 'extend' buffio to have the limited reader feature
  294. type smtpBufferedReader struct {
  295. *bufio.Reader
  296. alr *adjustableLimitedReader
  297. }
  298. // delegate to the adjustable limited reader
  299. func (sbr *smtpBufferedReader) setLimit(n int64) {
  300. sbr.alr.setLimit(n)
  301. }
  302. // allocate a new smtpBufferedReader
  303. func newSmtpBufferedReader(rd io.Reader) *smtpBufferedReader {
  304. alr := newAdjustableLimitedReader(rd, commandMaxLength)
  305. s := &smtpBufferedReader{bufio.NewReader(alr), alr}
  306. return s
  307. }
  308. // Reads from the smtpBufferedReader, can be in command state or data state.
  309. func (server *SmtpdServer) readSmtp(client *Client) (input string, err error) {
  310. var reply string
  311. // Command state terminator by default
  312. suffix := "\r\n"
  313. if client.state == 2 {
  314. // DATA state ends with a dot on a line by itself
  315. suffix = "\r\n.\r\n"
  316. }
  317. for err == nil {
  318. client.conn.SetDeadline(time.Now().Add(server.timeout * time.Second))
  319. reply, err = client.bufin.ReadString('\n')
  320. if reply != "" {
  321. input = input + reply
  322. if len(input) > server.Config.Max_size {
  323. err = errors.New("Maximum DATA size exceeded (" + strconv.Itoa(server.Config.Max_size) + ")")
  324. return input, err
  325. }
  326. if client.state == 2 {
  327. // Extract the subject while we are at it.
  328. scanSubject(client, reply)
  329. }
  330. }
  331. if err != nil {
  332. break
  333. }
  334. if strings.HasSuffix(input, suffix) {
  335. break
  336. }
  337. }
  338. return input, err
  339. }
  340. // Scan the data part for a Subject line. Can be a multi-line
  341. func scanSubject(client *Client, reply string) {
  342. if client.subject == "" && (len(reply) > 8) {
  343. test := strings.ToUpper(reply[0:9])
  344. if i := strings.Index(test, "SUBJECT: "); i == 0 {
  345. // first line with \r\n
  346. client.subject = reply[9:]
  347. }
  348. } else if strings.HasSuffix(client.subject, "\r\n") {
  349. // chop off the \r\n
  350. client.subject = client.subject[0 : len(client.subject)-2]
  351. if (strings.HasPrefix(reply, " ")) || (strings.HasPrefix(reply, "\t")) {
  352. // subject is multi-line
  353. client.subject = client.subject + reply[1:]
  354. }
  355. }
  356. }
  357. func (server *SmtpdServer) responseWrite(client *Client) (err error) {
  358. var size int
  359. client.conn.SetDeadline(time.Now().Add(server.timeout * time.Second))
  360. size, err = client.bufout.WriteString(client.response)
  361. client.bufout.Flush()
  362. client.response = client.response[size:]
  363. return err
  364. }