dohelper.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/labring/aiproxy/core/common"
  14. "github.com/labring/aiproxy/core/common/conv"
  15. "github.com/labring/aiproxy/core/model"
  16. "github.com/labring/aiproxy/core/relay/adaptor"
  17. "github.com/labring/aiproxy/core/relay/meta"
  18. "github.com/labring/aiproxy/core/relay/mode"
  19. relaymodel "github.com/labring/aiproxy/core/relay/model"
  20. log "github.com/sirupsen/logrus"
  21. )
  22. const (
  23. // 0.5MB
  24. maxBufferSize = 512 * 1024
  25. )
  26. type responseWriter struct {
  27. gin.ResponseWriter
  28. body *bytes.Buffer
  29. firstByteAt time.Time
  30. }
  31. func (rw *responseWriter) Write(b []byte) (int, error) {
  32. if rw.firstByteAt.IsZero() {
  33. rw.firstByteAt = time.Now()
  34. }
  35. if rw.body.Len()+len(b) <= maxBufferSize {
  36. rw.body.Write(b)
  37. } else {
  38. rw.body.Write(b[:maxBufferSize-rw.body.Len()])
  39. }
  40. return rw.ResponseWriter.Write(b)
  41. }
  42. func (rw *responseWriter) WriteString(s string) (int, error) {
  43. return rw.Write(conv.StringToBytes(s))
  44. }
  45. var bufferPool = sync.Pool{
  46. New: func() any {
  47. return bytes.NewBuffer(make([]byte, 0, maxBufferSize))
  48. },
  49. }
  50. func getBuffer() *bytes.Buffer {
  51. v, ok := bufferPool.Get().(*bytes.Buffer)
  52. if !ok {
  53. panic(fmt.Sprintf("buffer type error: %T, %v", v, v))
  54. }
  55. return v
  56. }
  57. func putBuffer(buf *bytes.Buffer) {
  58. buf.Reset()
  59. if buf.Cap() > maxBufferSize {
  60. return
  61. }
  62. bufferPool.Put(buf)
  63. }
  64. type RequestDetail struct {
  65. RequestBody string
  66. ResponseBody string
  67. FirstByteAt time.Time
  68. }
  69. func DoHelper(
  70. a adaptor.Adaptor,
  71. c *gin.Context,
  72. meta *meta.Meta,
  73. store adaptor.Store,
  74. ) (
  75. model.Usage,
  76. *RequestDetail,
  77. adaptor.Error,
  78. ) {
  79. detail := RequestDetail{}
  80. if err := storeRequestBody(meta, c, &detail); err != nil {
  81. return model.Usage{}, nil, err
  82. }
  83. // donot use c.Request.Context() because it will be canceled by the client
  84. ctx := context.Background()
  85. resp, err := prepareAndDoRequest(ctx, a, c, meta, store)
  86. if err != nil {
  87. return model.Usage{}, &detail, err
  88. }
  89. if resp == nil {
  90. relayErr := relaymodel.WrapperErrorWithMessage(
  91. meta.Mode,
  92. http.StatusInternalServerError,
  93. "response is nil",
  94. )
  95. respBody, _ := relayErr.MarshalJSON()
  96. detail.ResponseBody = conv.BytesToString(respBody)
  97. return model.Usage{}, &detail, relayErr
  98. }
  99. if resp.Body != nil {
  100. defer resp.Body.Close()
  101. }
  102. usage, relayErr := handleResponse(a, c, meta, store, resp, &detail)
  103. if relayErr != nil {
  104. return model.Usage{}, &detail, relayErr
  105. }
  106. log := common.GetLogger(c)
  107. updateUsageMetrics(usage, log)
  108. if !detail.FirstByteAt.IsZero() {
  109. ttfb := detail.FirstByteAt.Sub(meta.RequestAt)
  110. log.Data["ttfb"] = common.TruncateDuration(ttfb).String()
  111. }
  112. return usage, &detail, nil
  113. }
  114. func storeRequestBody(meta *meta.Meta, c *gin.Context, detail *RequestDetail) adaptor.Error {
  115. switch {
  116. case meta.Mode == mode.AudioTranscription,
  117. meta.Mode == mode.AudioTranslation,
  118. meta.Mode == mode.ImagesEdits:
  119. return nil
  120. case !common.IsJSONContentType(c.GetHeader("Content-Type")):
  121. return nil
  122. default:
  123. reqBody, err := common.GetRequestBodyReusable(c.Request)
  124. if err != nil {
  125. return relaymodel.WrapperErrorWithMessage(
  126. meta.Mode,
  127. http.StatusBadRequest,
  128. "get request body failed: "+err.Error(),
  129. )
  130. }
  131. detail.RequestBody = conv.BytesToString(reqBody)
  132. return nil
  133. }
  134. }
  135. func prepareAndDoRequest(
  136. ctx context.Context,
  137. a adaptor.Adaptor,
  138. c *gin.Context,
  139. meta *meta.Meta,
  140. store adaptor.Store,
  141. ) (*http.Response, adaptor.Error) {
  142. log := common.GetLogger(c)
  143. convertResult, err := a.ConvertRequest(meta, store, c.Request)
  144. if err != nil {
  145. return nil, relaymodel.WrapperErrorWithMessage(
  146. meta.Mode,
  147. http.StatusBadRequest,
  148. "convert request failed: "+err.Error(),
  149. )
  150. }
  151. if closer, ok := convertResult.Body.(io.Closer); ok {
  152. defer closer.Close()
  153. }
  154. if meta.Channel.BaseURL == "" {
  155. meta.Channel.BaseURL = a.DefaultBaseURL()
  156. }
  157. fullRequestURL, err := a.GetRequestURL(meta, store)
  158. if err != nil {
  159. return nil, relaymodel.WrapperErrorWithMessage(
  160. meta.Mode,
  161. http.StatusBadRequest,
  162. "get request url failed: "+err.Error(),
  163. )
  164. }
  165. log.Debugf("request url: %s %s", fullRequestURL.Method, fullRequestURL.URL)
  166. req, err := http.NewRequestWithContext(
  167. ctx,
  168. fullRequestURL.Method,
  169. fullRequestURL.URL,
  170. convertResult.Body,
  171. )
  172. if err != nil {
  173. return nil, relaymodel.WrapperErrorWithMessage(
  174. meta.Mode,
  175. http.StatusBadRequest,
  176. "new request failed: "+err.Error(),
  177. )
  178. }
  179. if err := setupRequestHeader(a, c, meta, store, req, convertResult.Header); err != nil {
  180. return nil, err
  181. }
  182. return doRequest(a, c, meta, store, req)
  183. }
  184. func setupRequestHeader(
  185. a adaptor.Adaptor,
  186. c *gin.Context,
  187. meta *meta.Meta,
  188. store adaptor.Store,
  189. req *http.Request,
  190. header http.Header,
  191. ) adaptor.Error {
  192. for key, value := range header {
  193. req.Header[key] = value
  194. }
  195. if err := a.SetupRequestHeader(meta, store, c, req); err != nil {
  196. return relaymodel.WrapperErrorWithMessage(
  197. meta.Mode,
  198. http.StatusInternalServerError,
  199. "setup request header failed: "+err.Error(),
  200. )
  201. }
  202. return nil
  203. }
  204. func doRequest(
  205. a adaptor.Adaptor,
  206. c *gin.Context,
  207. meta *meta.Meta,
  208. store adaptor.Store,
  209. req *http.Request,
  210. ) (*http.Response, adaptor.Error) {
  211. resp, err := a.DoRequest(meta, store, c, req)
  212. if err != nil {
  213. if errors.Is(err, context.Canceled) {
  214. return nil, relaymodel.WrapperErrorWithMessage(
  215. meta.Mode,
  216. http.StatusBadRequest,
  217. "request canceled by client: "+err.Error(),
  218. )
  219. }
  220. if errors.Is(err, context.DeadlineExceeded) {
  221. return nil, relaymodel.WrapperErrorWithMessage(
  222. meta.Mode,
  223. http.StatusRequestTimeout,
  224. "request timeout: "+err.Error(),
  225. )
  226. }
  227. if errors.Is(err, io.EOF) {
  228. return nil, relaymodel.WrapperErrorWithMessage(
  229. meta.Mode,
  230. http.StatusServiceUnavailable,
  231. "request eof: "+err.Error(),
  232. )
  233. }
  234. if errors.Is(err, io.ErrUnexpectedEOF) {
  235. return nil, relaymodel.WrapperErrorWithMessage(
  236. meta.Mode,
  237. http.StatusInternalServerError,
  238. "request unexpected eof: "+err.Error(),
  239. )
  240. }
  241. if strings.Contains(err.Error(), "timeout awaiting response headers") {
  242. return nil, relaymodel.WrapperErrorWithMessage(
  243. meta.Mode,
  244. http.StatusRequestTimeout,
  245. "request timeout: "+err.Error(),
  246. )
  247. }
  248. return nil, relaymodel.WrapperErrorWithMessage(
  249. meta.Mode,
  250. http.StatusInternalServerError,
  251. "request error: "+err.Error(),
  252. )
  253. }
  254. return resp, nil
  255. }
  256. func handleResponse(
  257. a adaptor.Adaptor,
  258. c *gin.Context,
  259. meta *meta.Meta,
  260. store adaptor.Store,
  261. resp *http.Response,
  262. detail *RequestDetail,
  263. ) (model.Usage, adaptor.Error) {
  264. buf := getBuffer()
  265. defer putBuffer(buf)
  266. rw := &responseWriter{
  267. ResponseWriter: c.Writer,
  268. body: buf,
  269. }
  270. rawWriter := c.Writer
  271. defer func() {
  272. c.Writer = rawWriter
  273. detail.FirstByteAt = rw.firstByteAt
  274. }()
  275. c.Writer = rw
  276. usage, relayErr := a.DoResponse(meta, store, c, resp)
  277. if relayErr != nil {
  278. respBody, _ := relayErr.MarshalJSON()
  279. detail.ResponseBody = conv.BytesToString(respBody)
  280. } else {
  281. // copy body buffer
  282. // do not use bytes conv
  283. detail.ResponseBody = rw.body.String()
  284. }
  285. return usage, relayErr
  286. }
  287. func updateUsageMetrics(usage model.Usage, log *log.Entry) {
  288. if usage.TotalTokens == 0 {
  289. usage.TotalTokens = usage.InputTokens + usage.OutputTokens
  290. }
  291. if usage.InputTokens > 0 {
  292. log.Data["t_input"] = usage.InputTokens
  293. }
  294. if usage.ImageInputTokens > 0 {
  295. log.Data["t_image_input"] = usage.ImageInputTokens
  296. }
  297. if usage.AudioInputTokens > 0 {
  298. log.Data["t_audio_input"] = usage.AudioInputTokens
  299. }
  300. if usage.OutputTokens > 0 {
  301. log.Data["t_output"] = usage.OutputTokens
  302. }
  303. if usage.TotalTokens > 0 {
  304. log.Data["t_total"] = usage.TotalTokens
  305. }
  306. if usage.CachedTokens > 0 {
  307. log.Data["t_cached"] = usage.CachedTokens
  308. }
  309. if usage.CacheCreationTokens > 0 {
  310. log.Data["t_cache_creation"] = usage.CacheCreationTokens
  311. }
  312. if usage.ReasoningTokens > 0 {
  313. log.Data["t_reason"] = usage.ReasoningTokens
  314. }
  315. if usage.WebSearchCount > 0 {
  316. log.Data["t_websearch"] = usage.WebSearchCount
  317. }
  318. }