| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- package controller
- import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strings"
- "sync"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/common/conv"
- "github.com/labring/aiproxy/core/model"
- "github.com/labring/aiproxy/core/relay/adaptor"
- "github.com/labring/aiproxy/core/relay/meta"
- "github.com/labring/aiproxy/core/relay/mode"
- relaymodel "github.com/labring/aiproxy/core/relay/model"
- log "github.com/sirupsen/logrus"
- )
- const (
- // 0.5MB
- maxBufferSize = 512 * 1024
- )
- type responseWriter struct {
- gin.ResponseWriter
- body *bytes.Buffer
- firstByteAt time.Time
- }
- func (rw *responseWriter) Write(b []byte) (int, error) {
- if rw.firstByteAt.IsZero() {
- rw.firstByteAt = time.Now()
- }
- if rw.body.Len()+len(b) <= maxBufferSize {
- rw.body.Write(b)
- } else {
- rw.body.Write(b[:maxBufferSize-rw.body.Len()])
- }
- return rw.ResponseWriter.Write(b)
- }
- func (rw *responseWriter) WriteString(s string) (int, error) {
- return rw.Write(conv.StringToBytes(s))
- }
- var bufferPool = sync.Pool{
- New: func() any {
- return bytes.NewBuffer(make([]byte, 0, maxBufferSize))
- },
- }
- func getBuffer() *bytes.Buffer {
- v, ok := bufferPool.Get().(*bytes.Buffer)
- if !ok {
- panic(fmt.Sprintf("buffer type error: %T, %v", v, v))
- }
- return v
- }
- func putBuffer(buf *bytes.Buffer) {
- buf.Reset()
- if buf.Cap() > maxBufferSize {
- return
- }
- bufferPool.Put(buf)
- }
- type RequestDetail struct {
- RequestBody string
- ResponseBody string
- FirstByteAt time.Time
- }
- func DoHelper(
- a adaptor.Adaptor,
- c *gin.Context,
- meta *meta.Meta,
- store adaptor.Store,
- ) (
- model.Usage,
- *RequestDetail,
- adaptor.Error,
- ) {
- detail := RequestDetail{}
- if err := storeRequestBody(meta, c, &detail); err != nil {
- return model.Usage{}, nil, err
- }
- // donot use c.Request.Context() because it will be canceled by the client
- ctx := context.Background()
- resp, err := prepareAndDoRequest(ctx, a, c, meta, store)
- if err != nil {
- return model.Usage{}, &detail, err
- }
- if resp == nil {
- relayErr := relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusInternalServerError,
- "response is nil",
- )
- respBody, _ := relayErr.MarshalJSON()
- detail.ResponseBody = conv.BytesToString(respBody)
- return model.Usage{}, &detail, relayErr
- }
- if resp.Body != nil {
- defer resp.Body.Close()
- }
- usage, relayErr := handleResponse(a, c, meta, store, resp, &detail)
- if relayErr != nil {
- return model.Usage{}, &detail, relayErr
- }
- log := common.GetLogger(c)
- updateUsageMetrics(usage, log)
- if !detail.FirstByteAt.IsZero() {
- ttfb := detail.FirstByteAt.Sub(meta.RequestAt)
- log.Data["ttfb"] = common.TruncateDuration(ttfb).String()
- }
- return usage, &detail, nil
- }
- func storeRequestBody(meta *meta.Meta, c *gin.Context, detail *RequestDetail) adaptor.Error {
- switch {
- case meta.Mode == mode.AudioTranscription,
- meta.Mode == mode.AudioTranslation,
- meta.Mode == mode.ImagesEdits:
- return nil
- case !common.IsJSONContentType(c.GetHeader("Content-Type")):
- return nil
- default:
- reqBody, err := common.GetRequestBodyReusable(c.Request)
- if err != nil {
- return relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusBadRequest,
- "get request body failed: "+err.Error(),
- )
- }
- detail.RequestBody = conv.BytesToString(reqBody)
- return nil
- }
- }
- func prepareAndDoRequest(
- ctx context.Context,
- a adaptor.Adaptor,
- c *gin.Context,
- meta *meta.Meta,
- store adaptor.Store,
- ) (*http.Response, adaptor.Error) {
- log := common.GetLogger(c)
- convertResult, err := a.ConvertRequest(meta, store, c.Request)
- if err != nil {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusBadRequest,
- "convert request failed: "+err.Error(),
- )
- }
- if closer, ok := convertResult.Body.(io.Closer); ok {
- defer closer.Close()
- }
- if meta.Channel.BaseURL == "" {
- meta.Channel.BaseURL = a.DefaultBaseURL()
- }
- fullRequestURL, err := a.GetRequestURL(meta, store)
- if err != nil {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusBadRequest,
- "get request url failed: "+err.Error(),
- )
- }
- log.Debugf("request url: %s %s", fullRequestURL.Method, fullRequestURL.URL)
- req, err := http.NewRequestWithContext(
- ctx,
- fullRequestURL.Method,
- fullRequestURL.URL,
- convertResult.Body,
- )
- if err != nil {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusBadRequest,
- "new request failed: "+err.Error(),
- )
- }
- if err := setupRequestHeader(a, c, meta, store, req, convertResult.Header); err != nil {
- return nil, err
- }
- return doRequest(a, c, meta, store, req)
- }
- func setupRequestHeader(
- a adaptor.Adaptor,
- c *gin.Context,
- meta *meta.Meta,
- store adaptor.Store,
- req *http.Request,
- header http.Header,
- ) adaptor.Error {
- for key, value := range header {
- req.Header[key] = value
- }
- if err := a.SetupRequestHeader(meta, store, c, req); err != nil {
- return relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusInternalServerError,
- "setup request header failed: "+err.Error(),
- )
- }
- return nil
- }
- func doRequest(
- a adaptor.Adaptor,
- c *gin.Context,
- meta *meta.Meta,
- store adaptor.Store,
- req *http.Request,
- ) (*http.Response, adaptor.Error) {
- resp, err := a.DoRequest(meta, store, c, req)
- if err != nil {
- if errors.Is(err, context.Canceled) {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusBadRequest,
- "request canceled by client: "+err.Error(),
- )
- }
- if errors.Is(err, context.DeadlineExceeded) {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusRequestTimeout,
- "request timeout: "+err.Error(),
- )
- }
- if errors.Is(err, io.EOF) {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusServiceUnavailable,
- "request eof: "+err.Error(),
- )
- }
- if errors.Is(err, io.ErrUnexpectedEOF) {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusInternalServerError,
- "request unexpected eof: "+err.Error(),
- )
- }
- if strings.Contains(err.Error(), "timeout awaiting response headers") {
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusRequestTimeout,
- "request timeout: "+err.Error(),
- )
- }
- return nil, relaymodel.WrapperErrorWithMessage(
- meta.Mode,
- http.StatusInternalServerError,
- "request error: "+err.Error(),
- )
- }
- return resp, nil
- }
- func handleResponse(
- a adaptor.Adaptor,
- c *gin.Context,
- meta *meta.Meta,
- store adaptor.Store,
- resp *http.Response,
- detail *RequestDetail,
- ) (model.Usage, adaptor.Error) {
- buf := getBuffer()
- defer putBuffer(buf)
- rw := &responseWriter{
- ResponseWriter: c.Writer,
- body: buf,
- }
- rawWriter := c.Writer
- defer func() {
- c.Writer = rawWriter
- detail.FirstByteAt = rw.firstByteAt
- }()
- c.Writer = rw
- usage, relayErr := a.DoResponse(meta, store, c, resp)
- if relayErr != nil {
- respBody, _ := relayErr.MarshalJSON()
- detail.ResponseBody = conv.BytesToString(respBody)
- } else {
- // copy body buffer
- // do not use bytes conv
- detail.ResponseBody = rw.body.String()
- }
- return usage, relayErr
- }
- func updateUsageMetrics(usage model.Usage, log *log.Entry) {
- if usage.TotalTokens == 0 {
- usage.TotalTokens = usage.InputTokens + usage.OutputTokens
- }
- if usage.InputTokens > 0 {
- log.Data["t_input"] = usage.InputTokens
- }
- if usage.ImageInputTokens > 0 {
- log.Data["t_image_input"] = usage.ImageInputTokens
- }
- if usage.AudioInputTokens > 0 {
- log.Data["t_audio_input"] = usage.AudioInputTokens
- }
- if usage.OutputTokens > 0 {
- log.Data["t_output"] = usage.OutputTokens
- }
- if usage.TotalTokens > 0 {
- log.Data["t_total"] = usage.TotalTokens
- }
- if usage.CachedTokens > 0 {
- log.Data["t_cached"] = usage.CachedTokens
- }
- if usage.CacheCreationTokens > 0 {
- log.Data["t_cache_creation"] = usage.CacheCreationTokens
- }
- if usage.ReasoningTokens > 0 {
- log.Data["t_reason"] = usage.ReasoningTokens
- }
- if usage.WebSearchCount > 0 {
- log.Data["t_websearch"] = usage.WebSearchCount
- }
- }
|