Browse Source

Use channel for pool

风扇滑翔翼 4 weeks ago
parent
commit
591ed82441
1 changed files with 72 additions and 26 deletions
  1. 72 26
      proxy/vless/outbound/outbound.go

+ 72 - 26
proxy/vless/outbound/outbound.go

@@ -54,9 +54,11 @@ type Handler struct {
 	encryption    *encryption.ClientInstance
 	reverse       *Reverse
 
-	testpre uint32
-	locker  sync.Mutex
-	conns   []stat.Connection
+	testpre     uint32
+	initConns   sync.Once
+	preConns    chan stat.Connection
+	preConnWait chan struct{}
+	preConnStop chan struct{}
 }
 
 // New creates a new VLess outbound handler.
@@ -117,6 +119,13 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
 
 // Close implements common.Closable.Close().
 func (h *Handler) Close() error {
+	if h.preConnStop != nil {
+		close(h.preConnStop)
+		for range h.testpre {
+			conn := <-h.preConns
+			common.CloseIfExists(conn)
+		}
+	}
 	if h.reverse != nil {
 		return h.reverse.Close()
 	}
@@ -136,30 +145,19 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 	var conn stat.Connection
 
 	if h.testpre > 0 && h.reverse == nil {
-		h.locker.Lock()
-		if h.conns == nil {
-			h.conns = make([]stat.Connection, 0)
-			go func() {
-				for { // TODO: close & inactive
-					time.Sleep(100 * time.Millisecond) // TODO: customize & randomize
-					h.locker.Lock()
-					if len(h.conns) >= int(h.testpre) {
-						h.locker.Unlock()
-						continue
-					}
-					h.locker.Unlock()
-					if conn, err := dialer.Dial(context.Background(), rec.Destination); err == nil { // TODO: timeout & concurrency? & ctx mitm?
-						h.locker.Lock()
-						h.conns = append(h.conns, conn) // TODO: vision paddings
-						h.locker.Unlock()
-					}
-				}
-			}()
-		} else if len(h.conns) > 0 {
-			conn = h.conns[0]
-			h.conns = h.conns[1:]
+		h.initConns.Do(func() {
+			h.preConns = make(chan stat.Connection, h.testpre)
+			h.preConnStop = make(chan struct{})
+			go h.preConnWorker(dialer, rec.Destination)
+		})
+		select {
+		case h.preConnWait <- struct{}{}:
+		default:
+		}
+		select {
+		case conn = <-h.preConns:
+		default:
 		}
-		h.locker.Unlock()
 	}
 
 	if conn == nil {
@@ -464,3 +462,51 @@ func (r *Reverse) Start() error {
 func (r *Reverse) Close() error {
 	return r.monitorTask.Close()
 }
+
+func (h *Handler) preConnWorker(dialer internet.Dialer, dest net.Destination) {
+	// conn in conns may be nil
+	conns := make(chan stat.Connection)
+	dial := func() {
+		conn, err := dialer.Dial(context.Background(), dest)
+		if err != nil {
+			errors.LogError(context.Background(), "failed to dial VLESS pre connection: ", err)
+			common.CloseIfExists(conn)
+		}
+		conns <- conn
+	}
+	go func() {
+		go dial() // get a conn immediately
+		for range h.testpre - 1 {
+			select {
+			case <-h.preConnWait:
+				go dial()
+			case <-h.preConnStop:
+				return
+			}
+		}
+	}()
+	for {
+		select {
+		case conn := <-conns:
+			if conn != nil {
+				select {
+				case h.preConns <- conn:
+				case <-h.preConnStop:
+					common.CloseIfExists(conn)
+					return
+				}
+				go dial()
+			} else {
+				// sleep until next client try if dial failed
+				select {
+				case <-h.preConnWait:
+					go dial()
+				case <-h.preConnStop:
+					return
+				}
+			}
+		case <-h.preConnStop:
+			return
+		}
+	}
+}