|
|
@@ -233,7 +233,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
|
|
sessionPolicy = s.policyManager.ForLevel(user.Level)
|
|
|
|
|
|
if destination.Network == net.Network_UDP { // handle udp request
|
|
|
- return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
|
|
|
+ return s.handleUDPPayload(ctx, sessionPolicy, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
|
|
|
}
|
|
|
|
|
|
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
|
|
@@ -248,7 +248,11 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
|
|
return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
|
|
|
}
|
|
|
|
|
|
-func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
|
|
|
+func (s *Server) handleUDPPayload(ctx context.Context, sessionPolicy policy.Session, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
|
|
|
+ ctx, cancel := context.WithCancel(ctx)
|
|
|
+ defer cancel()
|
|
|
+ timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
|
|
|
+ defer timer.SetTimeout(0)
|
|
|
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
|
|
|
udpPayload := packet.Payload
|
|
|
if udpPayload.UDP == nil {
|
|
|
@@ -257,6 +261,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
|
|
|
|
|
|
if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
|
|
|
errors.LogWarningInner(ctx, err, "failed to write response")
|
|
|
+ cancel()
|
|
|
+ } else {
|
|
|
+ timer.Update()
|
|
|
}
|
|
|
})
|
|
|
defer udpServer.RemoveRay()
|
|
|
@@ -266,47 +273,56 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
|
|
|
|
|
|
var dest *net.Destination
|
|
|
|
|
|
- for {
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- return nil
|
|
|
- default:
|
|
|
- mb, err := clientReader.ReadMultiBuffer()
|
|
|
- if err != nil {
|
|
|
- if errors.Cause(err) != io.EOF {
|
|
|
- return errors.New("unexpected EOF").Base(err)
|
|
|
- }
|
|
|
+ requestDone := func() error {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
return nil
|
|
|
- }
|
|
|
+ default:
|
|
|
+ mb, err := clientReader.ReadMultiBuffer()
|
|
|
+ if err != nil {
|
|
|
+ if errors.Cause(err) != io.EOF {
|
|
|
+ return errors.New("unexpected EOF").Base(err)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
|
|
|
- mb2, b := buf.SplitFirst(mb)
|
|
|
- if b == nil {
|
|
|
- continue
|
|
|
- }
|
|
|
- destination := *b.UDP
|
|
|
-
|
|
|
- currentPacketCtx := ctx
|
|
|
- if inbound.Source.IsValid() {
|
|
|
- currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
|
|
- From: inbound.Source,
|
|
|
- To: destination,
|
|
|
- Status: log.AccessAccepted,
|
|
|
- Reason: "",
|
|
|
- Email: user.Email,
|
|
|
- })
|
|
|
- }
|
|
|
- errors.LogInfo(ctx, "tunnelling request to ", destination)
|
|
|
+ mb2, b := buf.SplitFirst(mb)
|
|
|
+ if b == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ timer.Update()
|
|
|
+ destination := *b.UDP
|
|
|
+
|
|
|
+ currentPacketCtx := ctx
|
|
|
+ if inbound.Source.IsValid() {
|
|
|
+ currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
|
|
+ From: inbound.Source,
|
|
|
+ To: destination,
|
|
|
+ Status: log.AccessAccepted,
|
|
|
+ Reason: "",
|
|
|
+ Email: user.Email,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ errors.LogInfo(ctx, "tunnelling request to ", destination)
|
|
|
|
|
|
- if !s.cone || dest == nil {
|
|
|
- dest = &destination
|
|
|
- }
|
|
|
+ if !s.cone || dest == nil {
|
|
|
+ dest = &destination
|
|
|
+ }
|
|
|
|
|
|
- udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
|
|
|
- for _, payload := range mb2 {
|
|
|
- udpServer.Dispatch(currentPacketCtx, *dest, payload)
|
|
|
+ udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
|
|
|
+ for _, payload := range mb2 {
|
|
|
+ udpServer.Dispatch(currentPacketCtx, *dest, payload)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := task.Run(ctx, requestDone); err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,
|