dohelper.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/labring/aiproxy/core/common"
  14. "github.com/labring/aiproxy/core/common/conv"
  15. "github.com/labring/aiproxy/core/model"
  16. "github.com/labring/aiproxy/core/relay/adaptor"
  17. "github.com/labring/aiproxy/core/relay/meta"
  18. "github.com/labring/aiproxy/core/relay/mode"
  19. relaymodel "github.com/labring/aiproxy/core/relay/model"
  20. log "github.com/sirupsen/logrus"
  21. )
  22. const (
  23. // 0.5MB
  24. maxBufferSize = 512 * 1024
  25. )
  26. type responseWriter struct {
  27. gin.ResponseWriter
  28. body *bytes.Buffer
  29. firstByteAt time.Time
  30. }
  31. func (rw *responseWriter) Write(b []byte) (int, error) {
  32. if rw.firstByteAt.IsZero() {
  33. rw.firstByteAt = time.Now()
  34. }
  35. if rw.body.Len()+len(b) <= maxBufferSize {
  36. rw.body.Write(b)
  37. } else {
  38. rw.body.Write(b[:maxBufferSize-rw.body.Len()])
  39. }
  40. return rw.ResponseWriter.Write(b)
  41. }
  42. func (rw *responseWriter) WriteString(s string) (int, error) {
  43. return rw.Write(conv.StringToBytes(s))
  44. }
  45. var bufferPool = sync.Pool{
  46. New: func() any {
  47. return bytes.NewBuffer(make([]byte, 0, maxBufferSize))
  48. },
  49. }
  50. func getBuffer() *bytes.Buffer {
  51. v, ok := bufferPool.Get().(*bytes.Buffer)
  52. if !ok {
  53. panic(fmt.Sprintf("buffer type error: %T, %v", v, v))
  54. }
  55. return v
  56. }
  57. func putBuffer(buf *bytes.Buffer) {
  58. buf.Reset()
  59. if buf.Cap() > maxBufferSize {
  60. return
  61. }
  62. bufferPool.Put(buf)
  63. }
  64. type RequestDetail struct {
  65. RequestBody string
  66. ResponseBody string
  67. FirstByteAt time.Time
  68. }
  69. func DoHelper(
  70. a adaptor.Adaptor,
  71. c *gin.Context,
  72. meta *meta.Meta,
  73. store adaptor.Store,
  74. ) (
  75. model.Usage,
  76. *RequestDetail,
  77. adaptor.Error,
  78. ) {
  79. detail := RequestDetail{}
  80. if err := storeRequestBody(meta, c, &detail); err != nil {
  81. return model.Usage{}, nil, err
  82. }
  83. // donot use c.Request.Context() because it will be canceled by the client
  84. ctx := context.Background()
  85. resp, err := prepareAndDoRequest(ctx, a, c, meta, store)
  86. if err != nil {
  87. return model.Usage{}, &detail, err
  88. }
  89. if resp == nil {
  90. relayErr := relaymodel.WrapperErrorWithMessage(
  91. meta.Mode,
  92. http.StatusInternalServerError,
  93. "response is nil",
  94. )
  95. respBody, _ := relayErr.MarshalJSON()
  96. detail.ResponseBody = conv.BytesToString(respBody)
  97. return model.Usage{}, &detail, relayErr
  98. }
  99. if resp.Body != nil {
  100. defer resp.Body.Close()
  101. }
  102. usage, relayErr := handleResponse(a, c, meta, store, resp, &detail)
  103. if relayErr != nil {
  104. return model.Usage{}, &detail, relayErr
  105. }
  106. log := common.GetLogger(c)
  107. updateUsageMetrics(usage, log)
  108. if !detail.FirstByteAt.IsZero() {
  109. ttfb := detail.FirstByteAt.Sub(meta.RequestAt)
  110. log.Data["ttfb"] = common.TruncateDuration(ttfb).String()
  111. }
  112. return usage, &detail, nil
  113. }
  114. func storeRequestBody(meta *meta.Meta, c *gin.Context, detail *RequestDetail) adaptor.Error {
  115. switch {
  116. case meta.Mode == mode.AudioTranscription,
  117. meta.Mode == mode.AudioTranslation,
  118. meta.Mode == mode.ImagesEdits:
  119. return nil
  120. case !common.IsJSONContentType(c.GetHeader("Content-Type")):
  121. return nil
  122. default:
  123. reqBody, err := common.GetRequestBodyReusable(c.Request)
  124. if err != nil {
  125. return relaymodel.WrapperErrorWithMessage(
  126. meta.Mode,
  127. http.StatusBadRequest,
  128. "get request body failed: "+err.Error(),
  129. )
  130. }
  131. detail.RequestBody = conv.BytesToString(reqBody)
  132. return nil
  133. }
  134. }
  135. func prepareAndDoRequest(
  136. ctx context.Context,
  137. a adaptor.Adaptor,
  138. c *gin.Context,
  139. meta *meta.Meta,
  140. store adaptor.Store,
  141. ) (*http.Response, adaptor.Error) {
  142. log := common.GetLogger(c)
  143. convertResult, err := a.ConvertRequest(meta, store, c.Request)
  144. if err != nil {
  145. return nil, relaymodel.WrapperErrorWithMessage(
  146. meta.Mode,
  147. http.StatusBadRequest,
  148. "convert request failed: "+err.Error(),
  149. )
  150. }
  151. if closer, ok := convertResult.Body.(io.Closer); ok {
  152. defer closer.Close()
  153. }
  154. if meta.Channel.BaseURL == "" {
  155. meta.Channel.BaseURL = a.DefaultBaseURL()
  156. }
  157. fullRequestURL, err := a.GetRequestURL(meta, store)
  158. if err != nil {
  159. return nil, relaymodel.WrapperErrorWithMessage(
  160. meta.Mode,
  161. http.StatusBadRequest,
  162. "get request url failed: "+err.Error(),
  163. )
  164. }
  165. log.Debugf("request url: %s %s", fullRequestURL.Method, fullRequestURL.URL)
  166. req, err := http.NewRequestWithContext(
  167. ctx,
  168. fullRequestURL.Method,
  169. fullRequestURL.URL,
  170. convertResult.Body,
  171. )
  172. if err != nil {
  173. return nil, relaymodel.WrapperErrorWithMessage(
  174. meta.Mode,
  175. http.StatusBadRequest,
  176. "new request failed: "+err.Error(),
  177. )
  178. }
  179. if err := setupRequestHeader(a, c, meta, store, req, convertResult.Header); err != nil {
  180. return nil, err
  181. }
  182. return doRequest(a, c, meta, store, req)
  183. }
  184. func setupRequestHeader(
  185. a adaptor.Adaptor,
  186. c *gin.Context,
  187. meta *meta.Meta,
  188. store adaptor.Store,
  189. req *http.Request,
  190. header http.Header,
  191. ) adaptor.Error {
  192. for key, value := range header {
  193. req.Header[key] = value
  194. }
  195. if err := a.SetupRequestHeader(meta, store, c, req); err != nil {
  196. return relaymodel.WrapperErrorWithMessage(
  197. meta.Mode,
  198. http.StatusInternalServerError,
  199. "setup request header failed: "+err.Error(),
  200. )
  201. }
  202. return nil
  203. }
  204. func doRequest(
  205. a adaptor.Adaptor,
  206. c *gin.Context,
  207. meta *meta.Meta,
  208. store adaptor.Store,
  209. req *http.Request,
  210. ) (*http.Response, adaptor.Error) {
  211. resp, err := a.DoRequest(meta, store, c, req)
  212. if err != nil {
  213. var adaptorErr adaptor.Error
  214. ok := errors.As(err, &adaptorErr)
  215. if ok {
  216. return nil, adaptorErr
  217. }
  218. if errors.Is(err, context.Canceled) {
  219. return nil, relaymodel.WrapperErrorWithMessage(
  220. meta.Mode,
  221. http.StatusBadRequest,
  222. "request canceled by client: "+err.Error(),
  223. )
  224. }
  225. if errors.Is(err, context.DeadlineExceeded) {
  226. return nil, relaymodel.WrapperErrorWithMessage(
  227. meta.Mode,
  228. http.StatusRequestTimeout,
  229. "request timeout: "+err.Error(),
  230. )
  231. }
  232. if errors.Is(err, io.EOF) {
  233. return nil, relaymodel.WrapperErrorWithMessage(
  234. meta.Mode,
  235. http.StatusServiceUnavailable,
  236. "request eof: "+err.Error(),
  237. )
  238. }
  239. if errors.Is(err, io.ErrUnexpectedEOF) {
  240. return nil, relaymodel.WrapperErrorWithMessage(
  241. meta.Mode,
  242. http.StatusInternalServerError,
  243. "request unexpected eof: "+err.Error(),
  244. )
  245. }
  246. if strings.Contains(err.Error(), "timeout awaiting response headers") {
  247. return nil, relaymodel.WrapperErrorWithMessage(
  248. meta.Mode,
  249. http.StatusRequestTimeout,
  250. "request timeout: "+err.Error(),
  251. )
  252. }
  253. return nil, relaymodel.WrapperErrorWithMessage(
  254. meta.Mode,
  255. http.StatusInternalServerError,
  256. "request error: "+err.Error(),
  257. )
  258. }
  259. return resp, nil
  260. }
  261. func handleResponse(
  262. a adaptor.Adaptor,
  263. c *gin.Context,
  264. meta *meta.Meta,
  265. store adaptor.Store,
  266. resp *http.Response,
  267. detail *RequestDetail,
  268. ) (model.Usage, adaptor.Error) {
  269. buf := getBuffer()
  270. defer putBuffer(buf)
  271. rw := &responseWriter{
  272. ResponseWriter: c.Writer,
  273. body: buf,
  274. }
  275. rawWriter := c.Writer
  276. defer func() {
  277. c.Writer = rawWriter
  278. detail.FirstByteAt = rw.firstByteAt
  279. }()
  280. c.Writer = rw
  281. usage, relayErr := a.DoResponse(meta, store, c, resp)
  282. if relayErr != nil {
  283. respBody, _ := relayErr.MarshalJSON()
  284. detail.ResponseBody = conv.BytesToString(respBody)
  285. } else {
  286. // copy body buffer
  287. // do not use bytes conv
  288. detail.ResponseBody = rw.body.String()
  289. }
  290. return usage, relayErr
  291. }
  292. func updateUsageMetrics(usage model.Usage, log *log.Entry) {
  293. if usage.TotalTokens == 0 {
  294. usage.TotalTokens = usage.InputTokens + usage.OutputTokens
  295. }
  296. if usage.InputTokens > 0 {
  297. log.Data["t_input"] = usage.InputTokens
  298. }
  299. if usage.ImageInputTokens > 0 {
  300. log.Data["t_image_input"] = usage.ImageInputTokens
  301. }
  302. if usage.AudioInputTokens > 0 {
  303. log.Data["t_audio_input"] = usage.AudioInputTokens
  304. }
  305. if usage.OutputTokens > 0 {
  306. log.Data["t_output"] = usage.OutputTokens
  307. }
  308. if usage.TotalTokens > 0 {
  309. log.Data["t_total"] = usage.TotalTokens
  310. }
  311. if usage.CachedTokens > 0 {
  312. log.Data["t_cached"] = usage.CachedTokens
  313. }
  314. if usage.CacheCreationTokens > 0 {
  315. log.Data["t_cache_creation"] = usage.CacheCreationTokens
  316. }
  317. if usage.ReasoningTokens > 0 {
  318. log.Data["t_reason"] = usage.ReasoningTokens
  319. }
  320. if usage.WebSearchCount > 0 {
  321. log.Data["t_websearch"] = usage.WebSearchCount
  322. }
  323. }