Преглед изворни кода

fix(transport): correctly release UDS locker file (#2305)

* fix(transport): correctly release UDS locker file

* use callback function to do some jobs after create listener
A1lo пре 2 година
родитељ
комит
10d6b06578

+ 0 - 5
transport/internet/grpc/hub.go

@@ -23,7 +23,6 @@ type Listener struct {
 	handler internet.ConnHandler
 	handler internet.ConnHandler
 	local   net.Addr
 	local   net.Addr
 	config  *Config
 	config  *Config
-	locker  *internet.FileLocker // for unix domain socket
 
 
 	s *grpc.Server
 	s *grpc.Server
 }
 }
@@ -110,10 +109,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				return
 				return
 			}
 			}
-			locker := ctx.Value(address.Domain())
-			if locker != nil {
-				listener.locker = locker.(*internet.FileLocker)
-			}
 		} else { // tcp
 		} else { // tcp
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 				IP:   address.IP(),
 				IP:   address.IP(),

+ 0 - 8
transport/internet/http/hub.go

@@ -27,7 +27,6 @@ type Listener struct {
 	handler internet.ConnHandler
 	handler internet.ConnHandler
 	local   net.Addr
 	local   net.Addr
 	config  *Config
 	config  *Config
-	locker  *internet.FileLocker // for unix domain socket
 }
 }
 
 
 func (l *Listener) Addr() net.Addr {
 func (l *Listener) Addr() net.Addr {
@@ -35,9 +34,6 @@ func (l *Listener) Addr() net.Addr {
 }
 }
 
 
 func (l *Listener) Close() error {
 func (l *Listener) Close() error {
-	if l.locker != nil {
-		l.locker.Release()
-	}
 	return l.server.Close()
 	return l.server.Close()
 }
 }
 
 
@@ -180,10 +176,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				return
 				return
 			}
 			}
-			locker := ctx.Value(address.Domain())
-			if locker != nil {
-				listener.locker = locker.(*internet.FileLocker)
-			}
 		} else { // tcp
 		} else { // tcp
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 				IP:   address.IP(),
 				IP:   address.IP(),

+ 49 - 20
transport/internet/system_listener.go

@@ -21,6 +21,19 @@ type DefaultListener struct {
 	controllers []control.Func
 	controllers []control.Func
 }
 }
 
 
+type combinedListener struct {
+	net.Listener
+	locker *FileLocker // for unix domain socket
+}
+
+func (cl *combinedListener) Close() error {
+	if cl.locker != nil {
+		cl.locker.Release()
+		cl.locker = nil
+	}
+	return cl.Listener.Close()
+}
+
 func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error {
 func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error {
 	return func(network, address string, c syscall.RawConn) error {
 	return func(network, address string, c syscall.RawConn) error {
 		return c.Control(func(fd uintptr) {
 		return c.Control(func(fd uintptr) {
@@ -44,6 +57,10 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
 func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
 func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
 	var lc net.ListenConfig
 	var lc net.ListenConfig
 	var network, address string
 	var network, address string
+	// callback is called after the Listen function returns
+	callback := func(l net.Listener, err error) (net.Listener, error) {
+		return l, err
+	}
 
 
 	switch addr := addr.(type) {
 	switch addr := addr.(type) {
 	case *net.TCPAddr:
 	case *net.TCPAddr:
@@ -58,23 +75,6 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 		network = addr.Network()
 		network = addr.Network()
 		address = addr.Name
 		address = addr.Name
 
 
-		if s := strings.Split(address, ","); len(s) == 2 {
-			address = s[0]
-			perm, perr := strconv.ParseUint(s[1], 8, 32)
-			if perr != nil {
-				return nil, newError("failed to parse permission: " + s[1]).Base(perr)
-			}
-
-			defer func(file string, permission os.FileMode) {
-				if err == nil {
-					cerr := os.Chmod(address, permission)
-					if cerr != nil {
-						err = newError("failed to set permission for " + file).Base(cerr)
-					}
-				}
-			}(address, os.FileMode(perm))
-		}
-
 		if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' {
 		if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' {
 			// linux abstract unix domain socket is lockfree
 			// linux abstract unix domain socket is lockfree
 			if len(address) > 1 && address[1] == '@' {
 			if len(address) > 1 && address[1] == '@' {
@@ -84,19 +84,48 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 				address = string(fullAddr)
 				address = string(fullAddr)
 			}
 			}
 		} else {
 		} else {
+			// split permission from address
+			var filePerm *os.FileMode
+			if s := strings.Split(address, ","); len(s) == 2 {
+				address = s[0]
+				perm, perr := strconv.ParseUint(s[1], 8, 32)
+				if perr != nil {
+					return nil, newError("failed to parse permission: " + s[1]).Base(perr)
+				}
+
+				mode := os.FileMode(perm)
+				filePerm = &mode
+			}
 			// normal unix domain socket needs lock
 			// normal unix domain socket needs lock
 			locker := &FileLocker{
 			locker := &FileLocker{
 				path: address + ".lock",
 				path: address + ".lock",
 			}
 			}
-			err := locker.Acquire()
-			if err != nil {
+			if err := locker.Acquire(); err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
-			ctx = context.WithValue(ctx, address, locker)
+
+			// set callback to combine listener and set permission
+			callback = func(l net.Listener, err error) (net.Listener, error) {
+				if err != nil {
+					locker.Release()
+					return l, err
+				}
+				l = &combinedListener{Listener: l, locker: locker}
+				if filePerm == nil {
+					return l, nil
+				}
+				err = os.Chmod(address, *filePerm)
+				if err != nil {
+					l.Close()
+					return nil, newError("failed to set permission for " + address).Base(err)
+				}
+				return l, nil
+			}
 		}
 		}
 	}
 	}
 
 
 	l, err = lc.Listen(ctx, network, address)
 	l, err = lc.Listen(ctx, network, address)
+	l, err = callback(l, err)
 	if sockopt != nil && sockopt.AcceptProxyProtocol {
 	if sockopt != nil && sockopt.AcceptProxyProtocol {
 		policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
 		policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
 		l = &proxyproto.Listener{Listener: l, Policy: policyFunc}
 		l = &proxyproto.Listener{Listener: l, Policy: policyFunc}

+ 0 - 8
transport/internet/tcp/hub.go

@@ -24,7 +24,6 @@ type Listener struct {
 	authConfig    internet.ConnectionAuthenticator
 	authConfig    internet.ConnectionAuthenticator
 	config        *Config
 	config        *Config
 	addConn       internet.ConnHandler
 	addConn       internet.ConnHandler
-	locker        *internet.FileLocker // for unix domain socket
 }
 }
 
 
 // ListenTCP creates a new Listener based on configurations.
 // ListenTCP creates a new Listener based on configurations.
@@ -51,10 +50,6 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe
 			return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
 			return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
 		}
 		}
 		newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
 		newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
 	} else {
 	} else {
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
 			IP:   address.IP(),
@@ -133,9 +128,6 @@ func (v *Listener) Addr() net.Addr {
 
 
 // Close implements internet.Listener.Close.
 // Close implements internet.Listener.Close.
 func (v *Listener) Close() error {
 func (v *Listener) Close() error {
-	if v.locker != nil {
-		v.locker.Release()
-	}
 	return v.listener.Close()
 	return v.listener.Close()
 }
 }
 
 

+ 0 - 8
transport/internet/websocket/hub.go

@@ -75,7 +75,6 @@ type Listener struct {
 	listener net.Listener
 	listener net.Listener
 	config   *Config
 	config   *Config
 	addConn  internet.ConnHandler
 	addConn  internet.ConnHandler
-	locker   *internet.FileLocker // for unix domain socket
 }
 }
 
 
 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
@@ -101,10 +100,6 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
 			return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err)
 			return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err)
 		}
 		}
 		newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx))
 		newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
 	} else { // tcp
 	} else { // tcp
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
 			IP:   address.IP(),
@@ -153,9 +148,6 @@ func (ln *Listener) Addr() net.Addr {
 
 
 // Close implements net.Listener.Close().
 // Close implements net.Listener.Close().
 func (ln *Listener) Close() error {
 func (ln *Listener) Close() error {
-	if ln.locker != nil {
-		ln.locker.Release()
-	}
 	return ln.listener.Close()
 	return ln.listener.Close()
 }
 }