dohelper.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "maps"
  9. "net/http"
  10. "strings"
  11. "sync"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/conv"
  16. "github.com/labring/aiproxy/core/model"
  17. "github.com/labring/aiproxy/core/relay/adaptor"
  18. "github.com/labring/aiproxy/core/relay/meta"
  19. "github.com/labring/aiproxy/core/relay/mode"
  20. relaymodel "github.com/labring/aiproxy/core/relay/model"
  21. log "github.com/sirupsen/logrus"
  22. )
  23. const (
  24. // 0.5MB
  25. maxBufferSize = 512 * 1024
  26. )
  27. type responseWriter struct {
  28. gin.ResponseWriter
  29. body *bytes.Buffer
  30. firstByteAt time.Time
  31. }
  32. func (rw *responseWriter) Write(b []byte) (int, error) {
  33. if rw.firstByteAt.IsZero() {
  34. rw.firstByteAt = time.Now()
  35. }
  36. if rw.body.Len()+len(b) <= maxBufferSize {
  37. rw.body.Write(b)
  38. } else {
  39. rw.body.Write(b[:maxBufferSize-rw.body.Len()])
  40. }
  41. return rw.ResponseWriter.Write(b)
  42. }
  43. func (rw *responseWriter) WriteString(s string) (int, error) {
  44. return rw.Write(conv.StringToBytes(s))
  45. }
  46. var bufferPool = sync.Pool{
  47. New: func() any {
  48. return bytes.NewBuffer(make([]byte, 0, maxBufferSize))
  49. },
  50. }
  51. func getBuffer() *bytes.Buffer {
  52. v, ok := bufferPool.Get().(*bytes.Buffer)
  53. if !ok {
  54. panic(fmt.Sprintf("buffer type error: %T, %v", v, v))
  55. }
  56. return v
  57. }
  58. func putBuffer(buf *bytes.Buffer) {
  59. buf.Reset()
  60. if buf.Cap() > maxBufferSize {
  61. return
  62. }
  63. bufferPool.Put(buf)
  64. }
  65. type RequestDetail struct {
  66. RequestBody string
  67. ResponseBody string
  68. FirstByteAt time.Time
  69. }
  70. func DoHelper(
  71. a adaptor.Adaptor,
  72. c *gin.Context,
  73. meta *meta.Meta,
  74. store adaptor.Store,
  75. ) (
  76. model.Usage,
  77. *RequestDetail,
  78. adaptor.Error,
  79. ) {
  80. detail := RequestDetail{}
  81. if err := storeRequestBody(meta, c, &detail); err != nil {
  82. return model.Usage{}, nil, err
  83. }
  84. // donot use c.Request.Context() because it will be canceled by the client
  85. ctx := context.Background()
  86. resp, err := prepareAndDoRequest(ctx, a, c, meta, store)
  87. if err != nil {
  88. return model.Usage{}, &detail, err
  89. }
  90. if resp == nil {
  91. relayErr := relaymodel.WrapperErrorWithMessage(
  92. meta.Mode,
  93. http.StatusInternalServerError,
  94. "response is nil",
  95. )
  96. respBody, _ := relayErr.MarshalJSON()
  97. detail.ResponseBody = conv.BytesToString(respBody)
  98. return model.Usage{}, &detail, relayErr
  99. }
  100. if resp.Body != nil {
  101. defer resp.Body.Close()
  102. }
  103. usage, relayErr := handleResponse(a, c, meta, store, resp, &detail)
  104. if relayErr != nil {
  105. return model.Usage{}, &detail, relayErr
  106. }
  107. log := common.GetLogger(c)
  108. updateUsageMetrics(usage, log)
  109. if !detail.FirstByteAt.IsZero() {
  110. ttfb := detail.FirstByteAt.Sub(meta.RequestAt)
  111. log.Data["ttfb"] = common.TruncateDuration(ttfb).String()
  112. }
  113. return usage, &detail, nil
  114. }
  115. func storeRequestBody(meta *meta.Meta, c *gin.Context, detail *RequestDetail) adaptor.Error {
  116. switch {
  117. case meta.Mode == mode.AudioTranscription,
  118. meta.Mode == mode.AudioTranslation,
  119. meta.Mode == mode.ImagesEdits:
  120. return nil
  121. case !common.IsJSONContentType(c.GetHeader("Content-Type")):
  122. return nil
  123. default:
  124. reqBody, err := common.GetRequestBodyReusable(c.Request)
  125. if err != nil {
  126. return relaymodel.WrapperErrorWithMessage(
  127. meta.Mode,
  128. http.StatusBadRequest,
  129. "get request body failed: "+err.Error(),
  130. )
  131. }
  132. detail.RequestBody = conv.BytesToString(reqBody)
  133. return nil
  134. }
  135. }
  136. func prepareAndDoRequest(
  137. ctx context.Context,
  138. a adaptor.Adaptor,
  139. c *gin.Context,
  140. meta *meta.Meta,
  141. store adaptor.Store,
  142. ) (*http.Response, adaptor.Error) {
  143. log := common.GetLogger(c)
  144. convertResult, err := a.ConvertRequest(meta, store, c.Request)
  145. if err != nil {
  146. return nil, relaymodel.WrapperErrorWithMessage(
  147. meta.Mode,
  148. http.StatusBadRequest,
  149. "convert request failed: "+err.Error(),
  150. )
  151. }
  152. if closer, ok := convertResult.Body.(io.Closer); ok {
  153. defer closer.Close()
  154. }
  155. if meta.Channel.BaseURL == "" {
  156. meta.Channel.BaseURL = a.DefaultBaseURL()
  157. }
  158. fullRequestURL, err := a.GetRequestURL(meta, store, c)
  159. if err != nil {
  160. return nil, relaymodel.WrapperErrorWithMessage(
  161. meta.Mode,
  162. http.StatusBadRequest,
  163. "get request url failed: "+err.Error(),
  164. )
  165. }
  166. log.Debugf("request url: %s %s", fullRequestURL.Method, fullRequestURL.URL)
  167. req, err := http.NewRequestWithContext(
  168. ctx,
  169. fullRequestURL.Method,
  170. fullRequestURL.URL,
  171. convertResult.Body,
  172. )
  173. if err != nil {
  174. return nil, relaymodel.WrapperErrorWithMessage(
  175. meta.Mode,
  176. http.StatusBadRequest,
  177. "new request failed: "+err.Error(),
  178. )
  179. }
  180. if err := setupRequestHeader(a, c, meta, store, req, convertResult.Header); err != nil {
  181. return nil, err
  182. }
  183. return doRequest(a, c, meta, store, req)
  184. }
  185. func setupRequestHeader(
  186. a adaptor.Adaptor,
  187. c *gin.Context,
  188. meta *meta.Meta,
  189. store adaptor.Store,
  190. req *http.Request,
  191. header http.Header,
  192. ) adaptor.Error {
  193. maps.Copy(req.Header, header)
  194. if err := a.SetupRequestHeader(meta, store, c, req); err != nil {
  195. return relaymodel.WrapperErrorWithMessage(
  196. meta.Mode,
  197. http.StatusInternalServerError,
  198. "setup request header failed: "+err.Error(),
  199. )
  200. }
  201. return nil
  202. }
  203. func doRequest(
  204. a adaptor.Adaptor,
  205. c *gin.Context,
  206. meta *meta.Meta,
  207. store adaptor.Store,
  208. req *http.Request,
  209. ) (*http.Response, adaptor.Error) {
  210. resp, err := a.DoRequest(meta, store, c, req)
  211. if err != nil {
  212. var adaptorErr adaptor.Error
  213. ok := errors.As(err, &adaptorErr)
  214. if ok {
  215. return nil, adaptorErr
  216. }
  217. if errors.Is(err, context.Canceled) {
  218. return nil, relaymodel.WrapperErrorWithMessage(
  219. meta.Mode,
  220. http.StatusBadRequest,
  221. "request canceled by client: "+err.Error(),
  222. )
  223. }
  224. if errors.Is(err, context.DeadlineExceeded) {
  225. return nil, relaymodel.WrapperErrorWithMessage(
  226. meta.Mode,
  227. http.StatusRequestTimeout,
  228. "request timeout: "+err.Error(),
  229. )
  230. }
  231. if errors.Is(err, io.EOF) {
  232. return nil, relaymodel.WrapperErrorWithMessage(
  233. meta.Mode,
  234. http.StatusServiceUnavailable,
  235. "request eof: "+err.Error(),
  236. )
  237. }
  238. if errors.Is(err, io.ErrUnexpectedEOF) {
  239. return nil, relaymodel.WrapperErrorWithMessage(
  240. meta.Mode,
  241. http.StatusInternalServerError,
  242. "request unexpected eof: "+err.Error(),
  243. )
  244. }
  245. if strings.Contains(err.Error(), "timeout awaiting response headers") {
  246. return nil, relaymodel.WrapperErrorWithMessage(
  247. meta.Mode,
  248. http.StatusRequestTimeout,
  249. "request timeout: "+err.Error(),
  250. )
  251. }
  252. return nil, relaymodel.WrapperErrorWithMessage(
  253. meta.Mode,
  254. http.StatusInternalServerError,
  255. "request error: "+err.Error(),
  256. )
  257. }
  258. return resp, nil
  259. }
  260. func handleResponse(
  261. a adaptor.Adaptor,
  262. c *gin.Context,
  263. meta *meta.Meta,
  264. store adaptor.Store,
  265. resp *http.Response,
  266. detail *RequestDetail,
  267. ) (model.Usage, adaptor.Error) {
  268. buf := getBuffer()
  269. defer putBuffer(buf)
  270. rw := &responseWriter{
  271. ResponseWriter: c.Writer,
  272. body: buf,
  273. }
  274. rawWriter := c.Writer
  275. defer func() {
  276. c.Writer = rawWriter
  277. detail.FirstByteAt = rw.firstByteAt
  278. }()
  279. c.Writer = rw
  280. usage, relayErr := a.DoResponse(meta, store, c, resp)
  281. if relayErr != nil {
  282. respBody, _ := relayErr.MarshalJSON()
  283. detail.ResponseBody = conv.BytesToString(respBody)
  284. } else {
  285. // copy body buffer
  286. // do not use bytes conv
  287. detail.ResponseBody = rw.body.String()
  288. }
  289. return usage, relayErr
  290. }
  291. func updateUsageMetrics(usage model.Usage, log *log.Entry) {
  292. if usage.TotalTokens == 0 {
  293. usage.TotalTokens = usage.InputTokens + usage.OutputTokens
  294. }
  295. if usage.InputTokens > 0 {
  296. log.Data["t_input"] = usage.InputTokens
  297. }
  298. if usage.ImageInputTokens > 0 {
  299. log.Data["t_image_input"] = usage.ImageInputTokens
  300. }
  301. if usage.AudioInputTokens > 0 {
  302. log.Data["t_audio_input"] = usage.AudioInputTokens
  303. }
  304. if usage.OutputTokens > 0 {
  305. log.Data["t_output"] = usage.OutputTokens
  306. }
  307. if usage.TotalTokens > 0 {
  308. log.Data["t_total"] = usage.TotalTokens
  309. }
  310. if usage.CachedTokens > 0 {
  311. log.Data["t_cached"] = usage.CachedTokens
  312. }
  313. if usage.CacheCreationTokens > 0 {
  314. log.Data["t_cache_creation"] = usage.CacheCreationTokens
  315. }
  316. if usage.ReasoningTokens > 0 {
  317. log.Data["t_reason"] = usage.ReasoningTokens
  318. }
  319. if usage.WebSearchCount > 0 {
  320. log.Data["t_websearch"] = usage.WebSearchCount
  321. }
  322. }