|
@@ -132,11 +132,19 @@ func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
|
|
|
log.Println(reqID, req.Method, req.URL)
|
|
|
}
|
|
|
|
|
|
- var remoteIP net.IP
|
|
|
+ remoteAddr := &net.TCPAddr{
|
|
|
+ IP: nil,
|
|
|
+ Port: -1,
|
|
|
+ }
|
|
|
+
|
|
|
if s.useHTTP {
|
|
|
- remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
|
|
|
+ remoteAddr.IP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
|
|
|
+ if parsedPort, err := strconv.ParseInt(req.Header.Get("X-Forwarded-Port"), 10, 0); err == nil {
|
|
|
+ remoteAddr.Port = int(parsedPort)
|
|
|
+ }
|
|
|
} else {
|
|
|
- addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr)
|
|
|
+ var err error
|
|
|
+ remoteAddr, err = net.ResolveTCPAddr("tcp", req.RemoteAddr)
|
|
|
if err != nil {
|
|
|
log.Println("remoteAddr:", err)
|
|
|
lw.Header().Set("Retry-After", errorRetryAfterString())
|
|
@@ -144,14 +152,13 @@ func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
|
|
|
apiRequestsTotal.WithLabelValues("no_remote_addr").Inc()
|
|
|
return
|
|
|
}
|
|
|
- remoteIP = addr.IP
|
|
|
}
|
|
|
|
|
|
switch req.Method {
|
|
|
case "GET":
|
|
|
s.handleGET(ctx, lw, req)
|
|
|
case "POST":
|
|
|
- s.handlePOST(ctx, remoteIP, lw, req)
|
|
|
+ s.handlePOST(ctx, remoteAddr, lw, req)
|
|
|
default:
|
|
|
http.Error(lw, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
|
}
|
|
@@ -217,7 +224,7 @@ func (s *apiSrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http
|
|
|
w.Write(bs)
|
|
|
}
|
|
|
|
|
|
-func (s *apiSrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.ResponseWriter, req *http.Request) {
|
|
|
+func (s *apiSrv) handlePOST(ctx context.Context, remoteAddr *net.TCPAddr, w http.ResponseWriter, req *http.Request) {
|
|
|
reqID := ctx.Value(idKey).(requestID)
|
|
|
|
|
|
rawCert := certificateBytes(req)
|
|
@@ -244,7 +251,7 @@ func (s *apiSrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.Respons
|
|
|
|
|
|
deviceID := protocol.NewDeviceID(rawCert)
|
|
|
|
|
|
- addresses := fixupAddresses(remoteIP, ann.Addresses)
|
|
|
+ addresses := fixupAddresses(remoteAddr, ann.Addresses)
|
|
|
if len(addresses) == 0 {
|
|
|
announceRequestsTotal.WithLabelValues("bad_request").Inc()
|
|
|
w.Header().Set("Retry-After", errorRetryAfterString())
|
|
@@ -252,7 +259,7 @@ func (s *apiSrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.Respons
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if err := s.handleAnnounce(remoteIP, deviceID, addresses); err != nil {
|
|
|
+ if err := s.handleAnnounce(deviceID, addresses); err != nil {
|
|
|
announceRequestsTotal.WithLabelValues("internal_error").Inc()
|
|
|
w.Header().Set("Retry-After", errorRetryAfterString())
|
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
@@ -269,7 +276,7 @@ func (s *apiSrv) Stop() {
|
|
|
s.listener.Close()
|
|
|
}
|
|
|
|
|
|
-func (s *apiSrv) handleAnnounce(remote net.IP, deviceID protocol.DeviceID, addresses []string) error {
|
|
|
+func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) error {
|
|
|
key := deviceID.String()
|
|
|
now := time.Now()
|
|
|
expire := now.Add(addressExpiryTime).UnixNano()
|
|
@@ -364,7 +371,7 @@ func certificateBytes(req *http.Request) []byte {
|
|
|
|
|
|
// fixupAddresses checks the list of addresses, removing invalid ones and
|
|
|
// replacing unspecified IPs with the given remote IP.
|
|
|
-func fixupAddresses(remote net.IP, addresses []string) []string {
|
|
|
+func fixupAddresses(remote *net.TCPAddr, addresses []string) []string {
|
|
|
fixed := make([]string, 0, len(addresses))
|
|
|
for _, annAddr := range addresses {
|
|
|
uri, err := url.Parse(annAddr)
|
|
@@ -384,27 +391,34 @@ func fixupAddresses(remote net.IP, addresses []string) []string {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- if host == "" || ip.IsUnspecified() {
|
|
|
- // Replace the unspecified IP with the request source.
|
|
|
+ if remote != nil {
|
|
|
+ if host == "" || ip.IsUnspecified() {
|
|
|
+ // Replace the unspecified IP with the request source.
|
|
|
|
|
|
- // ... unless the request source is the loopback address or
|
|
|
- // multicast/unspecified (can't happen, really).
|
|
|
- if remote.IsLoopback() || remote.IsMulticast() || remote.IsUnspecified() {
|
|
|
- continue
|
|
|
- }
|
|
|
+ // ... unless the request source is the loopback address or
|
|
|
+ // multicast/unspecified (can't happen, really).
|
|
|
+ if remote.IP.IsLoopback() || remote.IP.IsMulticast() || remote.IP.IsUnspecified() {
|
|
|
+ continue
|
|
|
+ }
|
|
|
|
|
|
- // Do not use IPv6 remote address if requested scheme is ...4
|
|
|
- // (i.e., tcp4, etc.)
|
|
|
- if strings.HasSuffix(uri.Scheme, "4") && remote.To4() == nil {
|
|
|
- continue
|
|
|
- }
|
|
|
+ // Do not use IPv6 remote address if requested scheme is ...4
|
|
|
+ // (i.e., tcp4, etc.)
|
|
|
+ if strings.HasSuffix(uri.Scheme, "4") && remote.IP.To4() == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
|
|
|
- // Do not use IPv4 remote address if requested scheme is ...6
|
|
|
- if strings.HasSuffix(uri.Scheme, "6") && remote.To4() != nil {
|
|
|
- continue
|
|
|
+ // Do not use IPv4 remote address if requested scheme is ...6
|
|
|
+ if strings.HasSuffix(uri.Scheme, "6") && remote.IP.To4() != nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ host = remote.IP.String()
|
|
|
}
|
|
|
|
|
|
- host = remote.String()
|
|
|
+ // If zero port was specified, use remote port.
|
|
|
+ if port == "0" && remote.Port > 0 {
|
|
|
+ port = fmt.Sprintf("%d", remote.Port)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
uri.Host = net.JoinHostPort(host, port)
|