|
|
@@ -12,88 +12,54 @@ import (
|
|
|
"github.com/syncthing/syncthing/lib/dialer"
|
|
|
syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
|
|
|
"github.com/syncthing/syncthing/lib/relay/protocol"
|
|
|
- "github.com/syncthing/syncthing/lib/sync"
|
|
|
)
|
|
|
|
|
|
type staticClient struct {
|
|
|
- uri *url.URL
|
|
|
- invitations chan protocol.SessionInvitation
|
|
|
+ commonClient
|
|
|
|
|
|
- closeInvitationsOnFinish bool
|
|
|
+ uri *url.URL
|
|
|
|
|
|
config *tls.Config
|
|
|
|
|
|
messageTimeout time.Duration
|
|
|
connectTimeout time.Duration
|
|
|
|
|
|
- stop chan struct{}
|
|
|
- stopped chan struct{}
|
|
|
- stopMut sync.RWMutex
|
|
|
-
|
|
|
conn *tls.Conn
|
|
|
|
|
|
- mut sync.RWMutex
|
|
|
- err error
|
|
|
connected bool
|
|
|
latency time.Duration
|
|
|
}
|
|
|
|
|
|
func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation, timeout time.Duration) RelayClient {
|
|
|
- closeInvitationsOnFinish := false
|
|
|
- if invitations == nil {
|
|
|
- closeInvitationsOnFinish = true
|
|
|
- invitations = make(chan protocol.SessionInvitation)
|
|
|
- }
|
|
|
-
|
|
|
- stopped := make(chan struct{})
|
|
|
- close(stopped) // not yet started, don't block on Stop()
|
|
|
- return &staticClient{
|
|
|
- uri: uri,
|
|
|
- invitations: invitations,
|
|
|
-
|
|
|
- closeInvitationsOnFinish: closeInvitationsOnFinish,
|
|
|
+ c := &staticClient{
|
|
|
+ uri: uri,
|
|
|
|
|
|
config: configForCerts(certs),
|
|
|
|
|
|
messageTimeout: time.Minute * 2,
|
|
|
connectTimeout: timeout,
|
|
|
-
|
|
|
- stop: make(chan struct{}),
|
|
|
- stopped: stopped,
|
|
|
- stopMut: sync.NewRWMutex(),
|
|
|
-
|
|
|
- mut: sync.NewRWMutex(),
|
|
|
}
|
|
|
+ c.commonClient = newCommonClient(invitations, c.serve)
|
|
|
+ return c
|
|
|
}
|
|
|
|
|
|
-func (c *staticClient) Serve() {
|
|
|
- defer c.cleanup()
|
|
|
- c.stopMut.Lock()
|
|
|
- c.stop = make(chan struct{})
|
|
|
- c.stopped = make(chan struct{})
|
|
|
- c.stopMut.Unlock()
|
|
|
- defer close(c.stopped)
|
|
|
-
|
|
|
+func (c *staticClient) serve(stop chan struct{}) error {
|
|
|
if err := c.connect(); err != nil {
|
|
|
l.Infof("Could not connect to relay %s: %s", c.uri, err)
|
|
|
- c.setError(err)
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
l.Debugln(c, "connected", c.conn.RemoteAddr())
|
|
|
+ defer c.disconnect()
|
|
|
|
|
|
if err := c.join(); err != nil {
|
|
|
- c.conn.Close()
|
|
|
l.Infof("Could not join relay %s: %s", c.uri, err)
|
|
|
- c.setError(err)
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
if err := c.conn.SetDeadline(time.Time{}); err != nil {
|
|
|
- c.conn.Close()
|
|
|
l.Infoln("Relay set deadline:", err)
|
|
|
- c.setError(err)
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
l.Infof("Joined relay %s://%s", c.uri.Scheme, c.uri.Host)
|
|
|
@@ -106,12 +72,10 @@ func (c *staticClient) Serve() {
|
|
|
messages := make(chan interface{})
|
|
|
errors := make(chan error, 1)
|
|
|
|
|
|
- go messageReader(c.conn, messages, errors)
|
|
|
+ go messageReader(c.conn, messages, errors, stop)
|
|
|
|
|
|
timeout := time.NewTimer(c.messageTimeout)
|
|
|
|
|
|
- c.stopMut.RLock()
|
|
|
- defer c.stopMut.RUnlock()
|
|
|
for {
|
|
|
select {
|
|
|
case message := <-messages:
|
|
|
@@ -122,11 +86,9 @@ func (c *staticClient) Serve() {
|
|
|
case protocol.Ping:
|
|
|
if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil {
|
|
|
l.Infoln("Relay write:", err)
|
|
|
- c.setError(err)
|
|
|
- c.disconnect()
|
|
|
- } else {
|
|
|
- l.Debugln(c, "sent pong")
|
|
|
+ return err
|
|
|
}
|
|
|
+ l.Debugln(c, "sent pong")
|
|
|
|
|
|
case protocol.SessionInvitation:
|
|
|
ip := net.IP(msg.Address)
|
|
|
@@ -137,52 +99,28 @@ func (c *staticClient) Serve() {
|
|
|
|
|
|
case protocol.RelayFull:
|
|
|
l.Infof("Disconnected from relay %s due to it becoming full.", c.uri)
|
|
|
- c.setError(fmt.Errorf("Relay full"))
|
|
|
- c.disconnect()
|
|
|
+ return fmt.Errorf("relay full")
|
|
|
|
|
|
default:
|
|
|
l.Infoln("Relay: protocol error: unexpected message %v", msg)
|
|
|
- c.setError(fmt.Errorf("protocol error: unexpected message %v", msg))
|
|
|
- c.disconnect()
|
|
|
+ return fmt.Errorf("protocol error: unexpected message %v", msg)
|
|
|
}
|
|
|
|
|
|
- case <-c.stop:
|
|
|
+ case <-stop:
|
|
|
l.Debugln(c, "stopping")
|
|
|
- c.setError(nil)
|
|
|
- c.disconnect()
|
|
|
+ return nil
|
|
|
|
|
|
- // We always exit via this branch of the select, to make sure the
|
|
|
- // the reader routine exits.
|
|
|
case err := <-errors:
|
|
|
- close(errors)
|
|
|
- close(messages)
|
|
|
- c.mut.Lock()
|
|
|
- if c.connected {
|
|
|
- c.conn.Close()
|
|
|
- c.connected = false
|
|
|
- l.Infof("Disconnecting from relay %s due to error: %s", c.uri, err)
|
|
|
- c.err = err
|
|
|
- } else {
|
|
|
- c.err = nil
|
|
|
- }
|
|
|
- c.mut.Unlock()
|
|
|
- return
|
|
|
+ l.Infof("Disconnecting from relay %s due to error: %s", c.uri, err)
|
|
|
+ return err
|
|
|
|
|
|
case <-timeout.C:
|
|
|
l.Debugln(c, "timed out")
|
|
|
- c.disconnect()
|
|
|
- c.setError(fmt.Errorf("timed out"))
|
|
|
+ return fmt.Errorf("timed out")
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *staticClient) Stop() {
|
|
|
- c.stopMut.RLock()
|
|
|
- close(c.stop)
|
|
|
- <-c.stopped
|
|
|
- c.stopMut.RUnlock()
|
|
|
-}
|
|
|
-
|
|
|
func (c *staticClient) StatusOK() bool {
|
|
|
c.mut.RLock()
|
|
|
con := c.connected
|
|
|
@@ -205,22 +143,6 @@ func (c *staticClient) URI() *url.URL {
|
|
|
return c.uri
|
|
|
}
|
|
|
|
|
|
-func (c *staticClient) Invitations() chan protocol.SessionInvitation {
|
|
|
- c.mut.RLock()
|
|
|
- inv := c.invitations
|
|
|
- c.mut.RUnlock()
|
|
|
- return inv
|
|
|
-}
|
|
|
-
|
|
|
-func (c *staticClient) cleanup() {
|
|
|
- c.mut.Lock()
|
|
|
- if c.closeInvitationsOnFinish {
|
|
|
- close(c.invitations)
|
|
|
- c.invitations = make(chan protocol.SessionInvitation)
|
|
|
- }
|
|
|
- c.mut.Unlock()
|
|
|
-}
|
|
|
-
|
|
|
func (c *staticClient) connect() error {
|
|
|
if c.uri.Scheme != "relay" {
|
|
|
return fmt.Errorf("Unsupported relay schema: %v", c.uri.Scheme)
|
|
|
@@ -261,19 +183,6 @@ func (c *staticClient) disconnect() {
|
|
|
c.conn.Close()
|
|
|
}
|
|
|
|
|
|
-func (c *staticClient) setError(err error) {
|
|
|
- c.mut.Lock()
|
|
|
- c.err = err
|
|
|
- c.mut.Unlock()
|
|
|
-}
|
|
|
-
|
|
|
-func (c *staticClient) Error() error {
|
|
|
- c.mut.RLock()
|
|
|
- err := c.err
|
|
|
- c.mut.RUnlock()
|
|
|
- return err
|
|
|
-}
|
|
|
-
|
|
|
func (c *staticClient) join() error {
|
|
|
if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil {
|
|
|
return err
|
|
|
@@ -332,13 +241,17 @@ func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
|
|
|
+func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error, stop chan struct{}) {
|
|
|
for {
|
|
|
msg, err := protocol.ReadMessage(conn)
|
|
|
if err != nil {
|
|
|
errors <- err
|
|
|
return
|
|
|
}
|
|
|
- messages <- msg
|
|
|
+ select {
|
|
|
+ case messages <- msg:
|
|
|
+ case <-stop:
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
}
|