Browse Source

Luhn error checking

Jakob Borg 11 years ago
parent
commit
3d7d4d845a
3 changed files with 70 additions and 11 deletions
  1. 29 4
      luhn/luhn.go
  2. 26 2
      luhn/luhn_test.go
  3. 15 5
      protocol/nodeid.go

+ 29 - 4
luhn/luhn.go

@@ -1,7 +1,10 @@
 // Package luhn generates and validates Luhn mod N check digits.
 package luhn
 
-import "strings"
+import (
+	"fmt"
+	"strings"
+)
 
 // An alphabet is a string of N characters, representing the digits of a given
 // base N.
@@ -13,13 +16,20 @@ var (
 
 // Generate returns a check digit for the string s, which should be composed
 // of characters from the Alphabet a.
-func (a Alphabet) Generate(s string) rune {
+func (a Alphabet) Generate(s string) (rune, error) {
+	if err:=a.check();err!=nil{
+		return 0,err
+	}
+
 	factor := 1
 	sum := 0
 	n := len(a)
 
 	for i := range s {
 		codepoint := strings.IndexByte(string(a), s[i])
+		if codepoint == -1 {
+			return 0, fmt.Errorf("Digit %q not valid in alphabet %q", s[i], a)
+		}
 		addend := factor * codepoint
 		if factor == 2 {
 			factor = 1
@@ -31,13 +41,28 @@ func (a Alphabet) Generate(s string) rune {
 	}
 	remainder := sum % n
 	checkCodepoint := (n - remainder) % n
-	return rune(a[checkCodepoint])
+	return rune(a[checkCodepoint]), nil
 }
 
 // Validate returns true if the last character of the string s is correct, for
 // a string s composed of characters in the alphabet a.
 func (a Alphabet) Validate(s string) bool {
 	t := s[:len(s)-1]
-	c := a.Generate(t)
+	c, err := a.Generate(t)
+	if err != nil {
+		return false
+	}
 	return rune(s[len(s)-1]) == c
 }
+
+// check returns an error if the given alphabet does not consist of unique characters
+func (a Alphabet) check() error {
+	cm := make(map[byte]bool, len(a))
+	for i := range a {
+		if cm[a[i]] {
+			return fmt.Errorf("Digit %q non-unique in alphabet %q", a[i], a)
+		}
+		cm[a[i]] = true
+	}
+	return nil
+}

+ 26 - 2
luhn/luhn_test.go

@@ -9,19 +9,43 @@ import (
 func TestGenerate(t *testing.T) {
 	// Base 6 Luhn
 	a := luhn.Alphabet("abcdef")
-	c := a.Generate("abcdef")
+	c, err := a.Generate("abcdef")
+	if err != nil {
+		t.Fatal(err)
+	}
 	if c != 'e' {
 		t.Errorf("Incorrect check digit %c != e", c)
 	}
 
 	// Base 10 Luhn
 	a = luhn.Alphabet("0123456789")
-	c = a.Generate("7992739871")
+	c, err = a.Generate("7992739871")
+	if err != nil {
+		t.Fatal(err)
+	}
 	if c != '3' {
 		t.Errorf("Incorrect check digit %c != 3", c)
 	}
 }
 
+func TestInvalidString(t *testing.T) {
+	a := luhn.Alphabet("ABC")
+	_, err := a.Generate("7992739871")
+	t.Log(err)
+	if err == nil {
+		t.Error("Unexpected nil error")
+	}
+}
+
+func TestBadAlphabet(t *testing.T) {
+	a := luhn.Alphabet("01234566789")
+	_, err := a.Generate("7992739871")
+	t.Log(err)
+	if err == nil {
+		t.Error("Unexpected nil error")
+	}
+}
+
 func TestValidate(t *testing.T) {
 	a := luhn.Alphabet("abcdef")
 	if !a.Validate("abcdefe") {

+ 15 - 5
protocol/nodeid.go

@@ -34,7 +34,11 @@ func NodeIDFromString(s string) (NodeID, error) {
 func (n NodeID) String() string {
 	id := base32.StdEncoding.EncodeToString(n[:])
 	id = strings.Trim(id, "=")
-	id = luhnify(id)
+	id, err := luhnify(id)
+	if err != nil {
+		// Should never happen
+		panic(err)
+	}
 	id = chunkify(id)
 	return id
 }
@@ -84,7 +88,7 @@ func (n *NodeID) UnmarshalText(bs []byte) error {
 	}
 }
 
-func luhnify(s string) string {
+func luhnify(s string) (string, error) {
 	if len(s) != 52 {
 		panic("unsupported string length")
 	}
@@ -92,10 +96,13 @@ func luhnify(s string) string {
 	res := make([]string, 0, 4)
 	for i := 0; i < 4; i++ {
 		p := s[i*13 : (i+1)*13]
-		l := luhn.Base32.Generate(p)
+		l, err := luhn.Base32.Generate(p)
+		if err != nil {
+			return "", err
+		}
 		res = append(res, fmt.Sprintf("%s%c", p, l))
 	}
-	return res[0] + res[1] + res[2] + res[3]
+	return res[0] + res[1] + res[2] + res[3], nil
 }
 
 func unluhnify(s string) (string, error) {
@@ -106,7 +113,10 @@ func unluhnify(s string) (string, error) {
 	res := make([]string, 0, 4)
 	for i := 0; i < 4; i++ {
 		p := s[i*14 : (i+1)*14-1]
-		l := luhn.Base32.Generate(p)
+		l, err := luhn.Base32.Generate(p)
+		if err != nil {
+			return "", err
+		}
 		if g := fmt.Sprintf("%s%c", p, l); g != s[i*14:(i+1)*14] {
 			log.Printf("%q; %q", g, s[i*14:(i+1)*14])
 			return "", errors.New("check digit incorrect")