session.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. // Copyright (C) 2015 Audrius Butkevicius and Contributors.
  2. package main
  3. import (
  4. "crypto/rand"
  5. "encoding/hex"
  6. "fmt"
  7. "log"
  8. "math"
  9. "net"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "golang.org/x/time/rate"
  14. syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
  15. "github.com/syncthing/syncthing/lib/relay/protocol"
  16. )
  17. var (
  18. sessionMut = sync.RWMutex{}
  19. activeSessions = make([]*session, 0)
  20. pendingSessions = make(map[string]*session)
  21. numProxies atomic.Int64
  22. bytesProxied atomic.Int64
  23. )
  24. func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionLimitBps int, globalRateLimit *rate.Limiter) *session {
  25. serverkey := make([]byte, 32)
  26. _, err := rand.Read(serverkey)
  27. if err != nil {
  28. return nil
  29. }
  30. clientkey := make([]byte, 32)
  31. _, err = rand.Read(clientkey)
  32. if err != nil {
  33. return nil
  34. }
  35. var sessionRateLimit *rate.Limiter
  36. if sessionLimitBps > 0 {
  37. sessionRateLimit = rate.NewLimiter(rate.Limit(sessionLimitBps), 2*sessionLimitBps)
  38. }
  39. ses := &session{
  40. serverkey: serverkey,
  41. serverid: serverid,
  42. clientkey: clientkey,
  43. clientid: clientid,
  44. rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
  45. limiter: sessionRateLimit,
  46. connsChan: make(chan net.Conn),
  47. conns: make([]net.Conn, 0, 2),
  48. }
  49. if debug {
  50. log.Println("New session", ses)
  51. }
  52. sessionMut.Lock()
  53. pendingSessions[string(ses.serverkey)] = ses
  54. pendingSessions[string(ses.clientkey)] = ses
  55. sessionMut.Unlock()
  56. return ses
  57. }
  58. func findSession(key string) *session {
  59. sessionMut.Lock()
  60. defer sessionMut.Unlock()
  61. ses, ok := pendingSessions[key]
  62. if !ok {
  63. return nil
  64. }
  65. delete(pendingSessions, key)
  66. return ses
  67. }
  68. func dropSessions(id syncthingprotocol.DeviceID) {
  69. sessionMut.RLock()
  70. for _, session := range activeSessions {
  71. if session.HasParticipant(id) {
  72. if debug {
  73. log.Println("Dropping session", session, "involving", id)
  74. }
  75. session.CloseConns()
  76. }
  77. }
  78. sessionMut.RUnlock()
  79. }
  80. func hasSessions(id syncthingprotocol.DeviceID) bool {
  81. sessionMut.RLock()
  82. has := false
  83. for _, session := range activeSessions {
  84. if session.HasParticipant(id) {
  85. has = true
  86. break
  87. }
  88. }
  89. sessionMut.RUnlock()
  90. return has
  91. }
  92. type session struct {
  93. mut sync.Mutex
  94. serverkey []byte
  95. serverid syncthingprotocol.DeviceID
  96. clientkey []byte
  97. clientid syncthingprotocol.DeviceID
  98. rateLimit func(bytes int)
  99. limiter *rate.Limiter
  100. connsChan chan net.Conn
  101. conns []net.Conn
  102. }
  103. func (s *session) AddConnection(conn net.Conn) bool {
  104. if debug {
  105. log.Println("New connection for", s, "from", conn.RemoteAddr())
  106. }
  107. select {
  108. case s.connsChan <- conn:
  109. return true
  110. default:
  111. }
  112. return false
  113. }
  114. func (s *session) Serve() {
  115. timedout := time.After(messageTimeout)
  116. if debug {
  117. log.Println("Session", s, "serving")
  118. }
  119. for {
  120. select {
  121. case conn := <-s.connsChan:
  122. s.mut.Lock()
  123. s.conns = append(s.conns, conn)
  124. s.mut.Unlock()
  125. // We're the only ones mutating s.conns, hence we are free to read it.
  126. if len(s.conns) < 2 {
  127. continue
  128. }
  129. close(s.connsChan)
  130. if debug {
  131. log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr())
  132. }
  133. wg := sync.WaitGroup{}
  134. wg.Add(2)
  135. var err0 error
  136. go func() {
  137. err0 = s.proxy(s.conns[0], s.conns[1])
  138. wg.Done()
  139. }()
  140. var err1 error
  141. go func() {
  142. err1 = s.proxy(s.conns[1], s.conns[0])
  143. wg.Done()
  144. }()
  145. sessionMut.Lock()
  146. activeSessions = append(activeSessions, s)
  147. sessionMut.Unlock()
  148. wg.Wait()
  149. if debug {
  150. log.Println("Session", s, "ended, outcomes:", err0, "and", err1)
  151. }
  152. goto done
  153. case <-timedout:
  154. if debug {
  155. log.Println("Session", s, "timed out")
  156. }
  157. goto done
  158. }
  159. }
  160. done:
  161. // We can end up here in 3 cases:
  162. // 1. Timeout joining, in which case there are potentially entries in pendingSessions
  163. // 2. General session end/timeout, in which case there are entries in activeSessions
  164. // 3. Protocol handler calls dropSession as one of its clients disconnects.
  165. sessionMut.Lock()
  166. delete(pendingSessions, string(s.serverkey))
  167. delete(pendingSessions, string(s.clientkey))
  168. for i, session := range activeSessions {
  169. if session == s {
  170. l := len(activeSessions) - 1
  171. activeSessions[i] = activeSessions[l]
  172. activeSessions[l] = nil
  173. activeSessions = activeSessions[:l]
  174. }
  175. }
  176. sessionMut.Unlock()
  177. // If we are here because of case 2 or 3, we are potentially closing some or
  178. // all connections a second time.
  179. s.CloseConns()
  180. if debug {
  181. log.Println("Session", s, "stopping")
  182. }
  183. }
  184. func (s *session) GetClientInvitationMessage() protocol.SessionInvitation {
  185. return protocol.SessionInvitation{
  186. From: s.serverid[:],
  187. Key: s.clientkey,
  188. Address: sessionAddress,
  189. Port: sessionPort,
  190. ServerSocket: false,
  191. }
  192. }
  193. func (s *session) GetServerInvitationMessage() protocol.SessionInvitation {
  194. return protocol.SessionInvitation{
  195. From: s.clientid[:],
  196. Key: s.serverkey,
  197. Address: sessionAddress,
  198. Port: sessionPort,
  199. ServerSocket: true,
  200. }
  201. }
  202. func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool {
  203. return s.clientid == id || s.serverid == id
  204. }
  205. func (s *session) CloseConns() {
  206. s.mut.Lock()
  207. for _, conn := range s.conns {
  208. conn.Close()
  209. }
  210. s.mut.Unlock()
  211. }
  212. func (s *session) proxy(c1, c2 net.Conn) error {
  213. if debug {
  214. log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
  215. }
  216. numProxies.Add(1)
  217. defer numProxies.Add(-1)
  218. buf := make([]byte, networkBufferSize)
  219. for {
  220. c1.SetReadDeadline(time.Now().Add(networkTimeout))
  221. n, err := c1.Read(buf)
  222. if err != nil {
  223. return err
  224. }
  225. bytesProxied.Add(int64(n))
  226. if debug {
  227. log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
  228. }
  229. if s.rateLimit != nil {
  230. s.rateLimit(n)
  231. }
  232. c2.SetWriteDeadline(time.Now().Add(networkTimeout))
  233. _, err = c2.Write(buf[:n])
  234. if err != nil {
  235. return err
  236. }
  237. }
  238. }
  239. func (s *session) String() string {
  240. return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
  241. }
  242. func makeRateLimitFunc(sessionRateLimit, globalRateLimit *rate.Limiter) func(int) {
  243. // This may be a case of super duper premature optimization... We build an
  244. // optimized function to do the rate limiting here based on what we need
  245. // to do and then use it in the loop.
  246. if sessionRateLimit == nil && globalRateLimit == nil {
  247. // No limiting needed. We could equally well return a func(int64){} and
  248. // not do a nil check were we use it, but I think the nil check there
  249. // makes it clear that there will be no limiting if none is
  250. // configured...
  251. return nil
  252. }
  253. if sessionRateLimit == nil {
  254. // We only have a global limiter
  255. return func(bytes int) {
  256. take(bytes, globalRateLimit)
  257. }
  258. }
  259. if globalRateLimit == nil {
  260. // We only have a session limiter
  261. return func(bytes int) {
  262. take(bytes, sessionRateLimit)
  263. }
  264. }
  265. // We have both. Queue the bytes on both the global and session specific
  266. // rate limiters.
  267. return func(bytes int) {
  268. take(bytes, sessionRateLimit, globalRateLimit)
  269. }
  270. }
  271. // take is a utility function to consume tokens from a set of rate.Limiters.
  272. // Tokens are consumed in parallel on all limiters, respecting their
  273. // individual burst sizes.
  274. func take(tokens int, ls ...*rate.Limiter) {
  275. // minBurst is the smallest burst size supported by all limiters.
  276. minBurst := int(math.MaxInt32)
  277. for _, l := range ls {
  278. if burst := l.Burst(); burst < minBurst {
  279. minBurst = burst
  280. }
  281. }
  282. for tokens > 0 {
  283. // chunk is how many tokens we can consume at a time
  284. chunk := tokens
  285. if chunk > minBurst {
  286. chunk = minBurst
  287. }
  288. // maxDelay is the longest delay mandated by any of the limiters for
  289. // the chosen chunk size.
  290. var maxDelay time.Duration
  291. for _, l := range ls {
  292. res := l.ReserveN(time.Now(), chunk)
  293. if del := res.Delay(); del > maxDelay {
  294. maxDelay = del
  295. }
  296. }
  297. time.Sleep(maxDelay)
  298. tokens -= chunk
  299. }
  300. }