websocket.go 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. package relay
  2. import (
  3. "fmt"
  4. "github.com/QuantumNous/new-api/dto"
  5. relaycommon "github.com/QuantumNous/new-api/relay/common"
  6. "github.com/QuantumNous/new-api/service"
  7. "github.com/QuantumNous/new-api/types"
  8. "github.com/gin-gonic/gin"
  9. "github.com/gorilla/websocket"
  10. )
  11. func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
  12. info.InitChannelMeta(c)
  13. adaptor := GetAdaptor(info.ApiType)
  14. if adaptor == nil {
  15. return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
  16. }
  17. adaptor.Init(info)
  18. //var requestBody io.Reader
  19. //firstWssRequest, _ := c.Get("first_wss_request")
  20. //requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
  21. statusCodeMappingStr := c.GetString("status_code_mapping")
  22. resp, err := adaptor.DoRequest(c, info, nil)
  23. if err != nil {
  24. return types.NewError(err, types.ErrorCodeDoRequestFailed)
  25. }
  26. if resp != nil {
  27. info.TargetWs = resp.(*websocket.Conn)
  28. defer info.TargetWs.Close()
  29. }
  30. usage, newAPIError := adaptor.DoResponse(c, nil, info)
  31. if newAPIError != nil {
  32. // reset status code 重置状态码
  33. service.ResetStatusCode(newAPIError, statusCodeMappingStr)
  34. return newAPIError
  35. }
  36. service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "")
  37. return nil
  38. }