浏览代码

lib/connections: Dialer code deduplication (#6187)

Simon Frei 5 年之前
父节点
当前提交
33258b06f4
共有 5 个文件被更改,包括 34 次插入37 次删除
  1. 6 11
      lib/connections/quic_dial.go
  2. 8 12
      lib/connections/relay_dial.go
  3. 1 1
      lib/connections/service.go
  4. 11 1
      lib/connections/structs.go
  5. 8 12
      lib/connections/tcp_dial.go

+ 6 - 11
lib/connections/quic_dial.go

@@ -39,8 +39,7 @@ func init() {
 }
 
 type quicDialer struct {
-	cfg    config.Wrapper
-	tlsCfg *tls.Config
+	commonDialer
 }
 
 func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
@@ -91,20 +90,16 @@ func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, erro
 	return internalConn{&quicTlsConn{session, stream, createdConn}, connTypeQUICClient, quicPriority}, nil
 }
 
-func (d *quicDialer) RedialFrequency() time.Duration {
-	return time.Duration(d.cfg.Options().ReconnectIntervalS) * time.Second
-}
-
 type quicDialerFactory struct {
 	cfg    config.Wrapper
 	tlsCfg *tls.Config
 }
 
-func (quicDialerFactory) New(cfg config.Wrapper, tlsCfg *tls.Config) genericDialer {
-	return &quicDialer{
-		cfg:    cfg,
-		tlsCfg: tlsCfg,
-	}
+func (quicDialerFactory) New(opts config.OptionsConfiguration, tlsCfg *tls.Config) genericDialer {
+	return &quicDialer{commonDialer{
+		reconnectInterval: time.Duration(opts.ReconnectIntervalS) * time.Second,
+		tlsCfg:            tlsCfg,
+	}}
 }
 
 func (quicDialerFactory) Priority() int {

+ 8 - 12
lib/connections/relay_dial.go

@@ -24,8 +24,7 @@ func init() {
 }
 
 type relayDialer struct {
-	cfg    config.Wrapper
-	tlsCfg *tls.Config
+	commonDialer
 }
 
 func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
@@ -45,7 +44,7 @@ func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, er
 		return internalConn{}, err
 	}
 
-	err = dialer.SetTrafficClass(conn, d.cfg.Options().TrafficClass)
+	err = dialer.SetTrafficClass(conn, d.trafficClass)
 	if err != nil {
 		l.Debugln("Dial (BEP/relay): setting traffic class:", err)
 	}
@@ -66,17 +65,14 @@ func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, er
 	return internalConn{tc, connTypeRelayClient, relayPriority}, nil
 }
 
-func (d *relayDialer) RedialFrequency() time.Duration {
-	return time.Duration(d.cfg.Options().RelayReconnectIntervalM) * time.Minute
-}
-
 type relayDialerFactory struct{}
 
-func (relayDialerFactory) New(cfg config.Wrapper, tlsCfg *tls.Config) genericDialer {
-	return &relayDialer{
-		cfg:    cfg,
-		tlsCfg: tlsCfg,
-	}
+func (relayDialerFactory) New(opts config.OptionsConfiguration, tlsCfg *tls.Config) genericDialer {
+	return &relayDialer{commonDialer{
+		trafficClass:      opts.TrafficClass,
+		reconnectInterval: time.Duration(opts.RelayReconnectIntervalM) * time.Minute,
+		tlsCfg:            tlsCfg,
+	}}
 }
 
 func (relayDialerFactory) Priority() int {

+ 1 - 1
lib/connections/service.go

@@ -442,7 +442,7 @@ func (s *service) connect(ctx context.Context) {
 					continue
 				}
 
-				dialer := dialerFactory.New(s.cfg, s.tlsCfg)
+				dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg)
 				nextDial[nextDialKey] = now.Add(dialer.RedialFrequency())
 
 				// For LAN addresses, increase the priority so that we

+ 11 - 1
lib/connections/structs.go

@@ -146,13 +146,23 @@ func (c internalConn) String() string {
 }
 
 type dialerFactory interface {
-	New(config.Wrapper, *tls.Config) genericDialer
+	New(config.OptionsConfiguration, *tls.Config) genericDialer
 	Priority() int
 	AlwaysWAN() bool
 	Valid(config.Configuration) error
 	String() string
 }
 
+type commonDialer struct {
+	trafficClass      int
+	reconnectInterval time.Duration
+	tlsCfg            *tls.Config
+}
+
+func (d *commonDialer) RedialFrequency() time.Duration {
+	return d.reconnectInterval
+}
+
 type genericDialer interface {
 	Dial(protocol.DeviceID, *url.URL) (internalConn, error)
 	RedialFrequency() time.Duration

+ 8 - 12
lib/connections/tcp_dial.go

@@ -26,8 +26,7 @@ func init() {
 }
 
 type tcpDialer struct {
-	cfg    config.Wrapper
-	tlsCfg *tls.Config
+	commonDialer
 }
 
 func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
@@ -43,7 +42,7 @@ func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error
 		l.Debugln("Dial (BEP/tcp): setting tcp options:", err)
 	}
 
-	err = dialer.SetTrafficClass(conn, d.cfg.Options().TrafficClass)
+	err = dialer.SetTrafficClass(conn, d.trafficClass)
 	if err != nil {
 		l.Debugln("Dial (BEP/tcp): setting traffic class:", err)
 	}
@@ -58,17 +57,14 @@ func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error
 	return internalConn{tc, connTypeTCPClient, tcpPriority}, nil
 }
 
-func (d *tcpDialer) RedialFrequency() time.Duration {
-	return time.Duration(d.cfg.Options().ReconnectIntervalS) * time.Second
-}
-
 type tcpDialerFactory struct{}
 
-func (tcpDialerFactory) New(cfg config.Wrapper, tlsCfg *tls.Config) genericDialer {
-	return &tcpDialer{
-		cfg:    cfg,
-		tlsCfg: tlsCfg,
-	}
+func (tcpDialerFactory) New(opts config.OptionsConfiguration, tlsCfg *tls.Config) genericDialer {
+	return &tcpDialer{commonDialer{
+		trafficClass:      opts.TrafficClass,
+		reconnectInterval: time.Duration(opts.ReconnectIntervalS) * time.Second,
+		tlsCfg:            tlsCfg,
+	}}
 }
 
 func (tcpDialerFactory) Priority() int {