瀏覽代碼

cmd/strelaysrv: Add optional auth token (fixes #3987) (#8561)

* implement authentication via token for relaysrv

Make replaysrv check for a token before allowing clients to
join. The token can be set via the replay-uri.

* fix formatting

* key composite literal

* do not error out if auth material is provided but not needed

* remove unused method receiver

* clean up unused parameter in functions

* cleaner token handling, disable joining the pool if token is set.

* Keep backwards compatibility with older clients.

In prior versions of the protocol JoinRelayRequest did not have a
token field. Trying to unmarshal such a request will result in
an error. Return an empty JoinRelayRequest, that is a request
without token, instead.

Co-authored-by: entity0xfe <[email protected]>
entity0xfe 3 年之前
父節點
當前提交
ad986f372d

+ 12 - 3
cmd/strelaysrv/listener.go

@@ -23,7 +23,7 @@ var (
 	numConnections int64
 )
 
-func listener(_, addr string, config *tls.Config) {
+func listener(_, addr string, config *tls.Config, token string) {
 	tcpListener, err := net.Listen("tcp", addr)
 	if err != nil {
 		log.Fatalln(err)
@@ -49,7 +49,7 @@ func listener(_, addr string, config *tls.Config) {
 		}
 
 		if isTLS {
-			go protocolConnectionHandler(conn, config)
+			go protocolConnectionHandler(conn, config, token)
 		} else {
 			go sessionConnectionHandler(conn)
 		}
@@ -57,7 +57,7 @@ func listener(_, addr string, config *tls.Config) {
 	}
 }
 
-func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
+func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config, token string) {
 	conn := tls.Server(tcpConn, config)
 	if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
 		if debug {
@@ -119,6 +119,15 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
 
 			switch msg := message.(type) {
 			case protocol.JoinRelayRequest:
+				if token != "" && msg.Token != token {
+					if debug {
+						log.Printf("invalid token %s\n", msg.Token)
+					}
+					protocol.WriteMessage(conn, protocol.ResponseWrongToken)
+					conn.Close()
+					continue
+				}
+
 				if atomic.LoadInt32(&overLimit) > 0 {
 					protocol.WriteMessage(conn, protocol.RelayFull{})
 					if debug {

+ 7 - 1
cmd/strelaysrv/main.go

@@ -56,6 +56,7 @@ var (
 	networkBufferSize int
 
 	statusAddr       string
+	token            string
 	poolAddrs        string
 	pools            []string
 	providedBy       string
@@ -89,6 +90,7 @@ func main() {
 	flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s")
 	flag.BoolVar(&debug, "debug", debug, "Enable debug output")
 	flag.StringVar(&statusAddr, "status-srv", ":22070", "Listen address for status service (blank to disable)")
+	flag.StringVar(&token, "token", "", "Token to restrict access to the relay (optional). Disables joining any pools.")
 	flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relay pool addresses to join")
 	flag.StringVar(&providedBy, "provided-by", "", "An optional description about who provides the relay")
 	flag.StringVar(&extAddress, "ext-address", "", "An optional address to advertise as being available on.\n\tAllows listening on an unprivileged port with port forwarding from e.g. 443, and be connected to on port 443.")
@@ -256,6 +258,10 @@ func main() {
 
 	log.Println("URI:", uri.String())
 
+	if token != "" {
+		poolAddrs = ""
+	}
+
 	if poolAddrs == defaultPoolAddrs {
 		log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
 		log.Println("!!  Joining default relay pools, this relay will be available for public use. !!")
@@ -271,7 +277,7 @@ func main() {
 		}
 	}
 
-	go listener(proto, listen, tlsCfg)
+	go listener(proto, listen, tlsCfg, token)
 
 	sigs := make(chan os.Signal, 1)
 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

+ 5 - 2
lib/relay/client/static.go

@@ -27,7 +27,8 @@ type staticClient struct {
 	messageTimeout time.Duration
 	connectTimeout time.Duration
 
-	conn *tls.Conn
+	conn  *tls.Conn
+	token string
 }
 
 func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation, timeout time.Duration) *staticClient {
@@ -38,6 +39,8 @@ func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan pro
 
 		messageTimeout: time.Minute * 2,
 		connectTimeout: timeout,
+
+		token: uri.Query().Get("token"),
 	}
 	c.commonClient = newCommonClient(invitations, c.serve, c.String())
 	return c
@@ -173,7 +176,7 @@ func (c *staticClient) disconnect() {
 }
 
 func (c *staticClient) join() error {
-	if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil {
+	if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{Token: c.token}); err != nil {
 		return err
 	}
 

+ 4 - 1
lib/relay/protocol/packets.go

@@ -31,9 +31,12 @@ type header struct {
 
 type Ping struct{}
 type Pong struct{}
-type JoinRelayRequest struct{}
 type RelayFull struct{}
 
+type JoinRelayRequest struct {
+	Token string
+}
+
 type JoinSessionRequest struct {
 	Key []byte // max:32
 }

+ 41 - 24
lib/relay/protocol/packets_xdr.go

@@ -137,70 +137,87 @@ func (*Pong) UnmarshalXDRFrom(_ *xdr.Unmarshaller) error {
 
 /*
 
-JoinRelayRequest Structure:
+RelayFull Structure:
 (contains no fields)
 
 
-struct JoinRelayRequest {
+struct RelayFull {
 }
 
 */
 
-func (JoinRelayRequest) XDRSize() int {
+func (RelayFull) XDRSize() int {
 	return 0
 }
-func (JoinRelayRequest) MarshalXDR() ([]byte, error) {
+func (RelayFull) MarshalXDR() ([]byte, error) {
 	return nil, nil
 }
 
-func (JoinRelayRequest) MustMarshalXDR() []byte {
+func (RelayFull) MustMarshalXDR() []byte {
 	return nil
 }
 
-func (JoinRelayRequest) MarshalXDRInto(_ *xdr.Marshaller) error {
+func (RelayFull) MarshalXDRInto(_ *xdr.Marshaller) error {
 	return nil
 }
 
-func (*JoinRelayRequest) UnmarshalXDR(_ []byte) error {
+func (*RelayFull) UnmarshalXDR(_ []byte) error {
 	return nil
 }
 
-func (*JoinRelayRequest) UnmarshalXDRFrom(_ *xdr.Unmarshaller) error {
+func (*RelayFull) UnmarshalXDRFrom(_ *xdr.Unmarshaller) error {
 	return nil
 }
 
 /*
 
-RelayFull Structure:
-(contains no fields)
+JoinRelayRequest Structure:
+
+ 0                   1                   2                   3
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/                                                               /
+\                 Token (length + padded data)                  \
+/                                                               /
++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 
 
-struct RelayFull {
+struct JoinRelayRequest {
+	string Token<>;
 }
 
 */
 
-func (RelayFull) XDRSize() int {
-	return 0
-}
-func (RelayFull) MarshalXDR() ([]byte, error) {
-	return nil, nil
+func (o JoinRelayRequest) XDRSize() int {
+	return 4 + len(o.Token) + xdr.Padding(len(o.Token))
 }
 
-func (RelayFull) MustMarshalXDR() []byte {
-	return nil
+func (o JoinRelayRequest) MarshalXDR() ([]byte, error) {
+	buf := make([]byte, o.XDRSize())
+	m := &xdr.Marshaller{Data: buf}
+	return buf, o.MarshalXDRInto(m)
 }
 
-func (RelayFull) MarshalXDRInto(_ *xdr.Marshaller) error {
-	return nil
+func (o JoinRelayRequest) MustMarshalXDR() []byte {
+	bs, err := o.MarshalXDR()
+	if err != nil {
+		panic(err)
+	}
+	return bs
 }
 
-func (*RelayFull) UnmarshalXDR(_ []byte) error {
-	return nil
+func (o JoinRelayRequest) MarshalXDRInto(m *xdr.Marshaller) error {
+	m.MarshalString(o.Token)
+	return m.Error
 }
 
-func (*RelayFull) UnmarshalXDRFrom(_ *xdr.Unmarshaller) error {
-	return nil
+func (o *JoinRelayRequest) UnmarshalXDR(bs []byte) error {
+	u := &xdr.Unmarshaller{Data: bs}
+	return o.UnmarshalXDRFrom(u)
+}
+func (o *JoinRelayRequest) UnmarshalXDRFrom(u *xdr.Unmarshaller) error {
+	o.Token = u.UnmarshalString()
+	return u.Error
 }
 
 /*

+ 9 - 0
lib/relay/protocol/protocol.go

@@ -17,6 +17,7 @@ var (
 	ResponseSuccess           = Response{0, "success"}
 	ResponseNotFound          = Response{1, "not found"}
 	ResponseAlreadyConnected  = Response{2, "already connected"}
+	ResponseWrongToken        = Response{3, "wrong token"}
 	ResponseUnexpectedMessage = Response{100, "unexpected message"}
 )
 
@@ -107,6 +108,14 @@ func ReadMessage(r io.Reader) (interface{}, error) {
 		return msg, err
 	case messageTypeJoinRelayRequest:
 		var msg JoinRelayRequest
+
+		// In prior versions of the protocol JoinRelayRequest did not have a
+		// token field. Trying to unmarshal such a request will result in
+		// an error, return msg with an empty token instead.
+		if header.messageLength == 0 {
+			return msg, nil
+		}
+
 		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypeJoinSessionRequest: