session.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. // Copyright (C) 2015 Audrius Butkevicius and Contributors.
  2. package main
  3. import (
  4. "crypto/rand"
  5. "encoding/hex"
  6. "fmt"
  7. "log"
  8. "math"
  9. "net"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "golang.org/x/time/rate"
  14. syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
  15. "github.com/syncthing/syncthing/lib/relay/protocol"
  16. )
  17. var (
  18. sessionMut = sync.RWMutex{}
  19. activeSessions = make([]*session, 0)
  20. pendingSessions = make(map[string]*session)
  21. numProxies atomic.Int64
  22. bytesProxied atomic.Int64
  23. )
  24. func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionLimitBps int, globalRateLimit *rate.Limiter) *session {
  25. serverkey := make([]byte, 32)
  26. _, err := rand.Read(serverkey)
  27. if err != nil {
  28. return nil
  29. }
  30. clientkey := make([]byte, 32)
  31. _, err = rand.Read(clientkey)
  32. if err != nil {
  33. return nil
  34. }
  35. var sessionRateLimit *rate.Limiter
  36. if sessionLimitBps > 0 {
  37. sessionRateLimit = rate.NewLimiter(rate.Limit(sessionLimitBps), 2*sessionLimitBps)
  38. }
  39. ses := &session{
  40. serverkey: serverkey,
  41. serverid: serverid,
  42. clientkey: clientkey,
  43. clientid: clientid,
  44. rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
  45. limiter: sessionRateLimit,
  46. connsChan: make(chan net.Conn),
  47. conns: make([]net.Conn, 0, 2),
  48. }
  49. if debug {
  50. log.Println("New session", ses)
  51. }
  52. sessionMut.Lock()
  53. pendingSessions[string(ses.serverkey)] = ses
  54. pendingSessions[string(ses.clientkey)] = ses
  55. sessionMut.Unlock()
  56. return ses
  57. }
  58. func findSession(key string) *session {
  59. sessionMut.Lock()
  60. defer sessionMut.Unlock()
  61. ses, ok := pendingSessions[key]
  62. if !ok {
  63. return nil
  64. }
  65. delete(pendingSessions, key)
  66. return ses
  67. }
  68. func dropSessions(id syncthingprotocol.DeviceID) {
  69. sessionMut.RLock()
  70. for _, session := range activeSessions {
  71. if session.HasParticipant(id) {
  72. if debug {
  73. log.Println("Dropping session", session, "involving", id)
  74. }
  75. session.CloseConns()
  76. }
  77. }
  78. sessionMut.RUnlock()
  79. }
  80. func hasSessions(id syncthingprotocol.DeviceID) bool {
  81. sessionMut.RLock()
  82. has := false
  83. for _, session := range activeSessions {
  84. if session.HasParticipant(id) {
  85. has = true
  86. break
  87. }
  88. }
  89. sessionMut.RUnlock()
  90. return has
  91. }
  92. type session struct {
  93. mut sync.Mutex
  94. serverkey []byte
  95. serverid syncthingprotocol.DeviceID
  96. clientkey []byte
  97. clientid syncthingprotocol.DeviceID
  98. rateLimit func(bytes int)
  99. limiter *rate.Limiter
  100. connsChan chan net.Conn
  101. conns []net.Conn
  102. }
  103. func (s *session) AddConnection(conn net.Conn) bool {
  104. if debug {
  105. log.Println("New connection for", s, "from", conn.RemoteAddr())
  106. }
  107. select {
  108. case s.connsChan <- conn:
  109. return true
  110. default:
  111. }
  112. return false
  113. }
  114. func (s *session) Serve() {
  115. timedout := time.After(messageTimeout)
  116. if debug {
  117. log.Println("Session", s, "serving")
  118. }
  119. for {
  120. select {
  121. case conn := <-s.connsChan:
  122. s.mut.Lock()
  123. s.conns = append(s.conns, conn)
  124. s.mut.Unlock()
  125. // We're the only ones mutating s.conns, hence we are free to read it.
  126. if len(s.conns) < 2 {
  127. continue
  128. }
  129. close(s.connsChan)
  130. if debug {
  131. log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr())
  132. }
  133. wg := sync.WaitGroup{}
  134. var err0 error
  135. wg.Go(func() { err0 = s.proxy(s.conns[0], s.conns[1]) })
  136. var err1 error
  137. wg.Go(func() { err1 = s.proxy(s.conns[1], s.conns[0]) })
  138. sessionMut.Lock()
  139. activeSessions = append(activeSessions, s)
  140. sessionMut.Unlock()
  141. wg.Wait()
  142. if debug {
  143. log.Println("Session", s, "ended, outcomes:", err0, "and", err1)
  144. }
  145. goto done
  146. case <-timedout:
  147. if debug {
  148. log.Println("Session", s, "timed out")
  149. }
  150. goto done
  151. }
  152. }
  153. done:
  154. // We can end up here in 3 cases:
  155. // 1. Timeout joining, in which case there are potentially entries in pendingSessions
  156. // 2. General session end/timeout, in which case there are entries in activeSessions
  157. // 3. Protocol handler calls dropSession as one of its clients disconnects.
  158. sessionMut.Lock()
  159. delete(pendingSessions, string(s.serverkey))
  160. delete(pendingSessions, string(s.clientkey))
  161. for i, session := range activeSessions {
  162. if session == s {
  163. l := len(activeSessions) - 1
  164. activeSessions[i] = activeSessions[l]
  165. activeSessions[l] = nil
  166. activeSessions = activeSessions[:l]
  167. }
  168. }
  169. sessionMut.Unlock()
  170. // If we are here because of case 2 or 3, we are potentially closing some or
  171. // all connections a second time.
  172. s.CloseConns()
  173. if debug {
  174. log.Println("Session", s, "stopping")
  175. }
  176. }
  177. func (s *session) GetClientInvitationMessage() protocol.SessionInvitation {
  178. return protocol.SessionInvitation{
  179. From: s.serverid[:],
  180. Key: s.clientkey,
  181. Address: sessionAddress,
  182. Port: sessionPort,
  183. ServerSocket: false,
  184. }
  185. }
  186. func (s *session) GetServerInvitationMessage() protocol.SessionInvitation {
  187. return protocol.SessionInvitation{
  188. From: s.clientid[:],
  189. Key: s.serverkey,
  190. Address: sessionAddress,
  191. Port: sessionPort,
  192. ServerSocket: true,
  193. }
  194. }
  195. func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool {
  196. return s.clientid == id || s.serverid == id
  197. }
  198. func (s *session) CloseConns() {
  199. s.mut.Lock()
  200. for _, conn := range s.conns {
  201. conn.Close()
  202. }
  203. s.mut.Unlock()
  204. }
  205. func (s *session) proxy(c1, c2 net.Conn) error {
  206. if debug {
  207. log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
  208. }
  209. numProxies.Add(1)
  210. defer numProxies.Add(-1)
  211. buf := make([]byte, networkBufferSize)
  212. for {
  213. c1.SetReadDeadline(time.Now().Add(networkTimeout))
  214. n, err := c1.Read(buf)
  215. if err != nil {
  216. return err
  217. }
  218. bytesProxied.Add(int64(n))
  219. if debug {
  220. log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
  221. }
  222. if s.rateLimit != nil {
  223. s.rateLimit(n)
  224. }
  225. c2.SetWriteDeadline(time.Now().Add(networkTimeout))
  226. _, err = c2.Write(buf[:n])
  227. if err != nil {
  228. return err
  229. }
  230. }
  231. }
  232. func (s *session) String() string {
  233. return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
  234. }
  235. func makeRateLimitFunc(sessionRateLimit, globalRateLimit *rate.Limiter) func(int) {
  236. // This may be a case of super duper premature optimization... We build an
  237. // optimized function to do the rate limiting here based on what we need
  238. // to do and then use it in the loop.
  239. if sessionRateLimit == nil && globalRateLimit == nil {
  240. // No limiting needed. We could equally well return a func(int64){} and
  241. // not do a nil check were we use it, but I think the nil check there
  242. // makes it clear that there will be no limiting if none is
  243. // configured...
  244. return nil
  245. }
  246. if sessionRateLimit == nil {
  247. // We only have a global limiter
  248. return func(bytes int) {
  249. take(bytes, globalRateLimit)
  250. }
  251. }
  252. if globalRateLimit == nil {
  253. // We only have a session limiter
  254. return func(bytes int) {
  255. take(bytes, sessionRateLimit)
  256. }
  257. }
  258. // We have both. Queue the bytes on both the global and session specific
  259. // rate limiters.
  260. return func(bytes int) {
  261. take(bytes, sessionRateLimit, globalRateLimit)
  262. }
  263. }
  264. // take is a utility function to consume tokens from a set of rate.Limiters.
  265. // Tokens are consumed in parallel on all limiters, respecting their
  266. // individual burst sizes.
  267. func take(tokens int, ls ...*rate.Limiter) {
  268. // minBurst is the smallest burst size supported by all limiters.
  269. minBurst := int(math.MaxInt32)
  270. for _, l := range ls {
  271. if burst := l.Burst(); burst < minBurst {
  272. minBurst = burst
  273. }
  274. }
  275. for tokens > 0 {
  276. // chunk is how many tokens we can consume at a time
  277. chunk := tokens
  278. if chunk > minBurst {
  279. chunk = minBurst
  280. }
  281. // maxDelay is the longest delay mandated by any of the limiters for
  282. // the chosen chunk size.
  283. var maxDelay time.Duration
  284. for _, l := range ls {
  285. res := l.ReserveN(time.Now(), chunk)
  286. if del := res.Delay(); del > maxDelay {
  287. maxDelay = del
  288. }
  289. }
  290. time.Sleep(maxDelay)
  291. tokens -= chunk
  292. }
  293. }