浏览代码

all: Refactor relay invitations (#7646)

Simon Frei 4 年之前
父节点
当前提交
713527facf

+ 1 - 1
cmd/strelaysrv/testutil/main.go

@@ -57,7 +57,7 @@ func main() {
 
 	if join {
 		log.Println("Creating client")
-		relay, err := client.NewClient(uri, []tls.Certificate{cert}, nil, 10*time.Second)
+		relay, err := client.NewClient(uri, []tls.Certificate{cert}, 10*time.Second)
 		if err != nil {
 			log.Fatal(err)
 		}

+ 17 - 16
lib/connections/relay_listen.go

@@ -44,38 +44,39 @@ type relayListener struct {
 }
 
 func (t *relayListener) serve(ctx context.Context) error {
-	clnt, err := client.NewClient(t.uri, t.tlsCfg.Certificates, nil, 10*time.Second)
+	clnt, err := client.NewClient(t.uri, t.tlsCfg.Certificates, 10*time.Second)
 	if err != nil {
 		l.Infoln("Listen (BEP/relay):", err)
 		return err
 	}
-	invitations := clnt.Invitations()
 
 	t.mut.Lock()
 	t.client = clnt
-	go clnt.Serve(ctx)
 	t.mut.Unlock()
 
-	// Start with nil, so that we send a addresses changed notification as soon as we connect somewhere.
-	var oldURI *url.URL
-
 	l.Infof("Relay listener (%v) starting", t)
 	defer l.Infof("Relay listener (%v) shutting down", t)
 	defer t.clearAddresses(t)
 
+	invitationCtx, cancel := context.WithCancel(ctx)
+	defer cancel()
+	go t.handleInvitations(invitationCtx, clnt)
+
+	return clnt.Serve(ctx)
+}
+
+func (t *relayListener) handleInvitations(ctx context.Context, clnt client.RelayClient) {
+	invitations := clnt.Invitations()
+
+	// Start with nil, so that we send a addresses changed notification as soon as we connect somewhere.
+	var oldURI *url.URL
+
 	for {
 		select {
-		case inv, ok := <-invitations:
-			if !ok {
-				if err := clnt.Error(); err != nil {
-					l.Infoln("Listen (BEP/relay):", err)
-				}
-				return err
-			}
-
+		case inv := <-invitations:
 			conn, err := client.JoinSession(ctx, inv)
 			if err != nil {
-				if errors.Cause(err) != context.Canceled {
+				if !errors.Is(err, context.Canceled) {
 					l.Infoln("Listen (BEP/relay): joining session:", err)
 				}
 				continue
@@ -119,7 +120,7 @@ func (t *relayListener) serve(ctx context.Context) error {
 			}
 
 		case <-ctx.Done():
-			return ctx.Err()
+			return
 		}
 	}
 }

+ 5 - 21
lib/relay/client/client.go

@@ -35,21 +35,21 @@ type RelayClient interface {
 	URI() *url.URL
 }
 
-func NewClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation, timeout time.Duration) (RelayClient, error) {
+func NewClient(uri *url.URL, certs []tls.Certificate, timeout time.Duration) (RelayClient, error) {
 	factory, ok := supportedSchemes[uri.Scheme]
 	if !ok {
 		return nil, fmt.Errorf("unsupported scheme: %s", uri.Scheme)
 	}
 
+	invitations := make(chan protocol.SessionInvitation)
 	return factory(uri, certs, invitations, timeout), nil
 }
 
 type commonClient struct {
 	svcutil.ServiceWithError
 
-	invitations              chan protocol.SessionInvitation
-	closeInvitationsOnFinish bool
-	mut                      sync.RWMutex
+	invitations chan protocol.SessionInvitation
+	mut         sync.RWMutex
 }
 
 func newCommonClient(invitations chan protocol.SessionInvitation, serve func(context.Context) error, creator string) commonClient {
@@ -57,26 +57,10 @@ func newCommonClient(invitations chan protocol.SessionInvitation, serve func(con
 		invitations: invitations,
 		mut:         sync.NewRWMutex(),
 	}
-	newServe := func(ctx context.Context) error {
-		defer c.cleanup()
-		return serve(ctx)
-	}
-	c.ServiceWithError = svcutil.AsService(newServe, creator)
-	if c.invitations == nil {
-		c.closeInvitationsOnFinish = true
-		c.invitations = make(chan protocol.SessionInvitation)
-	}
+	c.ServiceWithError = svcutil.AsService(serve, creator)
 	return c
 }
 
-func (c *commonClient) cleanup() {
-	c.mut.Lock()
-	if c.closeInvitationsOnFinish {
-		close(c.invitations)
-	}
-	c.mut.Unlock()
-}
-
 func (c *commonClient) Invitations() chan protocol.SessionInvitation {
 	c.mut.RLock()
 	defer c.mut.RUnlock()

+ 9 - 5
lib/relay/client/methods.go

@@ -114,16 +114,20 @@ func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (ne
 
 func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) error {
 	id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0])
-	invs := make(chan protocol.SessionInvitation, 1)
-	c, err := NewClient(uri, certs, invs, timeout)
+	c, err := NewClient(uri, certs, timeout)
 	if err != nil {
-		close(invs)
 		return fmt.Errorf("creating client: %w", err)
 	}
 	ctx, cancel := context.WithCancel(context.Background())
+	go c.Serve(ctx)
 	go func() {
-		c.Serve(ctx)
-		close(invs)
+		for {
+			select {
+			case <-c.Invitations():
+			case <-ctx.Done():
+				return
+			}
+		}
 	}()
 	defer cancel()
 

+ 6 - 1
lib/relay/client/static.go

@@ -98,7 +98,12 @@ func (c *staticClient) serve(ctx context.Context) error {
 				if len(ip) == 0 || ip.IsUnspecified() {
 					msg.Address = remoteIPBytes(c.conn)
 				}
-				c.invitations <- msg
+				select {
+				case c.invitations <- msg:
+				case <-ctx.Done():
+					l.Debugln(c, "stopping")
+					return ctx.Err()
+				}
 
 			case protocol.RelayFull:
 				l.Infof("Disconnected from relay %s due to it becoming full.", c.uri)