websocket.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package relay
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/gorilla/websocket"
  7. "net/http"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. )
  13. func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
  14. relayInfo := relaycommon.GenRelayInfoWs(c, ws)
  15. // get & validate textRequest 获取并验证文本请求
  16. //realtimeEvent, err := getAndValidateWssRequest(c, ws)
  17. //if err != nil {
  18. // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
  19. // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
  20. //}
  21. // map model name
  22. modelMapping := c.GetString("model_mapping")
  23. //isModelMapped := false
  24. if modelMapping != "" && modelMapping != "{}" {
  25. modelMap := make(map[string]string)
  26. err := json.Unmarshal([]byte(modelMapping), &modelMap)
  27. if err != nil {
  28. return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
  29. }
  30. if modelMap[relayInfo.OriginModelName] != "" {
  31. relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
  32. // set upstream model name
  33. //isModelMapped = true
  34. }
  35. }
  36. priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
  37. if err != nil {
  38. return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
  39. }
  40. // pre-consume quota 预消耗配额
  41. preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
  42. if openaiErr != nil {
  43. return openaiErr
  44. }
  45. defer func() {
  46. if openaiErr != nil {
  47. returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
  48. }
  49. }()
  50. adaptor := GetAdaptor(relayInfo.ApiType)
  51. if adaptor == nil {
  52. return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
  53. }
  54. adaptor.Init(relayInfo)
  55. //var requestBody io.Reader
  56. //firstWssRequest, _ := c.Get("first_wss_request")
  57. //requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
  58. statusCodeMappingStr := c.GetString("status_code_mapping")
  59. resp, err := adaptor.DoRequest(c, relayInfo, nil)
  60. if err != nil {
  61. return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  62. }
  63. if resp != nil {
  64. relayInfo.TargetWs = resp.(*websocket.Conn)
  65. defer relayInfo.TargetWs.Close()
  66. }
  67. usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
  68. if openaiErr != nil {
  69. // reset status code 重置状态码
  70. service.ResetStatusCode(openaiErr, statusCodeMappingStr)
  71. return openaiErr
  72. }
  73. service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
  74. userQuota, priceData, "")
  75. return nil
  76. }