1
0
Эх сурвалжийг харах

cmd/discosrv: Merge discosrv repo

* discosrv/master: (64 commits)
  Use atomics for statistics handling (fixes #45)
  Lower case JSON fields are nicer
  Change v13 to v2
  Remove explicit relay handling
  Update vendored github.com/cznic/ql (fixes #34)
  Defer fd.Close() (fixes #37)
  There is no "get dependencies" step
  Add vendor/golang.org/x/net/context
  Use Go 1.5 vendoring instead of Godeps
  Add debug performance logging per request
  Must close result sets
  Set Retry-After header
  Ignores
  lru.Cache is not concurrency safe
  We need a limit on the number of PostgreSQL connections
  Correct example DSN (fixes #29)
  Allow plain HTTP serving behind a proxy
  Fix Query/Answer stats
  Reduce our patience with slow clients somewhat
  Discovery server should print device ID of certificate at startup
  ...
Jakob Borg 9 жил өмнө
parent
commit
7035ea3ab7

+ 19 - 0
cmd/discosrv/LICENSE

@@ -0,0 +1,19 @@
+Copyright (C) 2014-2015 The Discosrv Authors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+- The above copyright notice and this permission notice shall be included in
+  all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 40 - 0
cmd/discosrv/README.md

@@ -0,0 +1,40 @@
+discosrv
+========
+
+[![Latest Build](http://img.shields.io/jenkins/s/http/build.syncthing.net/discosrv.svg?style=flat-square)](http://build.syncthing.net/job/discosrv/lastBuild/)
+
+This is the global discovery server for the `syncthing` project.
+
+To get it, run `go get github.com/syncthing/discosrv` or download the
+[latest build](http://build.syncthing.net/job/discosrv/lastSuccessfulBuild/artifact/)
+from the build server.
+
+Usage
+-----
+
+The discovery server supports `ql` and `postgres` backends.
+Specify the backend via `-db-backend` and the database DSN via `-db-dsn`.
+
+By default it will use in-memory `ql` backend. If you wish to persist the
+information on disk between restarts in `ql`, specify a file DSN:
+
+```bash
+$ discosrv -db-dsn="file:///var/run/discosrv.db"
+```
+
+For `postgres`, you will need to create a database and a user with permissions
+to create tables in it, then start the discosrv as follows:
+
+```bash
+$ export DISCOSRV_DB_DSN="postgres://user:password@localhost/databasename"
+$ discosrv -db-backend="postgres"
+```
+
+You can pass the DSN as command line option, but the value what you pass in will
+be visible in most process managers, potentially exposing the database password
+to other users.
+
+In all cases, the appropriate tables and indexes will be created at first
+startup. If it doesn't exit with an error, you're fine.
+
+See `discosrv -help` for other options.

+ 75 - 0
cmd/discosrv/clean.go

@@ -0,0 +1,75 @@
+// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
+
+package main
+
+import (
+	"database/sql"
+	"log"
+	"time"
+)
+
+type cleansrv struct {
+	intv time.Duration
+	db   *sql.DB
+	prep map[string]*sql.Stmt
+}
+
+func (s *cleansrv) Serve() {
+	for {
+		time.Sleep(next(s.intv))
+
+		err := s.cleanOldEntries()
+		if err != nil {
+			log.Println("Clean:", err)
+		}
+	}
+}
+
+func (s *cleansrv) Stop() {
+	panic("stop unimplemented")
+}
+
+func (s *cleansrv) cleanOldEntries() (err error) {
+	var tx *sql.Tx
+	tx, err = s.db.Begin()
+	if err != nil {
+		return err
+	}
+
+	defer func() {
+		if err == nil {
+			err = tx.Commit()
+		} else {
+			tx.Rollback()
+		}
+	}()
+
+	res, err := tx.Stmt(s.prep["cleanAddress"]).Exec()
+	if err != nil {
+		return err
+	}
+	if rows, _ := res.RowsAffected(); rows > 0 {
+		log.Printf("Clean: %d old addresses", rows)
+	}
+
+	res, err = tx.Stmt(s.prep["cleanDevice"]).Exec()
+	if err != nil {
+		return err
+	}
+	if rows, _ := res.RowsAffected(); rows > 0 {
+		log.Printf("Clean: %d old devices", rows)
+	}
+
+	var devs, addrs int
+	row := tx.Stmt(s.prep["countDevice"]).QueryRow()
+	if err = row.Scan(&devs); err != nil {
+		return err
+	}
+	row = tx.Stmt(s.prep["countAddress"]).QueryRow()
+	if err = row.Scan(&addrs); err != nil {
+		return err
+	}
+
+	log.Printf("Database: %d devices, %d addresses", devs, addrs)
+	return nil
+}

+ 32 - 0
cmd/discosrv/db.go

@@ -0,0 +1,32 @@
+// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
+
+package main
+
+import (
+	"database/sql"
+	"fmt"
+)
+
+type setupFunc func(db *sql.DB) error
+type compileFunc func(db *sql.DB) (map[string]*sql.Stmt, error)
+
+var (
+	setupFuncs   = make(map[string]setupFunc)
+	compileFuncs = make(map[string]compileFunc)
+)
+
+func register(name string, setup setupFunc, compile compileFunc) {
+	setupFuncs[name] = setup
+	compileFuncs[name] = compile
+}
+
+func setup(backend string, db *sql.DB) (map[string]*sql.Stmt, error) {
+	setup, ok := setupFuncs[backend]
+	if !ok {
+		return nil, fmt.Errorf("Unsupported backend")
+	}
+	if err := setup(db); err != nil {
+		return nil, err
+	}
+	return compileFuncs[backend](db)
+}

+ 118 - 0
cmd/discosrv/main.go

@@ -0,0 +1,118 @@
+// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
+
+package main
+
+import (
+	"crypto/tls"
+	"database/sql"
+	"flag"
+	"log"
+	"os"
+	"time"
+
+	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/thejerf/suture"
+)
+
+const (
+	minNegCache  = 60        // seconds
+	maxNegCache  = 3600      // seconds
+	maxDeviceAge = 7 * 86400 // one week, in seconds
+)
+
+var (
+	lruSize     = 10240
+	limitAvg    = 5
+	limitBurst  = 20
+	globalStats stats
+	statsFile   string
+	backend     = "ql"
+	dsn         = getEnvDefault("DISCOSRV_DB_DSN", "memory://discosrv")
+	certFile    = "cert.pem"
+	keyFile     = "key.pem"
+	debug       = false
+	useHttp     = false
+)
+
+func main() {
+	const (
+		cleanIntv = 1 * time.Hour
+		statsIntv = 5 * time.Minute
+	)
+
+	var listen string
+
+	log.SetOutput(os.Stdout)
+	log.SetFlags(0)
+
+	flag.StringVar(&listen, "listen", ":8443", "Listen address")
+	flag.IntVar(&lruSize, "limit-cache", lruSize, "Limiter cache entries")
+	flag.IntVar(&limitAvg, "limit-avg", limitAvg, "Allowed average package rate, per 10 s")
+	flag.IntVar(&limitBurst, "limit-burst", limitBurst, "Allowed burst size, packets")
+	flag.StringVar(&statsFile, "stats-file", statsFile, "File to write periodic operation stats to")
+	flag.StringVar(&backend, "db-backend", backend, "Database backend to use")
+	flag.StringVar(&dsn, "db-dsn", dsn, "Database DSN")
+	flag.StringVar(&certFile, "cert", certFile, "Certificate file")
+	flag.StringVar(&keyFile, "key", keyFile, "Key file")
+	flag.BoolVar(&debug, "debug", debug, "Debug")
+	flag.BoolVar(&useHttp, "http", useHttp, "Listen on HTTP (behind an HTTPS proxy)")
+	flag.Parse()
+
+	var cert tls.Certificate
+	var err error
+	if !useHttp {
+		cert, err = tls.LoadX509KeyPair(certFile, keyFile)
+		if err != nil {
+			log.Fatalln("Failed to load X509 key pair:", err)
+		}
+
+		devID := protocol.NewDeviceID(cert.Certificate[0])
+		log.Println("Server device ID is", devID)
+	}
+
+	db, err := sql.Open(backend, dsn)
+	if err != nil {
+		log.Fatalln("sql.Open:", err)
+	}
+	prep, err := setup(backend, db)
+	if err != nil {
+		log.Fatalln("Setup:", err)
+	}
+
+	main := suture.NewSimple("main")
+
+	main.Add(&querysrv{
+		addr: listen,
+		cert: cert,
+		db:   db,
+		prep: prep,
+	})
+
+	main.Add(&cleansrv{
+		intv: cleanIntv,
+		db:   db,
+		prep: prep,
+	})
+
+	main.Add(&statssrv{
+		intv: statsIntv,
+		file: statsFile,
+		db:   db,
+	})
+
+	globalStats.Reset()
+	main.Serve()
+}
+
+func getEnvDefault(key, def string) string {
+	if val := os.Getenv(key); val != "" {
+		return val
+	}
+	return def
+}
+
+func next(intv time.Duration) time.Duration {
+	t0 := time.Now()
+	t1 := t0.Add(intv).Truncate(intv)
+	return t1.Sub(t0)
+}

+ 97 - 0
cmd/discosrv/psql.go

@@ -0,0 +1,97 @@
+// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
+
+package main
+
+import (
+	"database/sql"
+	"fmt"
+
+	_ "github.com/lib/pq"
+)
+
+func init() {
+	register("postgres", postgresSetup, postgresCompile)
+}
+
+func postgresSetup(db *sql.DB) error {
+	var err error
+
+	db.SetMaxIdleConns(4)
+	db.SetMaxOpenConns(8)
+
+	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS Devices (
+		DeviceID CHAR(63) NOT NULL PRIMARY KEY,
+		Seen TIMESTAMP NOT NULL
+	)`)
+	if err != nil {
+		return err
+	}
+
+	row := db.QueryRow(`SELECT 'DevicesDeviceIDIndex'::regclass`)
+	if err := row.Scan(nil); err != nil {
+		_, err = db.Exec(`CREATE INDEX DevicesDeviceIDIndex ON Devices (DeviceID)`)
+	}
+	if err != nil {
+		return err
+	}
+
+	row = db.QueryRow(`SELECT 'DevicesSeenIndex'::regclass`)
+	if err := row.Scan(nil); err != nil {
+		_, err = db.Exec(`CREATE INDEX DevicesSeenIndex ON Devices (Seen)`)
+	}
+	if err != nil {
+		return err
+	}
+
+	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS Addresses (
+		DeviceID CHAR(63) NOT NULL,
+		Seen TIMESTAMP NOT NULL,
+		Address VARCHAR(256) NOT NULL
+	)`)
+	if err != nil {
+		return err
+	}
+
+	row = db.QueryRow(`SELECT 'AddressesDeviceIDSeenIndex'::regclass`)
+	if err := row.Scan(nil); err != nil {
+		_, err = db.Exec(`CREATE INDEX AddressesDeviceIDSeenIndex ON Addresses (DeviceID, Seen)`)
+	}
+	if err != nil {
+		return err
+	}
+
+	row = db.QueryRow(`SELECT 'AddressesDeviceIDAddressIndex'::regclass`)
+	if err := row.Scan(nil); err != nil {
+		_, err = db.Exec(`CREATE INDEX AddressesDeviceIDAddressIndex ON Addresses (DeviceID, Address)`)
+	}
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func postgresCompile(db *sql.DB) (map[string]*sql.Stmt, error) {
+	stmts := map[string]string{
+		"cleanAddress":  "DELETE FROM Addresses WHERE Seen < now() - '2 hour'::INTERVAL",
+		"cleanDevice":   fmt.Sprintf("DELETE FROM Devices WHERE Seen < now() - '%d hour'::INTERVAL", maxDeviceAge/3600),
+		"countAddress":  "SELECT count(*) FROM Addresses",
+		"countDevice":   "SELECT count(*) FROM Devices",
+		"insertAddress": "INSERT INTO Addresses (DeviceID, Seen, Address) VALUES ($1, now(), $2)",
+		"insertDevice":  "INSERT INTO Devices (DeviceID, Seen) VALUES ($1, now())",
+		"selectAddress": "SELECT Address FROM Addresses WHERE DeviceID=$1 AND Seen > now() - '1 hour'::INTERVAL ORDER BY random() LIMIT 16",
+		"selectDevice":  "SELECT Seen FROM Devices WHERE DeviceID=$1",
+		"updateAddress": "UPDATE Addresses SET Seen=now() WHERE DeviceID=$1 AND Address=$2",
+		"updateDevice":  "UPDATE Devices SET Seen=now() WHERE DeviceID=$1",
+	}
+
+	res := make(map[string]*sql.Stmt, len(stmts))
+	for key, stmt := range stmts {
+		prep, err := db.Prepare(stmt)
+		if err != nil {
+			return nil, err
+		}
+		res[key] = prep
+	}
+	return res, nil
+}

+ 81 - 0
cmd/discosrv/ql.go

@@ -0,0 +1,81 @@
+// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
+
+package main
+
+import (
+	"database/sql"
+	"fmt"
+	"log"
+
+	"github.com/cznic/ql"
+)
+
+func init() {
+	ql.RegisterDriver()
+	register("ql", qlSetup, qlCompile)
+}
+
+func qlSetup(db *sql.DB) (err error) {
+	tx, err := db.Begin()
+	if err != nil {
+		return
+	}
+
+	defer func() {
+		if err == nil {
+			err = tx.Commit()
+		} else {
+			tx.Rollback()
+		}
+	}()
+
+	_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS Devices (
+		DeviceID STRING NOT NULL,
+		Seen TIME NOT NULL
+	)`)
+	if err != nil {
+		return
+	}
+
+	if _, err = tx.Exec(`CREATE INDEX IF NOT EXISTS DevicesDeviceIDIndex ON Devices (DeviceID)`); err != nil {
+		return
+	}
+
+	_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS Addresses (
+		DeviceID STRING NOT NULL,
+		Seen TIME NOT NULL,
+		Address STRING NOT NULL,
+	)`)
+	if err != nil {
+		return
+	}
+
+	_, err = tx.Exec(`CREATE INDEX IF NOT EXISTS AddressesDeviceIDAddressIndex ON Addresses (DeviceID, Address)`)
+	return
+}
+
+func qlCompile(db *sql.DB) (map[string]*sql.Stmt, error) {
+	stmts := map[string]string{
+		"cleanAddress":  `DELETE FROM Addresses WHERE Seen < now() - duration("2h")`,
+		"cleanDevice":   fmt.Sprintf(`DELETE FROM Devices WHERE Seen < now() - duration("%dh")`, maxDeviceAge/3600),
+		"countAddress":  "SELECT count(*) FROM Addresses",
+		"countDevice":   "SELECT count(*) FROM Devices",
+		"insertAddress": "INSERT INTO Addresses (DeviceID, Seen, Address) VALUES ($1, now(), $2)",
+		"insertDevice":  "INSERT INTO Devices (DeviceID, Seen) VALUES ($1, now())",
+		"selectAddress": `SELECT Address from Addresses WHERE DeviceID==$1 AND Seen > now() - duration("1h") LIMIT 16`,
+		"selectDevice":  "SELECT Seen FROM Devices WHERE DeviceID==$1",
+		"updateAddress": "UPDATE Addresses Seen=now() WHERE DeviceID==$1 AND Address==$2",
+		"updateDevice":  "UPDATE Devices Seen=now() WHERE DeviceID==$1",
+	}
+
+	res := make(map[string]*sql.Stmt, len(stmts))
+	for key, stmt := range stmts {
+		prep, err := db.Prepare(stmt)
+		if err != nil {
+			log.Println("Failed to compile", stmt)
+			return nil, err
+		}
+		res[key] = prep
+	}
+	return res, nil
+}

+ 476 - 0
cmd/discosrv/querysrv.go

@@ -0,0 +1,476 @@
+// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
+
+package main
+
+import (
+	"bytes"
+	"crypto/tls"
+	"database/sql"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"log"
+	"math/rand"
+	"net"
+	"net/http"
+	"net/url"
+	"strconv"
+	"sync"
+	"time"
+
+	"github.com/golang/groupcache/lru"
+	"github.com/juju/ratelimit"
+	"github.com/syncthing/syncthing/lib/protocol"
+	"golang.org/x/net/context"
+)
+
+type querysrv struct {
+	addr     string
+	db       *sql.DB
+	prep     map[string]*sql.Stmt
+	limiter  *safeCache
+	cert     tls.Certificate
+	listener net.Listener
+}
+
+type announcement struct {
+	Seen      time.Time `json:"seen"`
+	Addresses []string  `json:"addresses"`
+}
+
+type safeCache struct {
+	*lru.Cache
+	mut sync.Mutex
+}
+
+func (s *safeCache) Get(key string) (val interface{}, ok bool) {
+	s.mut.Lock()
+	val, ok = s.Cache.Get(key)
+	s.mut.Unlock()
+	return
+}
+
+func (s *safeCache) Add(key string, val interface{}) {
+	s.mut.Lock()
+	s.Cache.Add(key, val)
+	s.mut.Unlock()
+}
+
+type requestID int64
+
+func (i requestID) String() string {
+	return fmt.Sprintf("%016x", int64(i))
+}
+
+func negCacheFor(lastSeen time.Time) int {
+	since := time.Since(lastSeen).Seconds()
+	if since >= maxDeviceAge {
+		return maxNegCache
+	}
+	if since < 0 {
+		// That's weird
+		return minNegCache
+	}
+
+	// Return a value linearly scaled from minNegCache (at zero seconds ago)
+	// to maxNegCache (at maxDeviceAge seconds ago).
+	r := since / maxDeviceAge
+	return int(minNegCache + r*(maxNegCache-minNegCache))
+}
+
+func (s *querysrv) Serve() {
+	s.limiter = &safeCache{
+		Cache: lru.New(lruSize),
+	}
+
+	if useHttp {
+		listener, err := net.Listen("tcp", s.addr)
+		if err != nil {
+			log.Println("Listen:", err)
+			return
+		}
+		s.listener = listener
+	} else {
+		tlsCfg := &tls.Config{
+			Certificates:           []tls.Certificate{s.cert},
+			ClientAuth:             tls.RequestClientCert,
+			SessionTicketsDisabled: true,
+			MinVersion:             tls.VersionTLS12,
+			CipherSuites: []uint16{
+				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+				tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+				tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+				tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+				tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+			},
+		}
+
+		tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg)
+		if err != nil {
+			log.Println("Listen:", err)
+			return
+		}
+		s.listener = tlsListener
+	}
+
+	http.HandleFunc("/v2/", s.handler)
+	http.HandleFunc("/ping", handlePing)
+
+	srv := &http.Server{
+		ReadTimeout:    5 * time.Second,
+		WriteTimeout:   5 * time.Second,
+		MaxHeaderBytes: 1 << 10,
+	}
+
+	if err := srv.Serve(s.listener); err != nil {
+		log.Println("Serve:", err)
+	}
+}
+
+var topCtx = context.Background()
+
+func (s *querysrv) handler(w http.ResponseWriter, req *http.Request) {
+	reqID := requestID(rand.Int63())
+	ctx := context.WithValue(topCtx, "id", reqID)
+
+	if debug {
+		log.Println(reqID, req.Method, req.URL)
+	}
+
+	t0 := time.Now()
+	defer func() {
+		diff := time.Since(t0)
+		var comment string
+		if diff > time.Second {
+			comment = "(very slow request)"
+		} else if diff > 100*time.Millisecond {
+			comment = "(slow request)"
+		}
+		if comment != "" || debug {
+			log.Println(reqID, req.Method, req.URL, "completed in", diff, comment)
+		}
+	}()
+
+	var remoteIP net.IP
+	if useHttp {
+		remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
+	} else {
+		addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr)
+		if err != nil {
+			log.Println("remoteAddr:", err)
+			http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+			return
+		}
+		remoteIP = addr.IP
+	}
+
+	if s.limit(remoteIP) {
+		if debug {
+			log.Println(remoteIP, "is limited")
+		}
+		w.Header().Set("Retry-After", "60")
+		http.Error(w, "Too Many Requests", 429)
+		return
+	}
+
+	switch req.Method {
+	case "GET":
+		s.handleGET(ctx, w, req)
+	case "POST":
+		s.handlePOST(ctx, remoteIP, w, req)
+	default:
+		globalStats.Error()
+		http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
+	}
+}
+
+func (s *querysrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http.Request) {
+	reqID := ctx.Value("id").(requestID)
+
+	deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device"))
+	if err != nil {
+		if debug {
+			log.Println(reqID, "bad device param")
+		}
+		globalStats.Error()
+		http.Error(w, "Bad Request", http.StatusBadRequest)
+		return
+	}
+
+	var ann announcement
+
+	ann.Seen, err = s.getDeviceSeen(deviceID)
+	negCache := strconv.Itoa(negCacheFor(ann.Seen))
+	w.Header().Set("Retry-After", negCache)
+	w.Header().Set("Cache-Control", "public, max-age="+negCache)
+
+	if err != nil {
+		// The device is not in the database.
+		globalStats.Query()
+		http.Error(w, "Not Found", http.StatusNotFound)
+		return
+	}
+
+	t0 := time.Now()
+	ann.Addresses, err = s.getAddresses(ctx, deviceID)
+	if err != nil {
+		log.Println(reqID, "getAddresses:", err)
+		globalStats.Error()
+		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+		return
+	}
+	if debug {
+		log.Println(reqID, "getAddresses in", time.Since(t0))
+	}
+
+	globalStats.Query()
+
+	if len(ann.Addresses) == 0 {
+		http.Error(w, "Not Found", http.StatusNotFound)
+		return
+	}
+
+	globalStats.Answer()
+
+	w.Header().Set("Content-Type", "application/json")
+	json.NewEncoder(w).Encode(ann)
+}
+
+func (s *querysrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.ResponseWriter, req *http.Request) {
+	reqID := ctx.Value("id").(requestID)
+
+	rawCert := certificateBytes(req)
+	if rawCert == nil {
+		if debug {
+			log.Println(reqID, "no certificates")
+		}
+		globalStats.Error()
+		http.Error(w, "Forbidden", http.StatusForbidden)
+		return
+	}
+
+	var ann announcement
+	if err := json.NewDecoder(req.Body).Decode(&ann); err != nil {
+		if debug {
+			log.Println(reqID, "decode:", err)
+		}
+		globalStats.Error()
+		http.Error(w, "Bad Request", http.StatusBadRequest)
+		return
+	}
+
+	deviceID := protocol.NewDeviceID(rawCert)
+
+	// handleAnnounce returns *two* errors. The first indicates a problem with
+	// something the client posted to us. We should return a 400 Bad Request
+	// and not worry about it. The second indicates that the request was fine,
+	// but something internal messed up. We should log it and respond with a
+	// more apologetic 500 Internal Server Error.
+	userErr, internalErr := s.handleAnnounce(ctx, remoteIP, deviceID, ann.Addresses)
+	if userErr != nil {
+		if debug {
+			log.Println(reqID, "handleAnnounce:", userErr)
+		}
+		globalStats.Error()
+		http.Error(w, "Bad Request", http.StatusBadRequest)
+		return
+	}
+	if internalErr != nil {
+		log.Println(reqID, "handleAnnounce:", internalErr)
+		globalStats.Error()
+		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+		return
+	}
+
+	globalStats.Announce()
+
+	// TODO: Slowly increase this for stable clients
+	w.Header().Set("Reannounce-After", "1800")
+
+	// We could return the lookup result here, but it's kind of unnecessarily
+	// expensive to go query the database again so we let the client decide to
+	// do a lookup if they really care.
+	w.WriteHeader(http.StatusNoContent)
+}
+
+func (s *querysrv) Stop() {
+	s.listener.Close()
+}
+
+func (s *querysrv) handleAnnounce(ctx context.Context, remote net.IP, deviceID protocol.DeviceID, addresses []string) (userErr, internalErr error) {
+	reqID := ctx.Value("id").(requestID)
+
+	tx, err := s.db.Begin()
+	if err != nil {
+		internalErr = err
+		return
+	}
+
+	defer func() {
+		// Since we return from a bunch of different places, we handle
+		// rollback in the defer.
+		if internalErr != nil || userErr != nil {
+			tx.Rollback()
+		}
+	}()
+
+	for _, annAddr := range addresses {
+		uri, err := url.Parse(annAddr)
+		if err != nil {
+			userErr = err
+			return
+		}
+
+		host, port, err := net.SplitHostPort(uri.Host)
+		if err != nil {
+			userErr = err
+			return
+		}
+
+		ip := net.ParseIP(host)
+		if len(ip) == 0 || ip.IsUnspecified() {
+			uri.Host = net.JoinHostPort(remote.String(), port)
+		}
+
+		if err := s.updateAddress(ctx, tx, deviceID, uri.String()); err != nil {
+			internalErr = err
+			return
+		}
+	}
+
+	if err := s.updateDevice(ctx, tx, deviceID); err != nil {
+		internalErr = err
+		return
+	}
+
+	t0 := time.Now()
+	internalErr = tx.Commit()
+	if debug {
+		log.Println(reqID, "commit in", time.Since(t0))
+	}
+	return
+}
+
+func (s *querysrv) limit(remote net.IP) bool {
+	key := remote.String()
+
+	bkt, ok := s.limiter.Get(key)
+	if ok {
+		bkt := bkt.(*ratelimit.Bucket)
+		if bkt.TakeAvailable(1) != 1 {
+			// Rate limit exceeded; ignore packet
+			return true
+		}
+	} else {
+		// One packet per ten seconds average rate, burst ten packets
+		s.limiter.Add(key, ratelimit.NewBucket(10*time.Second/time.Duration(limitAvg), int64(limitBurst)))
+	}
+
+	return false
+}
+
+func (s *querysrv) updateDevice(ctx context.Context, tx *sql.Tx, device protocol.DeviceID) error {
+	reqID := ctx.Value("id").(requestID)
+	t0 := time.Now()
+	res, err := tx.Stmt(s.prep["updateDevice"]).Exec(device.String())
+	if err != nil {
+		return err
+	}
+	if debug {
+		log.Println(reqID, "updateDevice in", time.Since(t0))
+	}
+
+	if rows, _ := res.RowsAffected(); rows == 0 {
+		t0 = time.Now()
+		_, err := tx.Stmt(s.prep["insertDevice"]).Exec(device.String())
+		if err != nil {
+			return err
+		}
+		if debug {
+			log.Println(reqID, "insertDevice in", time.Since(t0))
+		}
+	}
+
+	return nil
+}
+
+func (s *querysrv) updateAddress(ctx context.Context, tx *sql.Tx, device protocol.DeviceID, uri string) error {
+	res, err := tx.Stmt(s.prep["updateAddress"]).Exec(device.String(), uri)
+	if err != nil {
+		return err
+	}
+
+	if rows, _ := res.RowsAffected(); rows == 0 {
+		_, err := tx.Stmt(s.prep["insertAddress"]).Exec(device.String(), uri)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (s *querysrv) getAddresses(ctx context.Context, device protocol.DeviceID) ([]string, error) {
+	rows, err := s.prep["selectAddress"].Query(device.String())
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+
+	var res []string
+	for rows.Next() {
+		var addr string
+
+		err := rows.Scan(&addr)
+		if err != nil {
+			log.Println("Scan:", err)
+			continue
+		}
+		res = append(res, addr)
+	}
+
+	return res, nil
+}
+
+func (s *querysrv) getDeviceSeen(device protocol.DeviceID) (time.Time, error) {
+	row := s.prep["selectDevice"].QueryRow(device.String())
+	var seen time.Time
+	if err := row.Scan(&seen); err != nil {
+		return time.Time{}, err
+	}
+	return seen, nil
+}
+
+func handlePing(w http.ResponseWriter, r *http.Request) {
+	w.WriteHeader(204)
+}
+
+func certificateBytes(req *http.Request) []byte {
+	if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
+		return req.TLS.PeerCertificates[0].Raw
+	}
+
+	if hdr := req.Header.Get("X-SSL-Cert"); hdr != "" {
+		bs := []byte(hdr)
+		// The certificate is in PEM format but with spaces for newlines. We
+		// need to reinstate the newlines for the PEM decoder. But we need to
+		// leave the spaces in the BEGIN and END lines - the first and last
+		// space - alone.
+		firstSpace := bytes.Index(bs, []byte(" "))
+		lastSpace := bytes.LastIndex(bs, []byte(" "))
+		for i := firstSpace + 1; i < lastSpace; i++ {
+			if bs[i] == ' ' {
+				bs[i] = '\n'
+			}
+		}
+		block, _ := pem.Decode(bs)
+		if block == nil {
+			// Decoding failed
+			return nil
+		}
+		return block.Bytes
+	}
+
+	return nil
+}

+ 141 - 0
cmd/discosrv/stats.go

@@ -0,0 +1,141 @@
+// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
+
+package main
+
+import (
+	"bytes"
+	"database/sql"
+	"fmt"
+	"io/ioutil"
+	"log"
+	"os"
+	"sync/atomic"
+	"time"
+)
+
+type stats struct {
+	// Incremented atomically
+	announces int64
+	queries   int64
+	answers   int64
+	errors    int64
+}
+
+func (s *stats) Announce() {
+	atomic.AddInt64(&s.announces, 1)
+}
+
+func (s *stats) Query() {
+	atomic.AddInt64(&s.queries, 1)
+}
+
+func (s *stats) Answer() {
+	atomic.AddInt64(&s.answers, 1)
+}
+
+func (s *stats) Error() {
+	atomic.AddInt64(&s.errors, 1)
+}
+
+// Reset returns a copy of the current stats and resets the counters to
+// zero.
+func (s *stats) Reset() stats {
+	// Create a copy of the stats using atomic reads
+	copy := stats{
+		announces: atomic.LoadInt64(&s.announces),
+		queries:   atomic.LoadInt64(&s.queries),
+		answers:   atomic.LoadInt64(&s.answers),
+		errors:    atomic.LoadInt64(&s.errors),
+	}
+
+	// Reset the stats by subtracting the values that we copied
+	atomic.AddInt64(&s.announces, -copy.announces)
+	atomic.AddInt64(&s.queries, -copy.queries)
+	atomic.AddInt64(&s.answers, -copy.answers)
+	atomic.AddInt64(&s.errors, -copy.errors)
+
+	return copy
+}
+
+type statssrv struct {
+	intv time.Duration
+	file string
+	db   *sql.DB
+}
+
+func (s *statssrv) Serve() {
+	lastReset := time.Now()
+	for {
+		time.Sleep(next(s.intv))
+
+		stats := globalStats.Reset()
+		d := time.Since(lastReset).Seconds()
+		lastReset = time.Now()
+
+		log.Printf("Stats: %.02f announces/s, %.02f queries/s, %.02f answers/s, %.02f errors/s",
+			float64(stats.announces)/d, float64(stats.queries)/d, float64(stats.answers)/d, float64(stats.errors)/d)
+
+		if s.file != "" {
+			s.writeToFile(stats, d)
+		}
+	}
+}
+
+func (s *statssrv) Stop() {
+	panic("stop unimplemented")
+}
+
+func (s *statssrv) writeToFile(stats stats, secs float64) {
+	newLine := []byte("\n")
+
+	var addrs int
+	row := s.db.QueryRow("SELECT COUNT(*) FROM Addresses")
+	if err := row.Scan(&addrs); err != nil {
+		log.Println("stats query:", err)
+		return
+	}
+
+	fd, err := os.OpenFile(s.file, os.O_RDWR|os.O_CREATE, 0666)
+	if err != nil {
+		log.Println("stats file:", err)
+		return
+	}
+	defer func() {
+		err = fd.Close()
+		if err != nil {
+			log.Println("stats file:", err)
+		}
+	}()
+
+	bs, err := ioutil.ReadAll(fd)
+	if err != nil {
+		log.Println("stats file:", err)
+		return
+	}
+	lines := bytes.Split(bytes.TrimSpace(bs), newLine)
+	if len(lines) > 12 {
+		lines = lines[len(lines)-12:]
+	}
+
+	latest := fmt.Sprintf("%v: %6d addresses, %8.02f announces/s, %8.02f queries/s, %8.02f answers/s, %8.02f errors/s\n",
+		time.Now().UTC().Format(time.RFC3339), addrs,
+		float64(stats.announces)/secs, float64(stats.queries)/secs, float64(stats.answers)/secs, float64(stats.errors)/secs)
+	lines = append(lines, []byte(latest))
+
+	_, err = fd.Seek(0, 0)
+	if err != nil {
+		log.Println("stats file:", err)
+		return
+	}
+	err = fd.Truncate(0)
+	if err != nil {
+		log.Println("stats file:", err)
+		return
+	}
+
+	_, err = fd.Write(bytes.Join(lines, newLine))
+	if err != nil {
+		log.Println("stats file:", err)
+		return
+	}
+}