querysrv.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. // Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
  2. package main
  3. import (
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "encoding/json"
  8. "encoding/pem"
  9. "fmt"
  10. "log"
  11. "math/rand"
  12. "net"
  13. "net/http"
  14. "net/url"
  15. "strconv"
  16. "sync"
  17. "time"
  18. "github.com/golang/groupcache/lru"
  19. "github.com/juju/ratelimit"
  20. "github.com/syncthing/syncthing/lib/protocol"
  21. "golang.org/x/net/context"
  22. )
  23. type querysrv struct {
  24. addr string
  25. db *sql.DB
  26. prep map[string]*sql.Stmt
  27. limiter *safeCache
  28. cert tls.Certificate
  29. listener net.Listener
  30. }
  31. type announcement struct {
  32. Seen time.Time `json:"seen"`
  33. Addresses []string `json:"addresses"`
  34. }
  35. type safeCache struct {
  36. *lru.Cache
  37. mut sync.Mutex
  38. }
  39. func (s *safeCache) Get(key string) (val interface{}, ok bool) {
  40. s.mut.Lock()
  41. val, ok = s.Cache.Get(key)
  42. s.mut.Unlock()
  43. return
  44. }
  45. func (s *safeCache) Add(key string, val interface{}) {
  46. s.mut.Lock()
  47. s.Cache.Add(key, val)
  48. s.mut.Unlock()
  49. }
  50. type requestID int64
  51. func (i requestID) String() string {
  52. return fmt.Sprintf("%016x", int64(i))
  53. }
  54. func negCacheFor(lastSeen time.Time) int {
  55. since := time.Since(lastSeen).Seconds()
  56. if since >= maxDeviceAge {
  57. return maxNegCache
  58. }
  59. if since < 0 {
  60. // That's weird
  61. return minNegCache
  62. }
  63. // Return a value linearly scaled from minNegCache (at zero seconds ago)
  64. // to maxNegCache (at maxDeviceAge seconds ago).
  65. r := since / maxDeviceAge
  66. return int(minNegCache + r*(maxNegCache-minNegCache))
  67. }
  68. func (s *querysrv) Serve() {
  69. s.limiter = &safeCache{
  70. Cache: lru.New(lruSize),
  71. }
  72. if useHTTP {
  73. listener, err := net.Listen("tcp", s.addr)
  74. if err != nil {
  75. log.Println("Listen:", err)
  76. return
  77. }
  78. s.listener = listener
  79. } else {
  80. tlsCfg := &tls.Config{
  81. Certificates: []tls.Certificate{s.cert},
  82. ClientAuth: tls.RequestClientCert,
  83. SessionTicketsDisabled: true,
  84. MinVersion: tls.VersionTLS12,
  85. CipherSuites: []uint16{
  86. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  87. tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  88. tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
  89. tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
  90. tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
  91. tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
  92. },
  93. }
  94. tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg)
  95. if err != nil {
  96. log.Println("Listen:", err)
  97. return
  98. }
  99. s.listener = tlsListener
  100. }
  101. http.HandleFunc("/v2/", s.handler)
  102. http.HandleFunc("/ping", handlePing)
  103. srv := &http.Server{
  104. ReadTimeout: 5 * time.Second,
  105. WriteTimeout: 5 * time.Second,
  106. MaxHeaderBytes: 1 << 10,
  107. }
  108. if err := srv.Serve(s.listener); err != nil {
  109. log.Println("Serve:", err)
  110. }
  111. }
  112. var topCtx = context.Background()
  113. func (s *querysrv) handler(w http.ResponseWriter, req *http.Request) {
  114. reqID := requestID(rand.Int63())
  115. ctx := context.WithValue(topCtx, "id", reqID)
  116. if debug {
  117. log.Println(reqID, req.Method, req.URL)
  118. }
  119. t0 := time.Now()
  120. defer func() {
  121. diff := time.Since(t0)
  122. var comment string
  123. if diff > time.Second {
  124. comment = "(very slow request)"
  125. } else if diff > 100*time.Millisecond {
  126. comment = "(slow request)"
  127. }
  128. if comment != "" || debug {
  129. log.Println(reqID, req.Method, req.URL, "completed in", diff, comment)
  130. }
  131. }()
  132. var remoteIP net.IP
  133. if useHTTP {
  134. remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
  135. } else {
  136. addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr)
  137. if err != nil {
  138. log.Println("remoteAddr:", err)
  139. http.Error(w, "Internal Server Error", http.StatusInternalServerError)
  140. return
  141. }
  142. remoteIP = addr.IP
  143. }
  144. if s.limit(remoteIP) {
  145. if debug {
  146. log.Println(remoteIP, "is limited")
  147. }
  148. w.Header().Set("Retry-After", "60")
  149. http.Error(w, "Too Many Requests", 429)
  150. return
  151. }
  152. switch req.Method {
  153. case "GET":
  154. s.handleGET(ctx, w, req)
  155. case "POST":
  156. s.handlePOST(ctx, remoteIP, w, req)
  157. default:
  158. globalStats.Error()
  159. http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
  160. }
  161. }
  162. func (s *querysrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http.Request) {
  163. reqID := ctx.Value("id").(requestID)
  164. deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device"))
  165. if err != nil {
  166. if debug {
  167. log.Println(reqID, "bad device param")
  168. }
  169. globalStats.Error()
  170. http.Error(w, "Bad Request", http.StatusBadRequest)
  171. return
  172. }
  173. var ann announcement
  174. ann.Seen, err = s.getDeviceSeen(deviceID)
  175. negCache := strconv.Itoa(negCacheFor(ann.Seen))
  176. w.Header().Set("Retry-After", negCache)
  177. w.Header().Set("Cache-Control", "public, max-age="+negCache)
  178. if err != nil {
  179. // The device is not in the database.
  180. globalStats.Query()
  181. http.Error(w, "Not Found", http.StatusNotFound)
  182. return
  183. }
  184. t0 := time.Now()
  185. ann.Addresses, err = s.getAddresses(ctx, deviceID)
  186. if err != nil {
  187. log.Println(reqID, "getAddresses:", err)
  188. globalStats.Error()
  189. http.Error(w, "Internal Server Error", http.StatusInternalServerError)
  190. return
  191. }
  192. if debug {
  193. log.Println(reqID, "getAddresses in", time.Since(t0))
  194. }
  195. globalStats.Query()
  196. if len(ann.Addresses) == 0 {
  197. http.Error(w, "Not Found", http.StatusNotFound)
  198. return
  199. }
  200. globalStats.Answer()
  201. w.Header().Set("Content-Type", "application/json")
  202. json.NewEncoder(w).Encode(ann)
  203. }
  204. func (s *querysrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.ResponseWriter, req *http.Request) {
  205. reqID := ctx.Value("id").(requestID)
  206. rawCert := certificateBytes(req)
  207. if rawCert == nil {
  208. if debug {
  209. log.Println(reqID, "no certificates")
  210. }
  211. globalStats.Error()
  212. http.Error(w, "Forbidden", http.StatusForbidden)
  213. return
  214. }
  215. var ann announcement
  216. if err := json.NewDecoder(req.Body).Decode(&ann); err != nil {
  217. if debug {
  218. log.Println(reqID, "decode:", err)
  219. }
  220. globalStats.Error()
  221. http.Error(w, "Bad Request", http.StatusBadRequest)
  222. return
  223. }
  224. deviceID := protocol.NewDeviceID(rawCert)
  225. // handleAnnounce returns *two* errors. The first indicates a problem with
  226. // something the client posted to us. We should return a 400 Bad Request
  227. // and not worry about it. The second indicates that the request was fine,
  228. // but something internal messed up. We should log it and respond with a
  229. // more apologetic 500 Internal Server Error.
  230. userErr, internalErr := s.handleAnnounce(ctx, remoteIP, deviceID, ann.Addresses)
  231. if userErr != nil {
  232. if debug {
  233. log.Println(reqID, "handleAnnounce:", userErr)
  234. }
  235. globalStats.Error()
  236. http.Error(w, "Bad Request", http.StatusBadRequest)
  237. return
  238. }
  239. if internalErr != nil {
  240. log.Println(reqID, "handleAnnounce:", internalErr)
  241. globalStats.Error()
  242. http.Error(w, "Internal Server Error", http.StatusInternalServerError)
  243. return
  244. }
  245. globalStats.Announce()
  246. // TODO: Slowly increase this for stable clients
  247. w.Header().Set("Reannounce-After", "1800")
  248. // We could return the lookup result here, but it's kind of unnecessarily
  249. // expensive to go query the database again so we let the client decide to
  250. // do a lookup if they really care.
  251. w.WriteHeader(http.StatusNoContent)
  252. }
  253. func (s *querysrv) Stop() {
  254. s.listener.Close()
  255. }
  256. func (s *querysrv) handleAnnounce(ctx context.Context, remote net.IP, deviceID protocol.DeviceID, addresses []string) (userErr, internalErr error) {
  257. reqID := ctx.Value("id").(requestID)
  258. tx, err := s.db.Begin()
  259. if err != nil {
  260. internalErr = err
  261. return
  262. }
  263. defer func() {
  264. // Since we return from a bunch of different places, we handle
  265. // rollback in the defer.
  266. if internalErr != nil || userErr != nil {
  267. tx.Rollback()
  268. }
  269. }()
  270. for _, annAddr := range addresses {
  271. uri, err := url.Parse(annAddr)
  272. if err != nil {
  273. userErr = err
  274. return
  275. }
  276. host, port, err := net.SplitHostPort(uri.Host)
  277. if err != nil {
  278. userErr = err
  279. return
  280. }
  281. ip := net.ParseIP(host)
  282. if host == "" || ip.IsUnspecified() {
  283. // Do not use IPv6 remote address if requested scheme is tcp4
  284. if uri.Scheme == "tcp4" && remote.To4() == nil {
  285. continue
  286. }
  287. // Do not use IPv4 remote address if requested scheme is tcp6
  288. if uri.Scheme == "tcp6" && remote.To4() != nil {
  289. continue
  290. }
  291. host = remote.String()
  292. }
  293. uri.Host = net.JoinHostPort(host, port)
  294. if err := s.updateAddress(ctx, tx, deviceID, uri.String()); err != nil {
  295. internalErr = err
  296. return
  297. }
  298. }
  299. if err := s.updateDevice(ctx, tx, deviceID); err != nil {
  300. internalErr = err
  301. return
  302. }
  303. t0 := time.Now()
  304. internalErr = tx.Commit()
  305. if debug {
  306. log.Println(reqID, "commit in", time.Since(t0))
  307. }
  308. return
  309. }
  310. func (s *querysrv) limit(remote net.IP) bool {
  311. key := remote.String()
  312. bkt, ok := s.limiter.Get(key)
  313. if ok {
  314. bkt := bkt.(*ratelimit.Bucket)
  315. if bkt.TakeAvailable(1) != 1 {
  316. // Rate limit exceeded; ignore packet
  317. return true
  318. }
  319. } else {
  320. // One packet per ten seconds average rate, burst ten packets
  321. s.limiter.Add(key, ratelimit.NewBucket(10*time.Second/time.Duration(limitAvg), int64(limitBurst)))
  322. }
  323. return false
  324. }
  325. func (s *querysrv) updateDevice(ctx context.Context, tx *sql.Tx, device protocol.DeviceID) error {
  326. reqID := ctx.Value("id").(requestID)
  327. t0 := time.Now()
  328. res, err := tx.Stmt(s.prep["updateDevice"]).Exec(device.String())
  329. if err != nil {
  330. return err
  331. }
  332. if debug {
  333. log.Println(reqID, "updateDevice in", time.Since(t0))
  334. }
  335. if rows, _ := res.RowsAffected(); rows == 0 {
  336. t0 = time.Now()
  337. _, err := tx.Stmt(s.prep["insertDevice"]).Exec(device.String())
  338. if err != nil {
  339. return err
  340. }
  341. if debug {
  342. log.Println(reqID, "insertDevice in", time.Since(t0))
  343. }
  344. }
  345. return nil
  346. }
  347. func (s *querysrv) updateAddress(ctx context.Context, tx *sql.Tx, device protocol.DeviceID, uri string) error {
  348. res, err := tx.Stmt(s.prep["updateAddress"]).Exec(device.String(), uri)
  349. if err != nil {
  350. return err
  351. }
  352. if rows, _ := res.RowsAffected(); rows == 0 {
  353. _, err := tx.Stmt(s.prep["insertAddress"]).Exec(device.String(), uri)
  354. if err != nil {
  355. return err
  356. }
  357. }
  358. return nil
  359. }
  360. func (s *querysrv) getAddresses(ctx context.Context, device protocol.DeviceID) ([]string, error) {
  361. rows, err := s.prep["selectAddress"].Query(device.String())
  362. if err != nil {
  363. return nil, err
  364. }
  365. defer rows.Close()
  366. var res []string
  367. for rows.Next() {
  368. var addr string
  369. err := rows.Scan(&addr)
  370. if err != nil {
  371. log.Println("Scan:", err)
  372. continue
  373. }
  374. res = append(res, addr)
  375. }
  376. return res, nil
  377. }
  378. func (s *querysrv) getDeviceSeen(device protocol.DeviceID) (time.Time, error) {
  379. row := s.prep["selectDevice"].QueryRow(device.String())
  380. var seen time.Time
  381. if err := row.Scan(&seen); err != nil {
  382. return time.Time{}, err
  383. }
  384. return seen, nil
  385. }
  386. func handlePing(w http.ResponseWriter, r *http.Request) {
  387. w.WriteHeader(204)
  388. }
  389. func certificateBytes(req *http.Request) []byte {
  390. if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
  391. return req.TLS.PeerCertificates[0].Raw
  392. }
  393. if hdr := req.Header.Get("X-SSL-Cert"); hdr != "" {
  394. bs := []byte(hdr)
  395. // The certificate is in PEM format but with spaces for newlines. We
  396. // need to reinstate the newlines for the PEM decoder. But we need to
  397. // leave the spaces in the BEGIN and END lines - the first and last
  398. // space - alone.
  399. firstSpace := bytes.Index(bs, []byte(" "))
  400. lastSpace := bytes.LastIndex(bs, []byte(" "))
  401. for i := firstSpace + 1; i < lastSpace; i++ {
  402. if bs[i] == ' ' {
  403. bs[i] = '\n'
  404. }
  405. }
  406. block, _ := pem.Decode(bs)
  407. if block == nil {
  408. // Decoding failed
  409. return nil
  410. }
  411. return block.Bytes
  412. }
  413. return nil
  414. }