apisrv.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. // Copyright (C) 2018 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. package main
  7. import (
  8. "bytes"
  9. "compress/gzip"
  10. "context"
  11. "crypto/tls"
  12. "encoding/base64"
  13. "encoding/json"
  14. "encoding/pem"
  15. "errors"
  16. "fmt"
  17. io "io"
  18. "log"
  19. "math/rand"
  20. "net"
  21. "net/http"
  22. "net/url"
  23. "slices"
  24. "strconv"
  25. "strings"
  26. "sync"
  27. "time"
  28. "github.com/syncthing/syncthing/lib/protocol"
  29. "github.com/syncthing/syncthing/lib/stringutil"
  30. )
  31. // announcement is the format received from and sent to clients
  32. type announcement struct {
  33. Seen time.Time `json:"seen"`
  34. Addresses []string `json:"addresses"`
  35. }
  36. type apiSrv struct {
  37. addr string
  38. cert tls.Certificate
  39. db database
  40. listener net.Listener
  41. repl replicator // optional
  42. useHTTP bool
  43. compression bool
  44. gzipWriters sync.Pool
  45. seenTracker *retryAfterTracker
  46. notSeenTracker *retryAfterTracker
  47. }
  48. type replicator interface {
  49. send(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64)
  50. }
  51. type requestID int64
  52. func (i requestID) String() string {
  53. return fmt.Sprintf("%016x", int64(i))
  54. }
  55. type contextKey int
  56. const idKey contextKey = iota
  57. func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP, compression bool) *apiSrv {
  58. return &apiSrv{
  59. addr: addr,
  60. cert: cert,
  61. db: db,
  62. repl: repl,
  63. useHTTP: useHTTP,
  64. compression: compression,
  65. seenTracker: &retryAfterTracker{
  66. name: "seenTracker",
  67. bucketStarts: time.Now(),
  68. desiredRate: 250,
  69. currentDelay: notFoundRetryUnknownMinSeconds,
  70. },
  71. notSeenTracker: &retryAfterTracker{
  72. name: "notSeenTracker",
  73. bucketStarts: time.Now(),
  74. desiredRate: 250,
  75. currentDelay: notFoundRetryUnknownMaxSeconds / 2,
  76. },
  77. }
  78. }
  79. func (s *apiSrv) Serve(ctx context.Context) error {
  80. if s.useHTTP {
  81. listener, err := net.Listen("tcp", s.addr)
  82. if err != nil {
  83. log.Println("Listen:", err)
  84. return err
  85. }
  86. s.listener = listener
  87. } else {
  88. tlsCfg := &tls.Config{
  89. Certificates: []tls.Certificate{s.cert},
  90. ClientAuth: tls.RequestClientCert,
  91. MinVersion: tls.VersionTLS12,
  92. NextProtos: []string{"h2", "http/1.1"},
  93. }
  94. tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg)
  95. if err != nil {
  96. log.Println("Listen:", err)
  97. return err
  98. }
  99. s.listener = tlsListener
  100. }
  101. http.HandleFunc("/", s.handler)
  102. http.HandleFunc("/ping", handlePing)
  103. srv := &http.Server{
  104. ReadTimeout: httpReadTimeout,
  105. WriteTimeout: httpWriteTimeout,
  106. MaxHeaderBytes: httpMaxHeaderBytes,
  107. ErrorLog: log.New(io.Discard, "", 0),
  108. }
  109. go func() {
  110. <-ctx.Done()
  111. srv.Shutdown(context.Background())
  112. }()
  113. err := srv.Serve(s.listener)
  114. if err != nil {
  115. log.Println("Serve:", err)
  116. }
  117. return err
  118. }
  119. func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
  120. t0 := time.Now()
  121. lw := NewLoggingResponseWriter(w)
  122. defer func() {
  123. diff := time.Since(t0)
  124. apiRequestsSeconds.WithLabelValues(req.Method).Observe(diff.Seconds())
  125. apiRequestsTotal.WithLabelValues(req.Method, strconv.Itoa(lw.statusCode)).Inc()
  126. }()
  127. reqID := requestID(rand.Int63())
  128. req = req.WithContext(context.WithValue(req.Context(), idKey, reqID))
  129. if debug {
  130. log.Println(reqID, req.Method, req.URL, req.Proto)
  131. }
  132. remoteAddr := &net.TCPAddr{
  133. IP: nil,
  134. Port: -1,
  135. }
  136. if s.useHTTP {
  137. // X-Forwarded-For can have multiple client IPs; split using the comma separator
  138. forwardIP, _, _ := strings.Cut(req.Header.Get("X-Forwarded-For"), ",")
  139. // net.ParseIP will return nil if leading/trailing whitespace exists; use strings.TrimSpace()
  140. remoteAddr.IP = net.ParseIP(strings.TrimSpace(forwardIP))
  141. if parsedPort, err := strconv.ParseInt(req.Header.Get("X-Client-Port"), 10, 0); err == nil {
  142. remoteAddr.Port = int(parsedPort)
  143. }
  144. } else {
  145. var err error
  146. remoteAddr, err = net.ResolveTCPAddr("tcp", req.RemoteAddr)
  147. if err != nil {
  148. log.Println("remoteAddr:", err)
  149. lw.Header().Set("Retry-After", errorRetryAfterString())
  150. http.Error(lw, "Internal Server Error", http.StatusInternalServerError)
  151. apiRequestsTotal.WithLabelValues("no_remote_addr").Inc()
  152. return
  153. }
  154. }
  155. switch req.Method {
  156. case http.MethodGet:
  157. s.handleGET(lw, req)
  158. case http.MethodPost:
  159. s.handlePOST(remoteAddr, lw, req)
  160. default:
  161. http.Error(lw, "Method Not Allowed", http.StatusMethodNotAllowed)
  162. }
  163. }
  164. func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) {
  165. reqID := req.Context().Value(idKey).(requestID)
  166. deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device"))
  167. if err != nil {
  168. if debug {
  169. log.Println(reqID, "bad device param")
  170. }
  171. lookupRequestsTotal.WithLabelValues("bad_request").Inc()
  172. w.Header().Set("Retry-After", errorRetryAfterString())
  173. http.Error(w, "Bad Request", http.StatusBadRequest)
  174. return
  175. }
  176. rec, err := s.db.get(&deviceID)
  177. if err != nil {
  178. // some sort of internal error
  179. lookupRequestsTotal.WithLabelValues("internal_error").Inc()
  180. w.Header().Set("Retry-After", errorRetryAfterString())
  181. http.Error(w, "Internal Server Error", http.StatusInternalServerError)
  182. return
  183. }
  184. if len(rec.Addresses) == 0 {
  185. var afterS int
  186. if rec.Seen == 0 {
  187. afterS = s.notSeenTracker.retryAfterS()
  188. lookupRequestsTotal.WithLabelValues("not_found_ever").Inc()
  189. } else {
  190. afterS = s.seenTracker.retryAfterS()
  191. lookupRequestsTotal.WithLabelValues("not_found_recent").Inc()
  192. }
  193. w.Header().Set("Retry-After", strconv.Itoa(afterS))
  194. http.Error(w, "Not Found", http.StatusNotFound)
  195. return
  196. }
  197. lookupRequestsTotal.WithLabelValues("success").Inc()
  198. w.Header().Set("Content-Type", "application/json")
  199. var bw io.Writer = w
  200. // Use compression if the client asks for it
  201. if s.compression && strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") {
  202. gw, ok := s.gzipWriters.Get().(*gzip.Writer)
  203. if ok {
  204. gw.Reset(w)
  205. } else {
  206. gw = gzip.NewWriter(w)
  207. }
  208. w.Header().Set("Content-Encoding", "gzip")
  209. defer gw.Close()
  210. defer s.gzipWriters.Put(gw)
  211. bw = gw
  212. }
  213. json.NewEncoder(bw).Encode(announcement{
  214. Seen: time.Unix(0, rec.Seen).Truncate(time.Second),
  215. Addresses: addressStrs(rec.Addresses),
  216. })
  217. }
  218. func (s *apiSrv) handlePOST(remoteAddr *net.TCPAddr, w http.ResponseWriter, req *http.Request) {
  219. reqID := req.Context().Value(idKey).(requestID)
  220. rawCert, err := certificateBytes(req)
  221. if err != nil {
  222. if debug {
  223. log.Println(reqID, "no certificates:", err)
  224. }
  225. announceRequestsTotal.WithLabelValues("no_certificate").Inc()
  226. w.Header().Set("Retry-After", errorRetryAfterString())
  227. http.Error(w, "Forbidden", http.StatusForbidden)
  228. return
  229. }
  230. var ann announcement
  231. if err := json.NewDecoder(req.Body).Decode(&ann); err != nil {
  232. if debug {
  233. log.Println(reqID, "decode:", err)
  234. }
  235. announceRequestsTotal.WithLabelValues("bad_request").Inc()
  236. w.Header().Set("Retry-After", errorRetryAfterString())
  237. http.Error(w, "Bad Request", http.StatusBadRequest)
  238. return
  239. }
  240. deviceID := protocol.NewDeviceID(rawCert)
  241. addresses := fixupAddresses(remoteAddr, ann.Addresses)
  242. if len(addresses) == 0 {
  243. announceRequestsTotal.WithLabelValues("bad_request").Inc()
  244. w.Header().Set("Retry-After", errorRetryAfterString())
  245. http.Error(w, "Bad Request", http.StatusBadRequest)
  246. return
  247. }
  248. if err := s.handleAnnounce(deviceID, addresses); err != nil {
  249. announceRequestsTotal.WithLabelValues("internal_error").Inc()
  250. w.Header().Set("Retry-After", errorRetryAfterString())
  251. http.Error(w, "Internal Server Error", http.StatusInternalServerError)
  252. return
  253. }
  254. announceRequestsTotal.WithLabelValues("success").Inc()
  255. w.Header().Set("Reannounce-After", reannounceAfterString())
  256. w.WriteHeader(http.StatusNoContent)
  257. }
  258. func (s *apiSrv) Stop() {
  259. s.listener.Close()
  260. }
  261. func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) error {
  262. now := time.Now()
  263. expire := now.Add(addressExpiryTime).UnixNano()
  264. // The address slice must always be sorted for database merges to work
  265. // properly.
  266. slices.Sort(addresses)
  267. addresses = slices.Compact(addresses)
  268. dbAddrs := make([]DatabaseAddress, len(addresses))
  269. for i := range addresses {
  270. dbAddrs[i].Address = addresses[i]
  271. dbAddrs[i].Expires = expire
  272. }
  273. seen := now.UnixNano()
  274. if s.repl != nil {
  275. s.repl.send(&deviceID, dbAddrs, seen)
  276. }
  277. return s.db.merge(&deviceID, dbAddrs, seen)
  278. }
  279. func handlePing(w http.ResponseWriter, _ *http.Request) {
  280. w.WriteHeader(204)
  281. }
  282. func certificateBytes(req *http.Request) ([]byte, error) {
  283. if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
  284. return req.TLS.PeerCertificates[0].Raw, nil
  285. }
  286. var bs []byte
  287. if hdr := req.Header.Get("X-SSL-Cert"); hdr != "" {
  288. if strings.Contains(hdr, "%") {
  289. // Nginx using $ssl_client_escaped_cert
  290. // The certificate is in PEM format with url encoding.
  291. // We need to decode for the PEM decoder
  292. hdr, err := url.QueryUnescape(hdr)
  293. if err != nil {
  294. // Decoding failed
  295. return nil, err
  296. }
  297. bs = []byte(hdr)
  298. } else {
  299. // Nginx using $ssl_client_cert
  300. // The certificate is in PEM format but with spaces for newlines. We
  301. // need to reinstate the newlines for the PEM decoder. But we need to
  302. // leave the spaces in the BEGIN and END lines - the first and last
  303. // space - alone.
  304. bs = []byte(hdr)
  305. firstSpace := bytes.Index(bs, []byte(" "))
  306. lastSpace := bytes.LastIndex(bs, []byte(" "))
  307. for i := firstSpace + 1; i < lastSpace; i++ {
  308. if bs[i] == ' ' {
  309. bs[i] = '\n'
  310. }
  311. }
  312. }
  313. } else if hdr := req.Header.Get("X-Tls-Client-Cert-Der-Base64"); hdr != "" {
  314. // Caddy {tls_client_certificate_der_base64}
  315. hdr, err := base64.StdEncoding.DecodeString(hdr)
  316. if err != nil {
  317. // Decoding failed
  318. return nil, err
  319. }
  320. bs = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: hdr})
  321. } else if cert := req.Header.Get("X-Forwarded-Tls-Client-Cert"); cert != "" {
  322. // Traefik 2 passtlsclientcert
  323. //
  324. // The certificate is in PEM format, maybe with URL encoding
  325. // (depends on Traefik version) but without newlines and start/end
  326. // statements. We need to decode, reinstate the newlines every 64
  327. // character and add statements for the PEM decoder
  328. if strings.Contains(cert, "%") {
  329. if unesc, err := url.QueryUnescape(cert); err == nil {
  330. cert = unesc
  331. }
  332. }
  333. const (
  334. header = "-----BEGIN CERTIFICATE-----"
  335. footer = "-----END CERTIFICATE-----"
  336. )
  337. var b bytes.Buffer
  338. b.Grow(len(header) + 1 + len(cert) + len(cert)/64 + 1 + len(footer) + 1)
  339. b.WriteString(header)
  340. b.WriteByte('\n')
  341. for i := 0; i < len(cert); i += 64 {
  342. end := i + 64
  343. if end > len(cert) {
  344. end = len(cert)
  345. }
  346. b.WriteString(cert[i:end])
  347. b.WriteByte('\n')
  348. }
  349. b.WriteString(footer)
  350. b.WriteByte('\n')
  351. bs = b.Bytes()
  352. }
  353. if bs == nil {
  354. return nil, errors.New("empty certificate header")
  355. }
  356. block, _ := pem.Decode(bs)
  357. if block == nil {
  358. // Decoding failed
  359. return nil, errors.New("certificate decode result is empty")
  360. }
  361. return block.Bytes, nil
  362. }
  363. // fixupAddresses checks the list of addresses, removing invalid ones and
  364. // replacing unspecified IPs with the given remote IP.
  365. func fixupAddresses(remote *net.TCPAddr, addresses []string) []string {
  366. fixed := make([]string, 0, len(addresses))
  367. for _, annAddr := range addresses {
  368. uri, err := url.Parse(annAddr)
  369. if err != nil {
  370. continue
  371. }
  372. host, port, err := net.SplitHostPort(uri.Host)
  373. if err != nil {
  374. continue
  375. }
  376. ip := net.ParseIP(host)
  377. // Some classes of IP are no-go.
  378. if ip.IsLoopback() || ip.IsMulticast() {
  379. continue
  380. }
  381. if host == "" || ip.IsUnspecified() {
  382. if remote != nil {
  383. // Replace the unspecified IP with the request source.
  384. // ... unless the request source is the loopback address or
  385. // multicast/unspecified (can't happen, really).
  386. if remote.IP == nil || remote.IP.IsLoopback() || remote.IP.IsMulticast() || remote.IP.IsUnspecified() {
  387. continue
  388. }
  389. // Do not use IPv6 remote address if requested scheme is ...4
  390. // (i.e., tcp4, etc.)
  391. if strings.HasSuffix(uri.Scheme, "4") && remote.IP.To4() == nil {
  392. continue
  393. }
  394. // Do not use IPv4 remote address if requested scheme is ...6
  395. if strings.HasSuffix(uri.Scheme, "6") && remote.IP.To4() != nil {
  396. continue
  397. }
  398. host = remote.IP.String()
  399. } else {
  400. // remote is nil, unable to determine host IP
  401. continue
  402. }
  403. }
  404. // If zero port was specified, use remote port.
  405. if port == "0" {
  406. if remote != nil && remote.Port > 0 {
  407. // use remote port
  408. port = strconv.Itoa(remote.Port)
  409. } else {
  410. // unable to determine remote port
  411. continue
  412. }
  413. }
  414. uri.Host = net.JoinHostPort(host, port)
  415. fixed = append(fixed, uri.String())
  416. }
  417. // Remove duplicate addresses
  418. fixed = stringutil.UniqueTrimmedStrings(fixed)
  419. return fixed
  420. }
  421. type loggingResponseWriter struct {
  422. http.ResponseWriter
  423. statusCode int
  424. }
  425. func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
  426. return &loggingResponseWriter{w, http.StatusOK}
  427. }
  428. func (lrw *loggingResponseWriter) WriteHeader(code int) {
  429. lrw.statusCode = code
  430. lrw.ResponseWriter.WriteHeader(code)
  431. }
  432. func addressStrs(dbAddrs []DatabaseAddress) []string {
  433. res := make([]string, len(dbAddrs))
  434. for i, a := range dbAddrs {
  435. res[i] = a.Address
  436. }
  437. return res
  438. }
  439. func errorRetryAfterString() string {
  440. return strconv.Itoa(errorRetryAfterSeconds + rand.Intn(errorRetryFuzzSeconds))
  441. }
  442. func reannounceAfterString() string {
  443. return strconv.Itoa(reannounceAfterSeconds + rand.Intn(reannounzeFuzzSeconds))
  444. }
  445. type retryAfterTracker struct {
  446. name string
  447. desiredRate float64 // requests per second
  448. mut sync.Mutex
  449. lastCount int // requests in the last bucket
  450. curCount int // requests in the current bucket
  451. bucketStarts time.Time // start of the current bucket
  452. currentDelay int // current delay in seconds
  453. }
  454. func (t *retryAfterTracker) retryAfterS() int {
  455. now := time.Now()
  456. t.mut.Lock()
  457. if durS := now.Sub(t.bucketStarts).Seconds(); durS > float64(t.currentDelay) {
  458. t.bucketStarts = now
  459. t.lastCount = t.curCount
  460. lastRate := float64(t.lastCount) / durS
  461. switch {
  462. case t.currentDelay > notFoundRetryUnknownMinSeconds &&
  463. lastRate < 0.75*t.desiredRate:
  464. t.currentDelay = max(8*t.currentDelay/10, notFoundRetryUnknownMinSeconds)
  465. case t.currentDelay < notFoundRetryUnknownMaxSeconds &&
  466. lastRate > 1.25*t.desiredRate:
  467. t.currentDelay = min(3*t.currentDelay/2, notFoundRetryUnknownMaxSeconds)
  468. }
  469. t.curCount = 0
  470. }
  471. if t.curCount == 0 {
  472. retryAfterLevel.WithLabelValues(t.name).Set(float64(t.currentDelay))
  473. }
  474. t.curCount++
  475. t.mut.Unlock()
  476. return t.currentDelay + rand.Intn(t.currentDelay/4)
  477. }