Selaa lähdekoodia

cmd/strelaypoolsrv: Expose check error to client, fix incorrect response code handling

Jakob Borg 5 vuotta sitten
vanhempi
sitoutus
362da59396

+ 3 - 4
cmd/strelaypoolsrv/main.go

@@ -10,7 +10,6 @@ import (
 	"context"
 	"crypto/tls"
 	"encoding/json"
-	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -491,11 +490,11 @@ func handleRelayTest(request request) {
 	if debug {
 		log.Println("Request for", request.relay)
 	}
-	if !client.TestRelay(context.TODO(), request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3) {
+	if err := client.TestRelay(context.TODO(), request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3); err != nil {
 		if debug {
-			log.Println("Test for relay", request.relay, "failed")
+			log.Println("Test for relay", request.relay, "failed:", err)
 		}
-		request.result <- result{errors.New("connection test failed"), 0}
+		request.result <- result{err, 0}
 		return
 	}
 

+ 2 - 2
cmd/strelaysrv/testutil/main.go

@@ -107,10 +107,10 @@ func main() {
 		connectToStdio(stdin, conn)
 		log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
 	} else if test {
-		if client.TestRelay(ctx, uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4) {
+		if err := client.TestRelay(ctx, uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4); err == nil {
 			log.Println("OK")
 		} else {
-			log.Println("FAIL")
+			log.Println("FAIL:", err)
 		}
 	} else {
 		log.Fatal("Requires either join or connect")

+ 19 - 9
lib/relay/client/methods.go

@@ -5,11 +5,11 @@ package client
 import (
 	"context"
 	"crypto/tls"
+	"errors"
 	"fmt"
 	"net"
 	"net/url"
 	"strconv"
-	"strings"
 	"time"
 
 	"github.com/syncthing/syncthing/lib/dialer"
@@ -17,6 +17,15 @@ import (
 	"github.com/syncthing/syncthing/lib/relay/protocol"
 )
 
+type incorrectResponseCodeErr struct {
+	code int32
+	msg  string
+}
+
+func (e incorrectResponseCodeErr) Error() string {
+	return fmt.Sprintf("incorrect response code %d: %s", e.code, e.msg)
+}
+
 func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) {
 	if uri.Scheme != "relay" {
 		return protocol.SessionInvitation{}, fmt.Errorf("unsupported relay scheme: %v", uri.Scheme)
@@ -53,7 +62,7 @@ func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingproto
 
 	switch msg := message.(type) {
 	case protocol.Response:
-		return protocol.SessionInvitation{}, fmt.Errorf("incorrect response code %d: %s", msg.Code, msg.Message)
+		return protocol.SessionInvitation{}, incorrectResponseCodeErr{msg.Code, msg.Message}
 	case protocol.SessionInvitation:
 		l.Debugln("Received invitation", msg, "via", conn.LocalAddr())
 		ip := net.IP(msg.Address)
@@ -104,13 +113,13 @@ func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (ne
 	}
 }
 
-func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) bool {
+func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) error {
 	id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0])
 	invs := make(chan protocol.SessionInvitation, 1)
 	c, err := NewClient(uri, certs, invs, timeout)
 	if err != nil {
 		close(invs)
-		return false
+		return fmt.Errorf("creating client: %w", err)
 	}
 	go c.Serve()
 	defer func() {
@@ -119,16 +128,17 @@ func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep
 	}()
 
 	for i := 0; i < times; i++ {
-		_, err := GetInvitationFromRelay(ctx, uri, id, certs, timeout)
+		_, err = GetInvitationFromRelay(ctx, uri, id, certs, timeout)
 		if err == nil {
-			return true
+			return nil
 		}
-		if !strings.Contains(err.Error(), "Incorrect response code") {
-			return false
+		if !errors.As(err, &incorrectResponseCodeErr{}) {
+			return fmt.Errorf("getting invitation: %w", err)
 		}
 		time.Sleep(sleep)
 	}
-	return false
+
+	return fmt.Errorf("getting invitation: %w", err) // last of the above errors
 }
 
 func configForCerts(certs []tls.Certificate) *tls.Config {

+ 1 - 1
lib/relay/client/static.go

@@ -201,7 +201,7 @@ func (c *staticClient) join() error {
 	switch msg := message.(type) {
 	case protocol.Response:
 		if msg.Code != 0 {
-			return fmt.Errorf("incorrect response code %d: %s", msg.Code, msg.Message)
+			return incorrectResponseCodeErr{msg.Code, msg.Message}
 		}
 
 	case protocol.RelayFull: