浏览代码

Fix context in v2ray http transports

世界 7 月之前
父节点
当前提交
68ce9577c6
共有 5 个文件被更改,包括 20 次插入4 次删除
  1. 5 1
      log/id.go
  2. 10 0
      transport/v2rayhttp/conn.go
  3. 1 1
      transport/v2rayhttp/server.go
  4. 2 1
      transport/v2rayhttpupgrade/server.go
  5. 2 1
      transport/v2raywebsocket/server.go

+ 5 - 1
log/id.go

@@ -20,12 +20,16 @@ type ID struct {
 }
 
 func ContextWithNewID(ctx context.Context) context.Context {
-	return context.WithValue(ctx, (*idKey)(nil), ID{
+	return ContextWithID(ctx, ID{
 		ID:        rand.Uint32(),
 		CreatedAt: time.Now(),
 	})
 }
 
+func ContextWithID(ctx context.Context, id ID) context.Context {
+	return context.WithValue(ctx, (*idKey)(nil), id)
+}
+
 func IDFromContext(ctx context.Context) (ID, bool) {
 	id, loaded := ctx.Value((*idKey)(nil)).(ID)
 	return id, loaded

+ 10 - 0
transport/v2rayhttp/conn.go

@@ -2,6 +2,7 @@ package v2rayhttp
 
 import (
 	std_bufio "bufio"
+	"context"
 	"io"
 	"net"
 	"net/http"
@@ -10,6 +11,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/baderror"
 	"github.com/sagernet/sing/common/buf"
@@ -255,3 +257,11 @@ func (w *HTTP2ConnWrapper) Close() error {
 func (w *HTTP2ConnWrapper) Upstream() any {
 	return w.ExtendedConn
 }
+
+func DupContext(ctx context.Context) context.Context {
+	id, loaded := log.IDFromContext(ctx)
+	if !loaded {
+		return context.Background()
+	}
+	return log.ContextWithID(context.Background(), id)
+}

+ 1 - 1
transport/v2rayhttp/server.go

@@ -132,7 +132,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 		if requestBody != nil {
 			conn = bufio.NewCachedConn(conn, requestBody)
 		}
-		s.handler.NewConnectionEx(request.Context(), conn, source, M.Socksaddr{}, nil)
+		s.handler.NewConnectionEx(DupContext(request.Context()), conn, source, M.Socksaddr{}, nil)
 	} else {
 		writer.WriteHeader(http.StatusOK)
 		done := make(chan struct{})

+ 2 - 1
transport/v2rayhttpupgrade/server.go

@@ -12,6 +12,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/transport/v2rayhttp"
 	"github.com/sagernet/sing/common"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/logger"
@@ -110,7 +111,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 		s.invalidRequest(writer, request, http.StatusInternalServerError, E.Cause(err, "hijack failed"))
 		return
 	}
-	s.handler.NewConnectionEx(request.Context(), conn, sHttp.SourceAddress(request), M.Socksaddr{}, nil)
+	s.handler.NewConnectionEx(v2rayhttp.DupContext(request.Context()), conn, sHttp.SourceAddress(request), M.Socksaddr{}, nil)
 }
 
 func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) {

+ 2 - 1
transport/v2raywebsocket/server.go

@@ -13,6 +13,7 @@ import (
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/transport/v2rayhttp"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
@@ -114,7 +115,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	if len(earlyData) > 0 {
 		conn = bufio.NewCachedConn(conn, buf.As(earlyData))
 	}
-	s.handler.NewConnectionEx(request.Context(), conn, source, M.Socksaddr{}, nil)
+	s.handler.NewConnectionEx(v2rayhttp.DupContext(request.Context()), conn, source, M.Socksaddr{}, nil)
 }
 
 func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) {