Browse Source

DispatchLink(): Fix user stats

Fixes https://github.com/XTLS/Xray-core/pull/5076#issuecomment-3243431593
RPRX 3 months ago
parent
commit
d20397c15d

+ 42 - 0
app/dispatcher/default.go

@@ -196,6 +196,47 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran
 	return inboundLink, outboundLink
 }
 
+func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) *transport.Link {
+	sessionInbound := session.InboundFromContext(ctx)
+	var user *protocol.MemoryUser
+	if sessionInbound != nil {
+		user = sessionInbound.User
+	}
+
+	link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader}
+
+	if user != nil && len(user.Email) > 0 {
+		p := d.policy.ForLevel(user.Level)
+		if p.Stats.UserUplink {
+			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
+			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
+				link.Reader.(*buf.TimeoutWrapperReader).Counter = c
+			}
+		}
+		if p.Stats.UserDownlink {
+			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
+			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
+				link.Writer = &SizeStatWriter{
+					Counter: c,
+					Writer:  link.Writer,
+				}
+			}
+		}
+		if p.Stats.UserOnline {
+			name := "user>>>" + user.Email + ">>>online"
+			if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil {
+				sessionInbounds := session.InboundFromContext(ctx)
+				userIP := sessionInbounds.Source.Address.String()
+				om.AddIP(userIP)
+				// log Online user with ips
+				// errors.LogDebug(ctx, "user>>>" + user.Email + ">>>online", om.Count(), om.List())
+			}
+		}
+	}
+
+	return link
+}
+
 func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
 	domain := result.Domain()
 	if domain == "" {
@@ -316,6 +357,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 		content = new(session.Content)
 		ctx = session.ContextWithContent(ctx, content)
 	}
+	outbound = d.WrapLink(ctx, outbound)
 	sniffingRequest := content.SniffingRequest
 	if !sniffingRequest.Enabled {
 		d.routedDispatch(ctx, outbound, destination)

+ 2 - 0
app/reverse/bridge.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"time"
 
+	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/mux"
 	"github.com/xtls/xray-core/common/net"
@@ -200,6 +201,7 @@ func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, l
 		return w.dispatcher.DispatchLink(ctx, dest, link)
 	}
 
+	link = w.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
 	w.handleInternalConn(link)
 
 	return nil

+ 12 - 3
common/buf/io.go

@@ -30,6 +30,7 @@ type TimeoutReader interface {
 
 type TimeoutWrapperReader struct {
 	Reader
+	stats.Counter
 	mb   MultiBuffer
 	err  error
 	done chan struct{}
@@ -39,11 +40,16 @@ func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) {
 	if r.done != nil {
 		<-r.done
 		r.done = nil
+		if r.Counter != nil {
+			r.Counter.Add(int64(r.mb.Len()))
+		}
 		return r.mb, r.err
 	}
-	r.mb = nil
-	r.err = nil
-	return r.Reader.ReadMultiBuffer()
+	r.mb, r.err = r.Reader.ReadMultiBuffer()
+	if r.Counter != nil {
+		r.Counter.Add(int64(r.mb.Len()))
+	}
+	return r.mb, r.err
 }
 
 func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) {
@@ -62,6 +68,9 @@ func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (M
 	select {
 	case <-r.done:
 		r.done = nil
+		if r.Counter != nil {
+			r.Counter.Add(int64(r.mb.Len()))
+		}
 		return r.mb, r.err
 	case <-timeout:
 		return nil, nil

+ 2 - 0
common/mux/server.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"io"
 
+	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
@@ -61,6 +62,7 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
 	if dest.Address != muxCoolAddress {
 		return s.dispatcher.DispatchLink(ctx, dest, link)
 	}
+	link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
 	_, err := NewServerWorker(ctx, s.dispatcher, link)
 	return err
 }

+ 1 - 1
proxy/dokodemo/dokodemo.go

@@ -182,7 +182,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 	}
 
 	if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
-		Reader: &buf.TimeoutWrapperReader{Reader: reader},
+		Reader: reader,
 		Writer: writer},
 	); err != nil {
 		return errors.New("failed to dispatch request").Base(err)

+ 1 - 1
proxy/http/server.go

@@ -193,7 +193,7 @@ func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *buf
 		inbound.CanSpliceCopy = 1
 	}
 	if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
-		Reader: &buf.TimeoutWrapperReader{Reader: reader},
+		Reader: reader,
 		Writer: buf.NewWriter(conn)},
 	); err != nil {
 		return errors.New("failed to dispatch request").Base(err)

+ 1 - 1
proxy/socks/server.go

@@ -161,7 +161,7 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
 			inbound.CanSpliceCopy = 1
 		}
 		if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
-			Reader: &buf.TimeoutWrapperReader{Reader: reader},
+			Reader: reader,
 			Writer: buf.NewWriter(conn)},
 		); err != nil {
 			return errors.New("failed to dispatch request").Base(err)

+ 1 - 1
proxy/vless/inbound/inbound.go

@@ -563,7 +563,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 	bufferWriter.SetFlushNext()
 
 	if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{
-		Reader: &buf.TimeoutWrapperReader{Reader: clientReader},
+		Reader: clientReader,
 		Writer: clientWriter},
 	); err != nil {
 		return errors.New("failed to dispatch request").Base(err)