ソースを参照

lib/connections: Actually fix LAN detection, for real (ref #4534)

Jakob Borg 8 年 前
コミット
1e9769cdd7
2 ファイル変更71 行追加11 行削除
  1. 47 0
      lib/connections/lan_test.go
  2. 24 11
      lib/connections/service.go

+ 47 - 0
lib/connections/lan_test.go

@@ -0,0 +1,47 @@
+// Copyright (C) 2017 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+package connections
+
+import (
+	"testing"
+
+	"github.com/syncthing/syncthing/lib/config"
+)
+
+func TestIsLANHost(t *testing.T) {
+	cases := []struct {
+		addr string
+		lan  bool
+	}{
+		// loopback
+		{"127.0.0.1:22000", true},
+		{"127.0.0.1", true},
+		// local nets
+		{"10.20.30.40:22000", true},
+		{"10.20.30.40", true},
+		// neither
+		{"192.0.2.1:22000", false},
+		{"192.0.2.1", false},
+		// doesn't resolve
+		{"[banana::phone]:hello", false},
+		{"„‹›fl´fi·‰ˇ¨Á˝", false},
+	}
+
+	cfg := config.Wrap("/dev/null", config.Configuration{
+		Options: config.OptionsConfiguration{
+			AlwaysLocalNets: []string{"10.20.30.0/24"},
+		},
+	})
+	s := &Service{cfg: cfg}
+
+	for _, tc := range cases {
+		res := s.isLANHost(tc.addr)
+		if res != tc.lan {
+			t.Errorf("isLANHost(%q) => %v, expected %v", tc.addr, res, tc.lan)
+		}
+	}
+}

+ 24 - 11
lib/connections/service.go

@@ -424,23 +424,36 @@ func (s *Service) connect() {
 }
 
 func (s *Service) isLANHost(host string) bool {
-	if noPort, _, err := net.SplitHostPort(host); err == nil && noPort != "" {
-		host = noPort
+	// Probably we are called with an ip:port combo which we can resolve as
+	// a TCP address.
+	if addr, err := net.ResolveTCPAddr("tcp", host); err == nil {
+		return s.isLAN(addr)
 	}
-	addr, err := net.ResolveIPAddr("ip", host)
-	if err != nil {
-		return false
+	// ... but this function looks general enough that someone might try
+	// with just an IP as well in the future so lets allow that.
+	if addr, err := net.ResolveIPAddr("ip", host); err == nil {
+		return s.isLAN(addr)
 	}
-	return s.isLAN(addr)
+	return false
 }
 
 func (s *Service) isLAN(addr net.Addr) bool {
-	tcpaddr, ok := addr.(*net.TCPAddr)
-	if !ok {
+	var ip net.IP
+
+	switch addr := addr.(type) {
+	case *net.IPAddr:
+		ip = addr.IP
+	case *net.TCPAddr:
+		ip = addr.IP
+	case *net.UDPAddr:
+		ip = addr.IP
+	default:
+		// From the standard library, just Unix sockets.
+		// If you invent your own, handle it.
 		return false
 	}
 
-	if tcpaddr.IP.IsLoopback() {
+	if ip.IsLoopback() {
 		return true
 	}
 
@@ -450,14 +463,14 @@ func (s *Service) isLAN(addr net.Addr) bool {
 			l.Debugln("Network", lan, "is malformed:", err)
 			continue
 		}
-		if ipnet.Contains(tcpaddr.IP) {
+		if ipnet.Contains(ip) {
 			return true
 		}
 	}
 
 	lans, _ := osutil.GetLans()
 	for _, lan := range lans {
-		if lan.Contains(tcpaddr.IP) {
+		if lan.Contains(ip) {
 			return true
 		}
 	}