main.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. // Copyright (C) 2014 Jakob Borg and Contributors (see the CONTRIBUTORS file).
  2. // All rights reserved. Use of this source code is governed by an MIT-style
  3. // license that can be found in the LICENSE file.
  4. package main
  5. import (
  6. "encoding/binary"
  7. "encoding/hex"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "log"
  12. "net"
  13. "os"
  14. "sync"
  15. "time"
  16. "github.com/calmh/syncthing/discover"
  17. "github.com/calmh/syncthing/protocol"
  18. "github.com/golang/groupcache/lru"
  19. "github.com/juju/ratelimit"
  20. )
  21. type node struct {
  22. addresses []address
  23. updated time.Time
  24. }
  25. type address struct {
  26. ip []byte
  27. port uint16
  28. }
  29. var (
  30. nodes = make(map[protocol.NodeID]node)
  31. lock sync.Mutex
  32. queries = 0
  33. announces = 0
  34. answered = 0
  35. limited = 0
  36. unknowns = 0
  37. debug = false
  38. lruSize = 1024
  39. limitAvg = 1
  40. limitBurst = 10
  41. limiter *lru.Cache
  42. )
  43. func main() {
  44. var listen string
  45. var timestamp bool
  46. var statsIntv int
  47. var statsFile string
  48. flag.StringVar(&listen, "listen", ":22026", "Listen address")
  49. flag.BoolVar(&debug, "debug", false, "Enable debug output")
  50. flag.BoolVar(&timestamp, "timestamp", true, "Timestamp the log output")
  51. flag.IntVar(&statsIntv, "stats-intv", 0, "Statistics output interval (s)")
  52. flag.StringVar(&statsFile, "stats-file", "/var/log/discosrv.stats", "Statistics file name")
  53. flag.IntVar(&lruSize, "limit-cache", lruSize, "Limiter cache entries")
  54. flag.IntVar(&limitAvg, "limit-avg", limitAvg, "Allowed average package rate, per 10 s")
  55. flag.IntVar(&limitBurst, "limit-burst", limitBurst, "Allowed burst size, packets")
  56. flag.Parse()
  57. limiter = lru.New(lruSize)
  58. log.SetOutput(os.Stdout)
  59. if !timestamp {
  60. log.SetFlags(0)
  61. }
  62. addr, _ := net.ResolveUDPAddr("udp", listen)
  63. conn, err := net.ListenUDP("udp", addr)
  64. if err != nil {
  65. log.Fatal(err)
  66. }
  67. if statsIntv > 0 {
  68. go logStats(statsFile, statsIntv)
  69. }
  70. var buf = make([]byte, 1024)
  71. for {
  72. buf = buf[:cap(buf)]
  73. n, addr, err := conn.ReadFromUDP(buf)
  74. if limit(addr) {
  75. // Rate limit in effect for source
  76. continue
  77. }
  78. if err != nil {
  79. log.Fatal(err)
  80. }
  81. if n < 4 {
  82. log.Printf("Received short packet (%d bytes)", n)
  83. continue
  84. }
  85. buf = buf[:n]
  86. magic := binary.BigEndian.Uint32(buf)
  87. switch magic {
  88. case discover.AnnouncementMagic:
  89. handleAnnounceV2(addr, buf)
  90. case discover.QueryMagic:
  91. handleQueryV2(conn, addr, buf)
  92. default:
  93. lock.Lock()
  94. unknowns++
  95. lock.Unlock()
  96. }
  97. }
  98. }
  99. func limit(addr *net.UDPAddr) bool {
  100. key := addr.IP.String()
  101. lock.Lock()
  102. defer lock.Unlock()
  103. bkt, ok := limiter.Get(key)
  104. if ok {
  105. bkt := bkt.(*ratelimit.Bucket)
  106. if bkt.TakeAvailable(1) != 1 {
  107. // Rate limit exceeded; ignore packet
  108. if debug {
  109. log.Println("Rate limit exceeded for", key)
  110. }
  111. limited++
  112. return true
  113. }
  114. } else {
  115. if debug {
  116. log.Println("New limiter for", key)
  117. }
  118. // One packet per ten seconds average rate, burst ten packets
  119. limiter.Add(key, ratelimit.NewBucket(10*time.Second/time.Duration(limitAvg), int64(limitBurst)))
  120. }
  121. return false
  122. }
  123. func handleAnnounceV2(addr *net.UDPAddr, buf []byte) {
  124. var pkt discover.Announce
  125. err := pkt.UnmarshalXDR(buf)
  126. if err != nil && err != io.EOF {
  127. log.Println("AnnounceV2 Unmarshal:", err)
  128. log.Println(hex.Dump(buf))
  129. return
  130. }
  131. if debug {
  132. log.Printf("<- %v %#v", addr, pkt)
  133. }
  134. lock.Lock()
  135. announces++
  136. lock.Unlock()
  137. ip := addr.IP.To4()
  138. if ip == nil {
  139. ip = addr.IP.To16()
  140. }
  141. var addrs []address
  142. for _, addr := range pkt.This.Addresses {
  143. tip := addr.IP
  144. if len(tip) == 0 {
  145. tip = ip
  146. }
  147. addrs = append(addrs, address{
  148. ip: tip,
  149. port: addr.Port,
  150. })
  151. }
  152. node := node{
  153. addresses: addrs,
  154. updated: time.Now(),
  155. }
  156. var id protocol.NodeID
  157. if len(pkt.This.ID) == 32 {
  158. // Raw node ID
  159. copy(id[:], pkt.This.ID)
  160. } else {
  161. id.UnmarshalText(pkt.This.ID)
  162. }
  163. lock.Lock()
  164. nodes[id] = node
  165. lock.Unlock()
  166. }
  167. func handleQueryV2(conn *net.UDPConn, addr *net.UDPAddr, buf []byte) {
  168. var pkt discover.Query
  169. err := pkt.UnmarshalXDR(buf)
  170. if err != nil {
  171. log.Println("QueryV2 Unmarshal:", err)
  172. log.Println(hex.Dump(buf))
  173. return
  174. }
  175. if debug {
  176. log.Printf("<- %v %#v", addr, pkt)
  177. }
  178. var id protocol.NodeID
  179. if len(pkt.NodeID) == 32 {
  180. // Raw node ID
  181. copy(id[:], pkt.NodeID)
  182. } else {
  183. id.UnmarshalText(pkt.NodeID)
  184. }
  185. lock.Lock()
  186. node, ok := nodes[id]
  187. queries++
  188. lock.Unlock()
  189. if ok && len(node.addresses) > 0 {
  190. ann := discover.Announce{
  191. Magic: discover.AnnouncementMagic,
  192. This: discover.Node{
  193. ID: pkt.NodeID,
  194. },
  195. }
  196. for _, addr := range node.addresses {
  197. ann.This.Addresses = append(ann.This.Addresses, discover.Address{IP: addr.ip, Port: addr.port})
  198. }
  199. if debug {
  200. log.Printf("-> %v %#v", addr, pkt)
  201. }
  202. tb := ann.MarshalXDR()
  203. _, _, err = conn.WriteMsgUDP(tb, nil, addr)
  204. if err != nil {
  205. log.Println("QueryV2 response write:", err)
  206. }
  207. lock.Lock()
  208. answered++
  209. lock.Unlock()
  210. }
  211. }
  212. func next(intv int) time.Time {
  213. d := time.Duration(intv) * time.Second
  214. t0 := time.Now()
  215. t1 := t0.Add(d).Truncate(d)
  216. time.Sleep(t1.Sub(t0))
  217. return t1
  218. }
  219. func logStats(file string, intv int) {
  220. f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
  221. if err != nil {
  222. log.Fatal(err)
  223. }
  224. for {
  225. t := next(intv)
  226. lock.Lock()
  227. var deleted = 0
  228. for id, node := range nodes {
  229. if time.Since(node.updated) > 60*time.Minute {
  230. delete(nodes, id)
  231. deleted++
  232. }
  233. }
  234. fmt.Fprintf(f, "%d Nr:%d Ne:%d Qt:%d Qa:%d A:%d U:%d Lq:%d Lc:%d\n",
  235. t.Unix(), len(nodes), deleted, queries, answered, announces, unknowns, limited, limiter.Len())
  236. f.Sync()
  237. queries = 0
  238. announces = 0
  239. answered = 0
  240. limited = 0
  241. unknowns = 0
  242. lock.Unlock()
  243. }
  244. }