|
@@ -3,14 +3,12 @@
|
|
|
package main
|
|
|
|
|
|
import (
|
|
|
- "compress/gzip"
|
|
|
"context"
|
|
|
"crypto/tls"
|
|
|
"crypto/x509"
|
|
|
"encoding/json"
|
|
|
"flag"
|
|
|
"fmt"
|
|
|
- "io"
|
|
|
"log"
|
|
|
"net"
|
|
|
"net/http"
|
|
@@ -19,11 +17,13 @@ import (
|
|
|
"path/filepath"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
+ lru "github.com/hashicorp/golang-lru/v2"
|
|
|
+ "github.com/syncthing/syncthing/lib/httpcache"
|
|
|
"github.com/syncthing/syncthing/lib/protocol"
|
|
|
|
|
|
- "github.com/golang/groupcache/lru"
|
|
|
"github.com/oschwald/geoip2-golang"
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
@@ -33,7 +33,6 @@ import (
|
|
|
"github.com/syncthing/syncthing/lib/relay/client"
|
|
|
"github.com/syncthing/syncthing/lib/sync"
|
|
|
"github.com/syncthing/syncthing/lib/tlsutil"
|
|
|
- "golang.org/x/time/rate"
|
|
|
)
|
|
|
|
|
|
type location struct {
|
|
@@ -99,27 +98,13 @@ var (
|
|
|
dir string
|
|
|
evictionTime = time.Hour
|
|
|
debug bool
|
|
|
- getLRUSize = 10 << 10
|
|
|
- getLimitBurst = 10
|
|
|
- getLimitAvg = 2
|
|
|
- postLRUSize = 1 << 10
|
|
|
- postLimitBurst = 2
|
|
|
- postLimitAvg = 2
|
|
|
- getLimit time.Duration
|
|
|
- postLimit time.Duration
|
|
|
permRelaysFile string
|
|
|
ipHeader string
|
|
|
geoipPath string
|
|
|
proto string
|
|
|
- statsRefresh = time.Minute / 2
|
|
|
- requestQueueLen = 10
|
|
|
- requestProcessors = 1
|
|
|
-
|
|
|
- getMut = sync.NewMutex()
|
|
|
- getLRUCache *lru.Cache
|
|
|
-
|
|
|
- postMut = sync.NewMutex()
|
|
|
- postLRUCache *lru.Cache
|
|
|
+ statsRefresh = time.Minute
|
|
|
+ requestQueueLen = 64
|
|
|
+ requestProcessors = 8
|
|
|
|
|
|
requests chan request
|
|
|
|
|
@@ -127,6 +112,7 @@ var (
|
|
|
knownRelays = make([]*relay, 0)
|
|
|
permanentRelays = make([]*relay, 0)
|
|
|
evictionTimers = make(map[string]*time.Timer)
|
|
|
+ globalBlocklist = newErrorTracker(1000)
|
|
|
)
|
|
|
|
|
|
const (
|
|
@@ -141,13 +127,8 @@ func main() {
|
|
|
flag.StringVar(&dir, "keys", dir, "Directory where http-cert.pem and http-key.pem is stored for TLS listening")
|
|
|
flag.BoolVar(&debug, "debug", debug, "Enable debug output")
|
|
|
flag.DurationVar(&evictionTime, "eviction", evictionTime, "After how long the relay is evicted")
|
|
|
- flag.IntVar(&getLRUSize, "get-limit-cache", getLRUSize, "Get request limiter cache size")
|
|
|
- flag.IntVar(&getLimitAvg, "get-limit-avg", getLimitAvg, "Allowed average get request rate, per 10 s")
|
|
|
- flag.IntVar(&getLimitBurst, "get-limit-burst", getLimitBurst, "Allowed burst get requests")
|
|
|
- flag.IntVar(&postLRUSize, "post-limit-cache", postLRUSize, "Post request limiter cache size")
|
|
|
- flag.IntVar(&postLimitAvg, "post-limit-avg", postLimitAvg, "Allowed average post request rate, per minute")
|
|
|
- flag.IntVar(&postLimitBurst, "post-limit-burst", postLimitBurst, "Allowed burst post requests")
|
|
|
flag.StringVar(&permRelaysFile, "perm-relays", "", "Path to list of permanent relays")
|
|
|
+ flag.StringVar(&knownRelaysFile, "known-relays", knownRelaysFile, "Path to list of current relays")
|
|
|
flag.StringVar(&ipHeader, "ip-header", "", "Name of header which holds clients ip:port. Only meaningful when running behind a reverse proxy.")
|
|
|
flag.StringVar(&geoipPath, "geoip", "GeoLite2-City.mmdb", "Path to GeoLite2-City database")
|
|
|
flag.StringVar(&proto, "protocol", "tcp", "Protocol used for listening. 'tcp' for IPv4 and IPv6, 'tcp4' for IPv4, 'tcp6' for IPv6")
|
|
@@ -159,12 +140,6 @@ func main() {
|
|
|
|
|
|
requests = make(chan request, requestQueueLen)
|
|
|
|
|
|
- getLimit = 10 * time.Second / time.Duration(getLimitAvg)
|
|
|
- postLimit = time.Minute / time.Duration(postLimitAvg)
|
|
|
-
|
|
|
- getLRUCache = lru.New(getLRUSize)
|
|
|
- postLRUCache = lru.New(postLRUSize)
|
|
|
-
|
|
|
var listener net.Listener
|
|
|
var err error
|
|
|
|
|
@@ -240,7 +215,7 @@ func main() {
|
|
|
|
|
|
handler := http.NewServeMux()
|
|
|
handler.HandleFunc("/", handleAssets)
|
|
|
- handler.HandleFunc("/endpoint", handleRequest)
|
|
|
+ handler.Handle("/endpoint", httpcache.SinglePath(http.HandlerFunc(handleRequest), 15*time.Second))
|
|
|
handler.HandleFunc("/metrics", handleMetrics)
|
|
|
|
|
|
srv := http.Server{
|
|
@@ -291,21 +266,17 @@ func handleRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
}()
|
|
|
|
|
|
if ipHeader != "" {
|
|
|
- r.RemoteAddr = r.Header.Get(ipHeader)
|
|
|
+ hdr := r.Header.Get(ipHeader)
|
|
|
+ fields := strings.Split(hdr, ",")
|
|
|
+ if len(fields) > 0 {
|
|
|
+ r.RemoteAddr = strings.TrimSpace(fields[len(fields)-1])
|
|
|
+ }
|
|
|
}
|
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
|
switch r.Method {
|
|
|
case "GET":
|
|
|
- if limit(r.RemoteAddr, getLRUCache, getMut, getLimit, getLimitBurst) {
|
|
|
- w.WriteHeader(httpStatusEnhanceYourCalm)
|
|
|
- return
|
|
|
- }
|
|
|
handleGetRequest(w, r)
|
|
|
case "POST":
|
|
|
- if limit(r.RemoteAddr, postLRUCache, postMut, postLimit, postLimitBurst) {
|
|
|
- w.WriteHeader(httpStatusEnhanceYourCalm)
|
|
|
- return
|
|
|
- }
|
|
|
handlePostRequest(w, r)
|
|
|
default:
|
|
|
if debug {
|
|
@@ -327,20 +298,28 @@ func handleGetRequest(rw http.ResponseWriter, r *http.Request) {
|
|
|
// Shuffle
|
|
|
rand.Shuffle(relays)
|
|
|
|
|
|
- w := io.Writer(rw)
|
|
|
- if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
|
|
- rw.Header().Set("Content-Encoding", "gzip")
|
|
|
- gw := gzip.NewWriter(rw)
|
|
|
- defer gw.Close()
|
|
|
- w = gw
|
|
|
- }
|
|
|
-
|
|
|
- _ = json.NewEncoder(w).Encode(map[string][]*relay{
|
|
|
+ _ = json.NewEncoder(rw).Encode(map[string][]*relay{
|
|
|
"relays": relays,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
func handlePostRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
+ // Get the IP address of the client
|
|
|
+ rhost := r.RemoteAddr
|
|
|
+ if host, _, err := net.SplitHostPort(rhost); err == nil {
|
|
|
+ rhost = host
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check the black list. A client is blacklisted if their last 10
|
|
|
+ // attempts to join have all failed. The "Unauthorized" status return
|
|
|
+ // causes strelaysrv to cease attempting to join.
|
|
|
+ if globalBlocklist.IsBlocked(rhost) {
|
|
|
+ log.Println("Rejected blocked client", rhost)
|
|
|
+ http.Error(w, "Too many errors", http.StatusUnauthorized)
|
|
|
+ globalBlocklist.ClearErrors(rhost)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
var relayCert *x509.Certificate
|
|
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
|
|
relayCert = r.TLS.PeerCertificates[0]
|
|
@@ -392,12 +371,6 @@ func handlePostRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // Get the IP address of the client
|
|
|
- rhost := r.RemoteAddr
|
|
|
- if host, _, err := net.SplitHostPort(rhost); err == nil {
|
|
|
- rhost = host
|
|
|
- }
|
|
|
-
|
|
|
ip := net.ParseIP(host)
|
|
|
// The client did not provide an IP address, use the IP address of the client.
|
|
|
if ip == nil || ip.IsUnspecified() {
|
|
@@ -429,10 +402,14 @@ func handlePostRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
case requests <- request{&newRelay, reschan, prometheus.NewTimer(relayTestActionsSeconds.WithLabelValues("queue"))}:
|
|
|
result := <-reschan
|
|
|
if result.err != nil {
|
|
|
+ log.Println("Join from", r.RemoteAddr, "failed:", result.err)
|
|
|
+ globalBlocklist.AddError(rhost)
|
|
|
relayTestsTotal.WithLabelValues("failed").Inc()
|
|
|
http.Error(w, result.err.Error(), http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
+ log.Println("Join from", r.RemoteAddr, "succeeded")
|
|
|
+ globalBlocklist.ClearErrors(rhost)
|
|
|
relayTestsTotal.WithLabelValues("success").Inc()
|
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
|
json.NewEncoder(w).Encode(map[string]time.Duration{
|
|
@@ -546,23 +523,6 @@ func evict(relay *relay) func() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func limit(addr string, cache *lru.Cache, lock sync.Mutex, intv time.Duration, burst int) bool {
|
|
|
- if host, _, err := net.SplitHostPort(addr); err == nil {
|
|
|
- addr = host
|
|
|
- }
|
|
|
-
|
|
|
- lock.Lock()
|
|
|
- v, _ := cache.Get(addr)
|
|
|
- bkt, ok := v.(*rate.Limiter)
|
|
|
- if !ok {
|
|
|
- bkt = rate.NewLimiter(rate.Every(intv), burst)
|
|
|
- cache.Add(addr, bkt)
|
|
|
- }
|
|
|
- lock.Unlock()
|
|
|
-
|
|
|
- return !bkt.Allow()
|
|
|
-}
|
|
|
-
|
|
|
func loadRelays(file string) []*relay {
|
|
|
content, err := os.ReadFile(file)
|
|
|
if err != nil {
|
|
@@ -602,7 +562,7 @@ func saveRelays(file string, relays []*relay) error {
|
|
|
for _, relay := range relays {
|
|
|
content += relay.uri.String() + "\n"
|
|
|
}
|
|
|
- return os.WriteFile(file, []byte(content), 0777)
|
|
|
+ return os.WriteFile(file, []byte(content), 0o777)
|
|
|
}
|
|
|
|
|
|
func createTestCertificate() tls.Certificate {
|
|
@@ -661,3 +621,42 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) {
|
|
|
lrw.statusCode = code
|
|
|
lrw.ResponseWriter.WriteHeader(code)
|
|
|
}
|
|
|
+
|
|
|
+type errorTracker struct {
|
|
|
+ errors *lru.TwoQueueCache[string, *errorCounter]
|
|
|
+}
|
|
|
+
|
|
|
+type errorCounter struct {
|
|
|
+ count atomic.Int32
|
|
|
+}
|
|
|
+
|
|
|
+func newErrorTracker(size int) *errorTracker {
|
|
|
+ cache, err := lru.New2Q[string, *errorCounter](size)
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+ return &errorTracker{
|
|
|
+ errors: cache,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (b *errorTracker) AddError(host string) {
|
|
|
+ entry, ok := b.errors.Get(host)
|
|
|
+ if !ok {
|
|
|
+ entry = &errorCounter{}
|
|
|
+ b.errors.Add(host, entry)
|
|
|
+ }
|
|
|
+ c := entry.count.Add(1)
|
|
|
+ log.Printf("Error count for %s is now %d", host, c)
|
|
|
+}
|
|
|
+
|
|
|
+func (b *errorTracker) ClearErrors(host string) {
|
|
|
+ b.errors.Remove(host)
|
|
|
+}
|
|
|
+
|
|
|
+func (b *errorTracker) IsBlocked(host string) bool {
|
|
|
+ if be, ok := b.errors.Get(host); ok {
|
|
|
+ return be.count.Load() > 10
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|