replication.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Copyright (C) 2018 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. package main
  7. import (
  8. "context"
  9. "crypto/tls"
  10. "encoding/binary"
  11. "fmt"
  12. io "io"
  13. "log"
  14. "net"
  15. "time"
  16. "github.com/syncthing/syncthing/lib/protocol"
  17. )
  18. const replicationReadTimeout = time.Minute
  19. const replicationHeartbeatInterval = time.Second * 30
  20. type replicator interface {
  21. send(key string, addrs []DatabaseAddress, seen int64)
  22. }
  23. // a replicationSender tries to connect to the remote address and provide
  24. // them with a feed of replication updates.
  25. type replicationSender struct {
  26. dst string
  27. cert tls.Certificate // our certificate
  28. allowedIDs []protocol.DeviceID
  29. outbox chan ReplicationRecord
  30. }
  31. func newReplicationSender(dst string, cert tls.Certificate, allowedIDs []protocol.DeviceID) *replicationSender {
  32. return &replicationSender{
  33. dst: dst,
  34. cert: cert,
  35. allowedIDs: allowedIDs,
  36. outbox: make(chan ReplicationRecord, replicationOutboxSize),
  37. }
  38. }
  39. func (s *replicationSender) Serve(ctx context.Context) error {
  40. // Sleep a little at startup. Peers often restart at the same time, and
  41. // this avoid the service failing and entering backoff state
  42. // unnecessarily, while also reducing the reconnect rate to something
  43. // reasonable by default.
  44. time.Sleep(2 * time.Second)
  45. tlsCfg := &tls.Config{
  46. Certificates: []tls.Certificate{s.cert},
  47. MinVersion: tls.VersionTLS12,
  48. InsecureSkipVerify: true,
  49. }
  50. // Dial the TLS connection.
  51. conn, err := tls.Dial("tcp", s.dst, tlsCfg)
  52. if err != nil {
  53. log.Println("Replication connect:", err)
  54. return err
  55. }
  56. defer func() {
  57. conn.SetWriteDeadline(time.Now().Add(time.Second))
  58. conn.Close()
  59. }()
  60. // Get the other side device ID.
  61. remoteID, err := deviceID(conn)
  62. if err != nil {
  63. log.Println("Replication connect:", err)
  64. return err
  65. }
  66. // Verify it's in the set of allowed device IDs.
  67. if !deviceIDIn(remoteID, s.allowedIDs) {
  68. log.Println("Replication connect: unexpected device ID:", remoteID)
  69. return err
  70. }
  71. heartBeatTicker := time.NewTicker(replicationHeartbeatInterval)
  72. defer heartBeatTicker.Stop()
  73. // Send records.
  74. buf := make([]byte, 1024)
  75. for {
  76. select {
  77. case <-heartBeatTicker.C:
  78. if len(s.outbox) > 0 {
  79. // No need to send heartbeats if there are events/prevrious
  80. // heartbeats to send, they will keep the connection alive.
  81. continue
  82. }
  83. // Empty replication message is the heartbeat:
  84. s.outbox <- ReplicationRecord{}
  85. case rec := <-s.outbox:
  86. // Buffer must hold record plus four bytes for size
  87. size := rec.Size()
  88. if len(buf) < size+4 {
  89. buf = make([]byte, size+4)
  90. }
  91. // Record comes after the four bytes size
  92. n, err := rec.MarshalTo(buf[4:])
  93. if err != nil {
  94. // odd to get an error here, but we haven't sent anything
  95. // yet so it's not fatal
  96. replicationSendsTotal.WithLabelValues("error").Inc()
  97. log.Println("Replication marshal:", err)
  98. continue
  99. }
  100. binary.BigEndian.PutUint32(buf, uint32(n))
  101. // Send
  102. conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
  103. if _, err := conn.Write(buf[:4+n]); err != nil {
  104. replicationSendsTotal.WithLabelValues("error").Inc()
  105. log.Println("Replication write:", err)
  106. // Yes, we are loosing the replication event here.
  107. return err
  108. }
  109. replicationSendsTotal.WithLabelValues("success").Inc()
  110. case <-ctx.Done():
  111. return nil
  112. }
  113. }
  114. }
  115. func (s *replicationSender) String() string {
  116. return fmt.Sprintf("replicationSender(%q)", s.dst)
  117. }
  118. func (s *replicationSender) send(key string, ps []DatabaseAddress, _ int64) {
  119. item := ReplicationRecord{
  120. Key: key,
  121. Addresses: ps,
  122. }
  123. // The send should never block. The inbox is suitably buffered for at
  124. // least a few seconds of stalls, which shouldn't happen in practice.
  125. select {
  126. case s.outbox <- item:
  127. default:
  128. replicationSendsTotal.WithLabelValues("drop").Inc()
  129. }
  130. }
  131. // a replicationMultiplexer sends to multiple replicators
  132. type replicationMultiplexer []replicator
  133. func (m replicationMultiplexer) send(key string, ps []DatabaseAddress, seen int64) {
  134. for _, s := range m {
  135. // each send is nonblocking
  136. s.send(key, ps, seen)
  137. }
  138. }
  139. // replicationListener accepts incoming connections and reads replication
  140. // items from them. Incoming items are applied to the KV store.
  141. type replicationListener struct {
  142. addr string
  143. cert tls.Certificate
  144. allowedIDs []protocol.DeviceID
  145. db database
  146. }
  147. func newReplicationListener(addr string, cert tls.Certificate, allowedIDs []protocol.DeviceID, db database) *replicationListener {
  148. return &replicationListener{
  149. addr: addr,
  150. cert: cert,
  151. allowedIDs: allowedIDs,
  152. db: db,
  153. }
  154. }
  155. func (l *replicationListener) Serve(ctx context.Context) error {
  156. tlsCfg := &tls.Config{
  157. Certificates: []tls.Certificate{l.cert},
  158. ClientAuth: tls.RequestClientCert,
  159. MinVersion: tls.VersionTLS12,
  160. InsecureSkipVerify: true,
  161. }
  162. lst, err := tls.Listen("tcp", l.addr, tlsCfg)
  163. if err != nil {
  164. log.Println("Replication listen:", err)
  165. return err
  166. }
  167. defer lst.Close()
  168. for {
  169. select {
  170. case <-ctx.Done():
  171. return nil
  172. default:
  173. }
  174. // Accept a connection
  175. conn, err := lst.Accept()
  176. if err != nil {
  177. log.Println("Replication accept:", err)
  178. return err
  179. }
  180. // Figure out the other side device ID
  181. remoteID, err := deviceID(conn.(*tls.Conn))
  182. if err != nil {
  183. log.Println("Replication accept:", err)
  184. conn.SetWriteDeadline(time.Now().Add(time.Second))
  185. conn.Close()
  186. continue
  187. }
  188. // Verify it is in the set of allowed device IDs
  189. if !deviceIDIn(remoteID, l.allowedIDs) {
  190. log.Println("Replication accept: unexpected device ID:", remoteID)
  191. conn.SetWriteDeadline(time.Now().Add(time.Second))
  192. conn.Close()
  193. continue
  194. }
  195. go l.handle(ctx, conn)
  196. }
  197. }
  198. func (l *replicationListener) String() string {
  199. return fmt.Sprintf("replicationListener(%q)", l.addr)
  200. }
  201. func (l *replicationListener) handle(ctx context.Context, conn net.Conn) {
  202. defer func() {
  203. conn.SetWriteDeadline(time.Now().Add(time.Second))
  204. conn.Close()
  205. }()
  206. buf := make([]byte, 1024)
  207. for {
  208. select {
  209. case <-ctx.Done():
  210. return
  211. default:
  212. }
  213. conn.SetReadDeadline(time.Now().Add(replicationReadTimeout))
  214. // First four bytes are the size
  215. if _, err := io.ReadFull(conn, buf[:4]); err != nil {
  216. log.Println("Replication read size:", err)
  217. replicationRecvsTotal.WithLabelValues("error").Inc()
  218. return
  219. }
  220. // Read the rest of the record
  221. size := int(binary.BigEndian.Uint32(buf[:4]))
  222. if len(buf) < size {
  223. buf = make([]byte, size)
  224. }
  225. if size == 0 {
  226. // Heartbeat, ignore
  227. continue
  228. }
  229. if _, err := io.ReadFull(conn, buf[:size]); err != nil {
  230. log.Println("Replication read record:", err)
  231. replicationRecvsTotal.WithLabelValues("error").Inc()
  232. return
  233. }
  234. // Unmarshal
  235. var rec ReplicationRecord
  236. if err := rec.Unmarshal(buf[:size]); err != nil {
  237. log.Println("Replication unmarshal:", err)
  238. replicationRecvsTotal.WithLabelValues("error").Inc()
  239. continue
  240. }
  241. // Store
  242. l.db.merge(rec.Key, rec.Addresses, rec.Seen)
  243. replicationRecvsTotal.WithLabelValues("success").Inc()
  244. }
  245. }
  246. func deviceID(conn *tls.Conn) (protocol.DeviceID, error) {
  247. // Handshake may not be complete on the server side yet, which we need
  248. // to get the client certificate.
  249. if !conn.ConnectionState().HandshakeComplete {
  250. if err := conn.Handshake(); err != nil {
  251. return protocol.DeviceID{}, err
  252. }
  253. }
  254. // We expect exactly one certificate.
  255. certs := conn.ConnectionState().PeerCertificates
  256. if len(certs) != 1 {
  257. return protocol.DeviceID{}, fmt.Errorf("unexpected number of certificates (%d != 1)", len(certs))
  258. }
  259. return protocol.NewDeviceID(certs[0].Raw), nil
  260. }
  261. func deviceIDIn(id protocol.DeviceID, ids []protocol.DeviceID) bool {
  262. for _, candidate := range ids {
  263. if id == candidate {
  264. return true
  265. }
  266. }
  267. return false
  268. }