Browse Source

Cleanup gun conn code

Hellojack 2 years ago
parent
commit
01b4769852
2 changed files with 28 additions and 42 deletions
  1. 16 27
      transport/v2raygrpclite/conn.go
  2. 12 15
      transport/v2rayhttp/conn.go

+ 16 - 27
transport/v2raygrpclite/conn.go

@@ -106,31 +106,22 @@ func (c *GunConn) Write(b []byte) (n int, err error) {
 	_, err = bufio.Copy(c.writer, io.MultiReader(bytes.NewReader(grpcHeader), bytes.NewReader(protobufHeader[:varuintLen+1]), bytes.NewReader(b)))
 	c.writeAccess.Unlock()
 	buf.Put(grpcHeader)
-	if c.flusher != nil {
+	if err == nil && c.flusher != nil {
 		c.flusher.Flush()
 	}
 	return len(b), baderror.WrapH2(err)
 }
 
-func uLen(x uint64) int {
-	i := 0
-	for x >= 0x80 {
-		x >>= 7
-		i++
-	}
-	return i + 1
-}
-
 func (c *GunConn) WriteBuffer(buffer *buf.Buffer) error {
 	defer buffer.Release()
 	dataLen := buffer.Len()
-	varLen := uLen(uint64(dataLen))
+	varLen := rw.UVariantLen(uint64(dataLen))
 	header := buffer.ExtendHeader(6 + varLen)
 	binary.BigEndian.PutUint32(header[1:5], uint32(1+varLen+dataLen))
 	header[5] = 0x0A
 	binary.PutUvarint(header[6:], uint64(dataLen))
 	err := rw.WriteBytes(c.writer, buffer.Bytes())
-	if c.flusher != nil {
+	if err == nil && c.flusher != nil {
 		c.flusher.Flush()
 	}
 	return baderror.WrapH2(err)
@@ -153,31 +144,29 @@ func (c *GunConn) RemoteAddr() net.Addr {
 }
 
 func (c *GunConn) SetDeadline(t time.Time) error {
-	responseWriter, loaded := c.writer.(interface {
+	if responseWriter, loaded := c.writer.(interface {
 		SetWriteDeadline(time.Time) error
-	})
-	if !loaded {
-		return os.ErrInvalid
+	}); loaded {
+		return responseWriter.SetWriteDeadline(t)
 	}
-	return responseWriter.SetWriteDeadline(t)
+	return os.ErrInvalid
+
 }
 
 func (c *GunConn) SetReadDeadline(t time.Time) error {
-	responseWriter, loaded := c.writer.(interface {
+	if responseWriter, loaded := c.writer.(interface {
 		SetReadDeadline(time.Time) error
-	})
-	if !loaded {
-		return os.ErrInvalid
+	}); loaded {
+		return responseWriter.SetReadDeadline(t)
 	}
-	return responseWriter.SetReadDeadline(t)
+	return os.ErrInvalid
 }
 
 func (c *GunConn) SetWriteDeadline(t time.Time) error {
-	responseWriter, loaded := c.writer.(interface {
+	if responseWriter, loaded := c.writer.(interface {
 		SetWriteDeadline(time.Time) error
-	})
-	if !loaded {
-		return os.ErrInvalid
+	}); loaded {
+		return responseWriter.SetWriteDeadline(t)
 	}
-	return responseWriter.SetWriteDeadline(t)
+	return os.ErrInvalid
 }

+ 12 - 15
transport/v2rayhttp/conn.go

@@ -67,33 +67,30 @@ func (c *HTTPConn) RemoteAddr() net.Addr {
 }
 
 func (c *HTTPConn) SetDeadline(t time.Time) error {
-	responseWriter, loaded := c.writer.(interface {
+	if responseWriter, loaded := c.writer.(interface {
 		SetWriteDeadline(time.Time) error
-	})
-	if !loaded {
-		return os.ErrInvalid
+	}); loaded {
+		return responseWriter.SetWriteDeadline(t)
 	}
-	return responseWriter.SetWriteDeadline(t)
+	return os.ErrInvalid
 }
 
 func (c *HTTPConn) SetReadDeadline(t time.Time) error {
-	responseWriter, loaded := c.writer.(interface {
+	if responseWriter, loaded := c.writer.(interface {
 		SetReadDeadline(time.Time) error
-	})
-	if !loaded {
-		return os.ErrInvalid
+	}); loaded {
+		return responseWriter.SetReadDeadline(t)
 	}
-	return responseWriter.SetReadDeadline(t)
+	return os.ErrInvalid
 }
 
 func (c *HTTPConn) SetWriteDeadline(t time.Time) error {
-	responseWriter, loaded := c.writer.(interface {
+	if responseWriter, loaded := c.writer.(interface {
 		SetWriteDeadline(time.Time) error
-	})
-	if !loaded {
-		return os.ErrInvalid
+	}); loaded {
+		return responseWriter.SetWriteDeadline(t)
 	}
-	return responseWriter.SetWriteDeadline(t)
+	return os.ErrInvalid
 }
 
 type ServerHTTPConn struct {