Ver código fonte

lib/protocol: faster Luhn algorithm and better testing (#6475)

The previous implementation was very generic; its tests didn't cover the
actual alphabet for device IDs.

Benchmark results on amd64:

name         old time/op    new time/op     delta
Luhnify-8      1.00µs ± 1%     0.28µs ± 4%   -72.38%  (p=0.000 n=9+10)
Unluhnify-8     992ns ± 2%      274ns ± 1%   -72.39%  (p=0.000 n=10+9)
greatroar 5 anos atrás
pai
commit
1e2379df1b
3 arquivos alterados com 29 adições e 60 exclusões
  1. 2 2
      lib/protocol/deviceid.go
  2. 19 28
      lib/protocol/luhn.go
  3. 8 30
      lib/protocol/luhn_test.go

+ 2 - 2
lib/protocol/deviceid.go

@@ -165,7 +165,7 @@ func luhnify(s string) (string, error) {
 	for i := 0; i < 4; i++ {
 	for i := 0; i < 4; i++ {
 		p := s[i*13 : (i+1)*13]
 		p := s[i*13 : (i+1)*13]
 		copy(res[i*(13+1):], p)
 		copy(res[i*(13+1):], p)
-		l, err := luhnBase32.generate(p)
+		l, err := luhn32(p)
 		if err != nil {
 		if err != nil {
 			return "", err
 			return "", err
 		}
 		}
@@ -183,7 +183,7 @@ func unluhnify(s string) (string, error) {
 	for i := 0; i < 4; i++ {
 	for i := 0; i < 4; i++ {
 		p := s[i*(13+1) : (i+1)*(13+1)-1]
 		p := s[i*(13+1) : (i+1)*(13+1)-1]
 		copy(res[i*13:], p)
 		copy(res[i*13:], p)
-		l, err := luhnBase32.generate(p)
+		l, err := luhn32(p)
 		if err != nil {
 		if err != nil {
 			return "", err
 			return "", err
 		}
 		}

+ 19 - 28
lib/protocol/luhn.go

@@ -2,32 +2,34 @@
 
 
 package protocol
 package protocol
 
 
-import (
-	"fmt"
-	"strings"
-)
+import "fmt"
 
 
-// An alphabet is a string of N characters, representing the digits of a given
-// base N.
-type luhnAlphabet string
+var luhnBase32 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
 
 
-var (
-	luhnBase32 luhnAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
-)
+func codepoint32(b byte) int {
+	switch {
+	case 'A' <= b && b <= 'Z':
+		return int(b - 'A')
+	case '2' <= b && b <= '7':
+		return int(b + 26 - '2')
+	default:
+		return -1
+	}
+}
 
 
-// generate returns a check digit for the string s, which should be composed
-// of characters from the Alphabet a.
+// luhn32 returns a check digit for the string s, which should be composed
+// of characters from the alphabet luhnBase32.
 // Doesn't follow the actual Luhn algorithm
 // Doesn't follow the actual Luhn algorithm
 // see https://forum.syncthing.net/t/v0-9-0-new-node-id-format/478/6 for more.
 // see https://forum.syncthing.net/t/v0-9-0-new-node-id-format/478/6 for more.
-func (a luhnAlphabet) generate(s string) (rune, error) {
+func luhn32(s string) (rune, error) {
 	factor := 1
 	factor := 1
 	sum := 0
 	sum := 0
-	n := len(a)
+	const n = 32
 
 
 	for i := range s {
 	for i := range s {
-		codepoint := strings.IndexByte(string(a), s[i])
+		codepoint := codepoint32(s[i])
 		if codepoint == -1 {
 		if codepoint == -1 {
-			return 0, fmt.Errorf("digit %q not valid in alphabet %q", s[i], a)
+			return 0, fmt.Errorf("digit %q not valid in alphabet %q", s[i], luhnBase32)
 		}
 		}
 		addend := factor * codepoint
 		addend := factor * codepoint
 		if factor == 2 {
 		if factor == 2 {
@@ -40,16 +42,5 @@ func (a luhnAlphabet) generate(s string) (rune, error) {
 	}
 	}
 	remainder := sum % n
 	remainder := sum % n
 	checkCodepoint := (n - remainder) % n
 	checkCodepoint := (n - remainder) % n
-	return rune(a[checkCodepoint]), nil
-}
-
-// luhnValidate returns true if the last character of the string s is correct, for
-// a string s composed of characters in the alphabet a.
-func (a luhnAlphabet) luhnValidate(s string) bool {
-	t := s[:len(s)-1]
-	c, err := a.generate(t)
-	if err != nil {
-		return false
-	}
-	return rune(s[len(s)-1]) == c
+	return rune(luhnBase32[checkCodepoint]), nil
 }
 }

+ 8 - 30
lib/protocol/luhn_test.go

@@ -3,46 +3,24 @@
 package protocol
 package protocol
 
 
 import (
 import (
+	"strings"
 	"testing"
 	"testing"
 )
 )
 
 
-func TestGenerate(t *testing.T) {
-	// Base 6 Luhn
-	a := luhnAlphabet("abcdef")
-	c, err := a.generate("abcdef")
+func TestLuhn32(t *testing.T) {
+	c, err := luhn32("AB725E4GHIQPL3ZFGT")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	if c != 'e' {
-		t.Errorf("Incorrect check digit %c != e", c)
+	if c != 'G' {
+		t.Errorf("Incorrect check digit %c != G", c)
 	}
 	}
 
 
-	// Base 10 Luhn
-	a = luhnAlphabet("0123456789")
-	c, err = a.generate("7992739871")
-	if err != nil {
-		t.Fatal(err)
-	}
-	if c != '3' {
-		t.Errorf("Incorrect check digit %c != 3", c)
-	}
-}
-
-func TestInvalidString(t *testing.T) {
-	a := luhnAlphabet("ABC")
-	_, err := a.generate("7992739871")
-	t.Log(err)
+	_, err = luhn32("3734EJEKMRHWPZQTWYQ1")
 	if err == nil {
 	if err == nil {
 		t.Error("Unexpected nil error")
 		t.Error("Unexpected nil error")
 	}
 	}
-}
-
-func TestValidate(t *testing.T) {
-	a := luhnAlphabet("abcdef")
-	if !a.luhnValidate("abcdefe") {
-		t.Errorf("Incorrect validation response for abcdefe")
-	}
-	if a.luhnValidate("abcdefd") {
-		t.Errorf("Incorrect validation response for abcdefd")
+	if !strings.Contains(err.Error(), "'1'") {
+		t.Errorf("luhn32 should have errored on digit '1', got %v", err)
 	}
 	}
 }
 }