Преглед изворни кода

lib/protocol: Avoid panic in DeviceIDFromBytes (#6714)

André Colomb пре 5 година
родитељ
комит
46536509d7

+ 9 - 5
cmd/stindex/dump.go

@@ -73,12 +73,16 @@ func dump(ldb backend.Backend) {
 		case db.KeyTypeDeviceIdx:
 		case db.KeyTypeDeviceIdx:
 			key := binary.BigEndian.Uint32(key[1:])
 			key := binary.BigEndian.Uint32(key[1:])
 			val := it.Value()
 			val := it.Value()
-			if len(val) == 0 {
-				fmt.Printf("[deviceidx] K:%d V:<nil>\n", key)
-			} else {
-				dev := protocol.DeviceIDFromBytes(val)
-				fmt.Printf("[deviceidx] K:%d V:%s\n", key, dev)
+			device := "<nil>"
+			if len(val) > 0 {
+				dev, err := protocol.DeviceIDFromBytes(val)
+				if err != nil {
+					device = fmt.Sprintf("<invalid %d bytes>", len(val))
+				} else {
+					device = dev.String()
+				}
 			}
 			}
+			fmt.Printf("[deviceidx] K:%d V:%s\n", key, device)
 
 
 		case db.KeyTypeIndexID:
 		case db.KeyTypeIndexID:
 			device := binary.BigEndian.Uint32(key[1:])
 			device := binary.BigEndian.Uint32(key[1:])

+ 9 - 1
cmd/strelaysrv/listener.go

@@ -149,7 +149,15 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
 				protocol.WriteMessage(conn, protocol.ResponseSuccess)
 				protocol.WriteMessage(conn, protocol.ResponseSuccess)
 
 
 			case protocol.ConnectRequest:
 			case protocol.ConnectRequest:
-				requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID)
+				requestedPeer, err := syncthingprotocol.DeviceIDFromBytes(msg.ID)
+				if err != nil {
+					if debug {
+						log.Println(id, "is looking for an invalid peer ID")
+					}
+					protocol.WriteMessage(conn, protocol.ResponseNotFound)
+					conn.Close()
+					continue
+				}
 				outboxesMut.RLock()
 				outboxesMut.RLock()
 				peerOutbox, ok := outboxes[requestedPeer]
 				peerOutbox, ok := outboxes[requestedPeer]
 				outboxesMut.RUnlock()
 				outboxesMut.RUnlock()

+ 5 - 1
lib/db/db_test.go

@@ -264,7 +264,11 @@ func TestUpdate0to3(t *testing.T) {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
 		if !ok {
 		if !ok {
-			t.Fatal("surprise missing global file", string(name), protocol.DeviceIDFromBytes(vl.Versions[0].Device))
+			device := "<invalid>"
+			if dev, err := protocol.DeviceIDFromBytes(vl.Versions[0].Device); err != nil {
+				device = dev.String()
+			}
+			t.Fatal("surprise missing global file", string(name), device)
 		}
 		}
 		e, ok := need[fi.FileName()]
 		e, ok := need[fi.FileName()]
 		if !ok {
 		if !ok {

+ 4 - 1
lib/db/lowlevel.go

@@ -121,7 +121,10 @@ func (db *Lowlevel) updateRemoteFiles(folder, device []byte, fs []protocol.FileI
 	defer t.close()
 	defer t.close()
 
 
 	var dk, gk, keyBuf []byte
 	var dk, gk, keyBuf []byte
-	devID := protocol.DeviceIDFromBytes(device)
+	devID, err := protocol.DeviceIDFromBytes(device)
+	if err != nil {
+		return err
+	}
 	for _, f := range fs {
 	for _, f := range fs {
 		name := []byte(f.Name)
 		name := []byte(f.Name)
 		dk, err = db.keyer.GenerateDeviceFileKey(dk, folder, device, name)
 		dk, err = db.keyer.GenerateDeviceFileKey(dk, folder, device, name)

+ 9 - 2
lib/db/meta.go

@@ -52,7 +52,11 @@ func (m *metadataTracker) Unmarshal(bs []byte) error {
 
 
 	// Initialize the index map
 	// Initialize the index map
 	for i, c := range m.counts.Counts {
 	for i, c := range m.counts.Counts {
-		m.indexes[metaKey{protocol.DeviceIDFromBytes(c.DeviceID), c.LocalFlags}] = i
+		dev, err := protocol.DeviceIDFromBytes(c.DeviceID)
+		if err != nil {
+			return err
+		}
+		m.indexes[metaKey{dev, c.LocalFlags}] = i
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -392,7 +396,10 @@ func (m *countsMap) devices() []protocol.DeviceID {
 
 
 	for _, dev := range m.counts.Counts {
 	for _, dev := range m.counts.Counts {
 		if dev.Sequence > 0 {
 		if dev.Sequence > 0 {
-			id := protocol.DeviceIDFromBytes(dev.DeviceID)
+			id, err := protocol.DeviceIDFromBytes(dev.DeviceID)
+			if err != nil {
+				panic(err)
+			}
 			if id == protocol.GlobalDeviceID || id == protocol.LocalDeviceID {
 			if id == protocol.GlobalDeviceID || id == protocol.LocalDeviceID {
 				continue
 				continue
 			}
 			}

+ 17 - 4
lib/db/transactions.go

@@ -414,7 +414,11 @@ func (t *readOnlyTransaction) availability(folder, file []byte) ([]protocol.Devi
 	}
 	}
 	devices := make([]protocol.DeviceID, len(fv.Devices))
 	devices := make([]protocol.DeviceID, len(fv.Devices))
 	for i, dev := range fv.Devices {
 	for i, dev := range fv.Devices {
-		devices[i] = protocol.DeviceIDFromBytes(dev)
+		n, err := protocol.DeviceIDFromBytes(dev)
+		if err != nil {
+			return nil, err
+		}
+		devices[i] = n
 	}
 	}
 
 
 	return devices, nil
 	return devices, nil
@@ -436,7 +440,10 @@ func (t *readOnlyTransaction) withNeed(folder, device []byte, truncate bool, fn
 	defer dbi.Release()
 	defer dbi.Release()
 
 
 	var dk []byte
 	var dk []byte
-	devID := protocol.DeviceIDFromBytes(device)
+	devID, err := protocol.DeviceIDFromBytes(device)
+	if err != nil {
+		return err
+	}
 	for dbi.Next() {
 	for dbi.Next() {
 		var vl VersionList
 		var vl VersionList
 		if err := vl.Unmarshal(dbi.Value()); err != nil {
 		if err := vl.Unmarshal(dbi.Value()); err != nil {
@@ -592,7 +599,10 @@ func (t readWriteTransaction) putFile(fkey []byte, fi protocol.FileInfo, truncat
 // file. If the device is already present in the list, the version is updated.
 // file. If the device is already present in the list, the version is updated.
 // If the file does not have an entry in the global list, it is created.
 // If the file does not have an entry in the global list, it is created.
 func (t readWriteTransaction) updateGlobal(gk, keyBuf, folder, device []byte, file protocol.FileInfo, meta *metadataTracker) ([]byte, bool, error) {
 func (t readWriteTransaction) updateGlobal(gk, keyBuf, folder, device []byte, file protocol.FileInfo, meta *metadataTracker) ([]byte, bool, error) {
-	deviceID := protocol.DeviceIDFromBytes(device)
+	deviceID, err := protocol.DeviceIDFromBytes(device)
+	if err != nil {
+		return nil, false, err
+	}
 
 
 	l.Debugf("update global; folder=%q device=%v file=%q version=%v invalid=%v", folder, deviceID, file.Name, file.Version, file.IsInvalid())
 	l.Debugf("update global; folder=%q device=%v file=%q version=%v invalid=%v", folder, deviceID, file.Name, file.Version, file.IsInvalid())
 
 
@@ -768,7 +778,10 @@ func need(global FileVersion, haveLocal bool, localVersion protocol.Vector) bool
 // given file. If the version list is empty after this, the file entry is
 // given file. If the version list is empty after this, the file entry is
 // removed entirely.
 // removed entirely.
 func (t readWriteTransaction) removeFromGlobal(gk, keyBuf, folder, device, file []byte, meta *metadataTracker) ([]byte, error) {
 func (t readWriteTransaction) removeFromGlobal(gk, keyBuf, folder, device, file []byte, meta *metadataTracker) ([]byte, error) {
-	deviceID := protocol.DeviceIDFromBytes(device)
+	deviceID, err := protocol.DeviceIDFromBytes(device)
+	if err != nil {
+		return nil, err
+	}
 
 
 	l.Debugf("remove from global; folder=%q device=%v file=%q", folder, deviceID, file)
 	l.Debugf("remove from global; folder=%q device=%v file=%q", folder, deviceID, file)
 
 

+ 3 - 3
lib/protocol/deviceid.go

@@ -46,13 +46,13 @@ func DeviceIDFromString(s string) (DeviceID, error) {
 	return n, err
 	return n, err
 }
 }
 
 
-func DeviceIDFromBytes(bs []byte) DeviceID {
+func DeviceIDFromBytes(bs []byte) (DeviceID, error) {
 	var n DeviceID
 	var n DeviceID
 	if len(bs) != len(n) {
 	if len(bs) != len(n) {
-		panic("incorrect length of byte slice representing device ID")
+		return n, fmt.Errorf("incorrect length of byte slice representing device ID")
 	}
 	}
 	copy(n[:], bs)
 	copy(n[:], bs)
-	return n
+	return n, nil
 }
 }
 
 
 // String returns the canonical string representation of the device ID
 // String returns the canonical string representation of the device ID

+ 8 - 3
lib/protocol/deviceid_test.go

@@ -99,8 +99,10 @@ func TestShortIDString(t *testing.T) {
 
 
 func TestDeviceIDFromBytes(t *testing.T) {
 func TestDeviceIDFromBytes(t *testing.T) {
 	id0, _ := DeviceIDFromString(formatted)
 	id0, _ := DeviceIDFromString(formatted)
-	id1 := DeviceIDFromBytes(id0[:])
-	if id1.String() != formatted {
+	id1, err := DeviceIDFromBytes(id0[:])
+	if err != nil {
+		t.Fatal(err)
+	} else if id1.String() != formatted {
 		t.Errorf("Wrong device ID, got %q, want %q", id1, formatted)
 		t.Errorf("Wrong device ID, got %q, want %q", id1, formatted)
 	}
 	}
 }
 }
@@ -150,7 +152,10 @@ func TestNewDeviceIDMarshalling(t *testing.T) {
 
 
 	// Verify it's the same
 	// Verify it's the same
 
 
-	if DeviceIDFromBytes(msg2.Test) != id0 {
+	id1, err := DeviceIDFromBytes(msg2.Test)
+	if err != nil {
+		t.Fatal(err)
+	} else if id1 != id0 {
 		t.Error("Mismatch in old -> new direction")
 		t.Error("Mismatch in old -> new direction")
 	}
 	}
 }
 }

+ 5 - 1
lib/relay/protocol/packets.go

@@ -56,7 +56,11 @@ type SessionInvitation struct {
 }
 }
 
 
 func (i SessionInvitation) String() string {
 func (i SessionInvitation) String() string {
-	return fmt.Sprintf("%s@%s", syncthingprotocol.DeviceIDFromBytes(i.From), i.AddressString())
+	device := "<invalid>"
+	if address, err := syncthingprotocol.DeviceIDFromBytes(i.From); err == nil {
+		device = address.String()
+	}
+	return fmt.Sprintf("%s@%s", device, i.AddressString())
 }
 }
 
 
 func (i SessionInvitation) GoString() string {
 func (i SessionInvitation) GoString() string {