replication.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. // Copyright (C) 2018 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. package main
  7. import (
  8. "crypto/tls"
  9. "encoding/binary"
  10. "fmt"
  11. io "io"
  12. "log"
  13. "net"
  14. "time"
  15. "github.com/syncthing/syncthing/lib/protocol"
  16. )
  17. type replicator interface {
  18. send(key string, addrs []DatabaseAddress, seen int64)
  19. }
  20. // a replicationSender tries to connect to the remote address and provide
  21. // them with a feed of replication updates.
  22. type replicationSender struct {
  23. dst string
  24. cert tls.Certificate // our certificate
  25. allowedIDs []protocol.DeviceID
  26. outbox chan ReplicationRecord
  27. stop chan struct{}
  28. }
  29. func newReplicationSender(dst string, cert tls.Certificate, allowedIDs []protocol.DeviceID) *replicationSender {
  30. return &replicationSender{
  31. dst: dst,
  32. cert: cert,
  33. allowedIDs: allowedIDs,
  34. outbox: make(chan ReplicationRecord, replicationOutboxSize),
  35. stop: make(chan struct{}),
  36. }
  37. }
  38. func (s *replicationSender) Serve() {
  39. // Sleep a little at startup. Peers often restart at the same time, and
  40. // this avoid the service failing and entering backoff state
  41. // unnecessarily, while also reducing the reconnect rate to something
  42. // reasonable by default.
  43. time.Sleep(2 * time.Second)
  44. tlsCfg := &tls.Config{
  45. Certificates: []tls.Certificate{s.cert},
  46. MinVersion: tls.VersionTLS12,
  47. InsecureSkipVerify: true,
  48. }
  49. // Dial the TLS connection.
  50. conn, err := tls.Dial("tcp", s.dst, tlsCfg)
  51. if err != nil {
  52. log.Println("Replication connect:", err)
  53. return
  54. }
  55. defer func() {
  56. conn.SetWriteDeadline(time.Now().Add(time.Second))
  57. conn.Close()
  58. }()
  59. // Get the other side device ID.
  60. remoteID, err := deviceID(conn)
  61. if err != nil {
  62. log.Println("Replication connect:", err)
  63. return
  64. }
  65. // Verify it's in the set of allowed device IDs.
  66. if !deviceIDIn(remoteID, s.allowedIDs) {
  67. log.Println("Replication connect: unexpected device ID:", remoteID)
  68. return
  69. }
  70. // Send records.
  71. buf := make([]byte, 1024)
  72. for {
  73. select {
  74. case rec := <-s.outbox:
  75. // Buffer must hold record plus four bytes for size
  76. size := rec.Size()
  77. if len(buf) < size+4 {
  78. buf = make([]byte, size+4)
  79. }
  80. // Record comes after the four bytes size
  81. n, err := rec.MarshalTo(buf[4:])
  82. if err != nil {
  83. // odd to get an error here, but we haven't sent anything
  84. // yet so it's not fatal
  85. replicationSendsTotal.WithLabelValues("error").Inc()
  86. log.Println("Replication marshal:", err)
  87. continue
  88. }
  89. binary.BigEndian.PutUint32(buf, uint32(n))
  90. // Send
  91. conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
  92. if _, err := conn.Write(buf[:4+n]); err != nil {
  93. replicationSendsTotal.WithLabelValues("error").Inc()
  94. log.Println("Replication write:", err)
  95. return
  96. }
  97. replicationSendsTotal.WithLabelValues("success").Inc()
  98. case <-s.stop:
  99. return
  100. }
  101. }
  102. }
  103. func (s *replicationSender) Stop() {
  104. close(s.stop)
  105. }
  106. func (s *replicationSender) String() string {
  107. return fmt.Sprintf("replicationSender(%q)", s.dst)
  108. }
  109. func (s *replicationSender) send(key string, ps []DatabaseAddress, seen int64) {
  110. item := ReplicationRecord{
  111. Key: key,
  112. Addresses: ps,
  113. }
  114. // The send should never block. The inbox is suitably buffered for at
  115. // least a few seconds of stalls, which shouldn't happen in practice.
  116. select {
  117. case s.outbox <- item:
  118. default:
  119. replicationSendsTotal.WithLabelValues("drop").Inc()
  120. }
  121. }
  122. // a replicationMultiplexer sends to multiple replicators
  123. type replicationMultiplexer []replicator
  124. func (m replicationMultiplexer) send(key string, ps []DatabaseAddress, seen int64) {
  125. for _, s := range m {
  126. // each send is nonblocking
  127. s.send(key, ps, seen)
  128. }
  129. }
  130. // replicationListener accepts incoming connections and reads replication
  131. // items from them. Incoming items are applied to the KV store.
  132. type replicationListener struct {
  133. addr string
  134. cert tls.Certificate
  135. allowedIDs []protocol.DeviceID
  136. db database
  137. stop chan struct{}
  138. }
  139. func newReplicationListener(addr string, cert tls.Certificate, allowedIDs []protocol.DeviceID, db database) *replicationListener {
  140. return &replicationListener{
  141. addr: addr,
  142. cert: cert,
  143. allowedIDs: allowedIDs,
  144. db: db,
  145. stop: make(chan struct{}),
  146. }
  147. }
  148. func (l *replicationListener) Serve() {
  149. tlsCfg := &tls.Config{
  150. Certificates: []tls.Certificate{l.cert},
  151. ClientAuth: tls.RequestClientCert,
  152. MinVersion: tls.VersionTLS12,
  153. InsecureSkipVerify: true,
  154. }
  155. lst, err := tls.Listen("tcp", l.addr, tlsCfg)
  156. if err != nil {
  157. log.Println("Replication listen:", err)
  158. return
  159. }
  160. defer lst.Close()
  161. for {
  162. select {
  163. case <-l.stop:
  164. return
  165. default:
  166. }
  167. // Accept a connection
  168. conn, err := lst.Accept()
  169. if err != nil {
  170. log.Println("Replication accept:", err)
  171. return
  172. }
  173. // Figure out the other side device ID
  174. remoteID, err := deviceID(conn.(*tls.Conn))
  175. if err != nil {
  176. log.Println("Replication accept:", err)
  177. conn.SetWriteDeadline(time.Now().Add(time.Second))
  178. conn.Close()
  179. continue
  180. }
  181. // Verify it is in the set of allowed device IDs
  182. if !deviceIDIn(remoteID, l.allowedIDs) {
  183. log.Println("Replication accept: unexpected device ID:", remoteID)
  184. conn.SetWriteDeadline(time.Now().Add(time.Second))
  185. conn.Close()
  186. continue
  187. }
  188. go l.handle(conn)
  189. }
  190. }
  191. func (l *replicationListener) Stop() {
  192. close(l.stop)
  193. }
  194. func (l *replicationListener) String() string {
  195. return fmt.Sprintf("replicationListener(%q)", l.addr)
  196. }
  197. func (l *replicationListener) handle(conn net.Conn) {
  198. defer func() {
  199. conn.SetWriteDeadline(time.Now().Add(time.Second))
  200. conn.Close()
  201. }()
  202. buf := make([]byte, 1024)
  203. for {
  204. select {
  205. case <-l.stop:
  206. return
  207. default:
  208. }
  209. conn.SetReadDeadline(time.Now().Add(time.Minute))
  210. // First four bytes are the size
  211. if _, err := io.ReadFull(conn, buf[:4]); err != nil {
  212. log.Println("Replication read size:", err)
  213. replicationRecvsTotal.WithLabelValues("error").Inc()
  214. return
  215. }
  216. // Read the rest of the record
  217. size := int(binary.BigEndian.Uint32(buf[:4]))
  218. if len(buf) < size {
  219. buf = make([]byte, size)
  220. }
  221. if _, err := io.ReadFull(conn, buf[:size]); err != nil {
  222. log.Println("Replication read record:", err)
  223. replicationRecvsTotal.WithLabelValues("error").Inc()
  224. return
  225. }
  226. // Unmarshal
  227. var rec ReplicationRecord
  228. if err := rec.Unmarshal(buf[:size]); err != nil {
  229. log.Println("Replication unmarshal:", err)
  230. replicationRecvsTotal.WithLabelValues("error").Inc()
  231. continue
  232. }
  233. // Store
  234. l.db.merge(rec.Key, rec.Addresses, rec.Seen)
  235. replicationRecvsTotal.WithLabelValues("success").Inc()
  236. }
  237. }
  238. func deviceID(conn *tls.Conn) (protocol.DeviceID, error) {
  239. // Handshake may not be complete on the server side yet, which we need
  240. // to get the client certificate.
  241. if !conn.ConnectionState().HandshakeComplete {
  242. if err := conn.Handshake(); err != nil {
  243. return protocol.DeviceID{}, err
  244. }
  245. }
  246. // We expect exactly one certificate.
  247. certs := conn.ConnectionState().PeerCertificates
  248. if len(certs) != 1 {
  249. return protocol.DeviceID{}, fmt.Errorf("unexpected number of certificates (%d != 1)", len(certs))
  250. }
  251. return protocol.NewDeviceID(certs[0].Raw), nil
  252. }
  253. func deviceIDIn(id protocol.DeviceID, ids []protocol.DeviceID) bool {
  254. for _, candidate := range ids {
  255. if id == candidate {
  256. return true
  257. }
  258. }
  259. return false
  260. }