Explorar o código

Refine Trojan packet reader & writer (#142)

maskedeken %!s(int64=4) %!d(string=hai) anos
pai
achega
d5aeb6c545
Modificáronse 3 ficheiros con 21 adicións e 50 borrados
  1. 1 36
      proxy/trojan/protocol.go
  2. 8 7
      proxy/trojan/protocol_test.go
  3. 12 7
      proxy/trojan/server.go

+ 1 - 36
proxy/trojan/protocol.go

@@ -146,26 +146,6 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	return nil
 }
 
-// WriteMultiBufferWithMetadata writes udp packet with destination specified
-func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
-	for {
-		mb2, b := buf.SplitFirst(mb)
-		mb = mb2
-		if b == nil {
-			break
-		}
-		source := &dest
-		if b.UDP != nil {
-			source = b.UDP
-		}
-		if _, err := w.writePacket(b.Bytes(), *source); err != nil {
-			buf.ReleaseMulti(mb)
-			return err
-		}
-	}
-	return nil
-}
-
 func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
 	buffer := buf.StackNew()
 	defer buffer.Release()
@@ -259,12 +239,6 @@ func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	return buf.MultiBuffer{b}, err
 }
 
-// PacketPayload combines udp payload and destination
-type PacketPayload struct {
-	Target net.Destination
-	Buffer buf.MultiBuffer
-}
-
 // PacketReader is UDP Connection Reader Wrapper for trojan protocol
 type PacketReader struct {
 	io.Reader
@@ -272,15 +246,6 @@ type PacketReader struct {
 
 // ReadMultiBuffer implements buf.Reader
 func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
-	p, err := r.ReadMultiBufferWithMetadata()
-	if p != nil {
-		return p.Buffer, err
-	}
-	return nil, err
-}
-
-// ReadMultiBufferWithMetadata reads udp packet with destination
-func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
 	addr, port, err := addrParser.ReadAddressPort(nil, r)
 	if err != nil {
 		return nil, newError("failed to read address and port").Base(err)
@@ -321,7 +286,7 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
 		remain -= int(n)
 	}
 
-	return &PacketPayload{Target: dest, Buffer: mb}, nil
+	return mb, nil
 }
 
 func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn *xtls.Conn, rawConn syscall.RawConn, counter stats.Counter, sctx context.Context) error {

+ 8 - 7
proxy/trojan/protocol_test.go

@@ -71,21 +71,22 @@ func TestUDPRequest(t *testing.T) {
 	common.Must(connReader.ParseHeader())
 
 	packetReader := &PacketReader{Reader: connReader}
-	p, err := packetReader.ReadMultiBufferWithMetadata()
+	mb, err := packetReader.ReadMultiBuffer()
 	common.Must(err)
 
-	if p.Buffer.IsEmpty() {
+	if mb.IsEmpty() {
 		t.Error("no request data")
 	}
 
-	if r := cmp.Diff(p.Target, destination); r != "" {
+	mb2, b := buf.SplitFirst(mb)
+	defer buf.ReleaseMulti(mb2)
+
+	dest := *b.UDP
+	if r := cmp.Diff(dest, destination); r != "" {
 		t.Error("destination: ", r)
 	}
 
-	mb, decoded := buf.SplitFirst(p.Buffer)
-	buf.ReleaseMulti(mb)
-
-	if r := cmp.Diff(decoded.Bytes(), payload); r != "" {
+	if r := cmp.Diff(b.Bytes(), payload); r != "" {
 		t.Error("data: ", r)
 	}
 }

+ 12 - 7
proxy/trojan/server.go

@@ -250,7 +250,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 
 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
 	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
-		common.Must(clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source))
+		udpPayload := packet.Payload
+		udpPayload.UDP = &packet.Source
+		common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}))
 	})
 
 	inbound := session.InboundFromContext(ctx)
@@ -263,7 +265,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 		case <-ctx.Done():
 			return nil
 		default:
-			p, err := clientReader.ReadMultiBufferWithMetadata()
+			mb, err := clientReader.ReadMultiBuffer()
 			if err != nil {
 				if errors.Cause(err) != io.EOF {
 					return newError("unexpected EOF").Base(err)
@@ -271,21 +273,24 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 				return nil
 			}
 
+			mb2, b := buf.SplitFirst(mb)
+			destination := *b.UDP
 			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 				From:   inbound.Source,
-				To:     p.Target,
+				To:     destination,
 				Status: log.AccessAccepted,
 				Reason: "",
 				Email:  user.Email,
 			})
-			newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx))
+			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
 
 			if !buf.Cone || dest == nil {
-				dest = &p.Target
+				dest = &destination
 			}
 
-			for _, b := range p.Buffer {
-				udpServer.Dispatch(ctx, *dest, b)
+			udpServer.Dispatch(ctx, *dest, b) // first packet
+			for _, payload := range mb2 {
+				udpServer.Dispatch(ctx, *dest, payload)
 			}
 		}
 	}