| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- package wireguard
- import (
- "io"
- "github.com/sagernet/sing/common/buf"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/wireguard-go/conn"
- )
- var _ conn.Bind = (*ServerBind)(nil)
- type ServerBind struct {
- inbound chan serverPacket
- done chan struct{}
- writeBack N.PacketWriter
- }
- func NewServerBind(writeBack N.PacketWriter) *ServerBind {
- return &ServerBind{
- inbound: make(chan serverPacket, 256),
- done: make(chan struct{}),
- writeBack: writeBack,
- }
- }
- func (s *ServerBind) Abort() error {
- select {
- case <-s.done:
- return io.ErrClosedPipe
- default:
- close(s.done)
- }
- return nil
- }
- type serverPacket struct {
- buffer *buf.Buffer
- source M.Socksaddr
- }
- func (s *ServerBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
- fns = []conn.ReceiveFunc{s.receive}
- return
- }
- func (s *ServerBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
- select {
- case packet := <-s.inbound:
- defer packet.buffer.Release()
- n = copy(b, packet.buffer.Bytes())
- ep = Endpoint(packet.source)
- return
- case <-s.done:
- err = io.ErrClosedPipe
- return
- }
- }
- func (s *ServerBind) WriteIsThreadUnsafe() {
- }
- func (s *ServerBind) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
- select {
- case s.inbound <- serverPacket{
- buffer: buffer,
- source: destination,
- }:
- return nil
- case <-s.done:
- return io.ErrClosedPipe
- }
- }
- func (s *ServerBind) Close() error {
- return nil
- }
- func (s *ServerBind) SetMark(mark uint32) error {
- return nil
- }
- func (s *ServerBind) Send(b []byte, ep conn.Endpoint) error {
- return s.writeBack.WritePacket(buf.As(b), M.Socksaddr(ep.(Endpoint)))
- }
- func (s *ServerBind) ParseEndpoint(addr string) (conn.Endpoint, error) {
- destination := M.ParseSocksaddr(addr)
- if !destination.IsValid() || destination.Port == 0 {
- return nil, E.New("invalid endpoint: ", addr)
- }
- return Endpoint(destination), nil
- }
|