main.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package doubao
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/url"
  6. "strings"
  7. "github.com/gin-gonic/gin"
  8. "github.com/labring/aiproxy/core/model"
  9. "github.com/labring/aiproxy/core/relay/adaptor"
  10. "github.com/labring/aiproxy/core/relay/adaptor/openai"
  11. "github.com/labring/aiproxy/core/relay/meta"
  12. "github.com/labring/aiproxy/core/relay/mode"
  13. "github.com/labring/aiproxy/core/relay/utils"
  14. )
  15. func GetRequestURL(meta *meta.Meta) (adaptor.RequestURL, error) {
  16. u := meta.Channel.BaseURL
  17. switch meta.Mode {
  18. case mode.ChatCompletions, mode.Anthropic:
  19. if strings.HasPrefix(meta.ActualModel, "bot-") {
  20. url, err := url.JoinPath(u, "/api/v3/bots/chat/completions")
  21. if err != nil {
  22. return adaptor.RequestURL{}, err
  23. }
  24. return adaptor.RequestURL{
  25. Method: http.MethodPost,
  26. URL: url,
  27. }, nil
  28. }
  29. url, err := url.JoinPath(u, "/api/v3/chat/completions")
  30. if err != nil {
  31. return adaptor.RequestURL{}, err
  32. }
  33. return adaptor.RequestURL{
  34. Method: http.MethodPost,
  35. URL: url,
  36. }, nil
  37. case mode.Embeddings:
  38. if strings.Contains(meta.ActualModel, "vision") {
  39. url, err := url.JoinPath(u, "/api/v3/embeddings/multimodal")
  40. if err != nil {
  41. return adaptor.RequestURL{}, err
  42. }
  43. return adaptor.RequestURL{
  44. Method: http.MethodPost,
  45. URL: url,
  46. }, nil
  47. }
  48. url, err := url.JoinPath(u, "/api/v3/embeddings")
  49. if err != nil {
  50. return adaptor.RequestURL{}, err
  51. }
  52. return adaptor.RequestURL{
  53. Method: http.MethodPost,
  54. URL: url,
  55. }, nil
  56. default:
  57. return adaptor.RequestURL{}, fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
  58. }
  59. }
  60. type Adaptor struct {
  61. openai.Adaptor
  62. }
  63. const baseURL = "https://ark.cn-beijing.volces.com"
  64. func (a *Adaptor) DefaultBaseURL() string {
  65. return baseURL
  66. }
  67. func (a *Adaptor) SupportMode(m mode.Mode) bool {
  68. return m == mode.ChatCompletions ||
  69. m == mode.Anthropic ||
  70. m == mode.Embeddings
  71. }
  72. func (a *Adaptor) Metadata() adaptor.Metadata {
  73. return adaptor.Metadata{
  74. Readme: "Bot support\nNetwork search metering support",
  75. Models: ModelList,
  76. }
  77. }
  78. func (a *Adaptor) GetRequestURL(
  79. meta *meta.Meta,
  80. _ adaptor.Store,
  81. _ *gin.Context,
  82. ) (adaptor.RequestURL, error) {
  83. return GetRequestURL(meta)
  84. }
  85. func (a *Adaptor) ConvertRequest(
  86. meta *meta.Meta,
  87. store adaptor.Store,
  88. req *http.Request,
  89. ) (adaptor.ConvertResult, error) {
  90. switch meta.Mode {
  91. case mode.Embeddings:
  92. if strings.Contains(meta.ActualModel, "vision") {
  93. return openai.ConvertEmbeddingsRequest(meta, req, false, patchEmbeddingsVisionInput)
  94. }
  95. return openai.ConvertEmbeddingsRequest(meta, req, true)
  96. case mode.ChatCompletions:
  97. return ConvertChatCompletionsRequest(meta, req)
  98. default:
  99. return openai.ConvertRequest(meta, store, req)
  100. }
  101. }
  102. func (a *Adaptor) DoResponse(
  103. meta *meta.Meta,
  104. store adaptor.Store,
  105. c *gin.Context,
  106. resp *http.Response,
  107. ) (usage model.Usage, err adaptor.Error) {
  108. switch meta.Mode {
  109. case mode.ChatCompletions:
  110. websearchCount := int64(0)
  111. if utils.IsStreamResponse(resp) {
  112. usage, err = openai.StreamHandler(meta, c, resp, newHandlerPreHandler(&websearchCount))
  113. } else {
  114. usage, err = openai.Handler(meta, c, resp, newHandlerPreHandler(&websearchCount))
  115. }
  116. usage.WebSearchCount += model.ZeroNullInt64(websearchCount)
  117. case mode.Embeddings:
  118. usage, err = openai.EmbeddingsHandler(
  119. meta,
  120. c,
  121. resp,
  122. embeddingPreHandler,
  123. )
  124. default:
  125. return openai.DoResponse(meta, store, c, resp)
  126. }
  127. return usage, err
  128. }
  129. func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) {
  130. return 0, adaptor.ErrGetBalanceNotImplemented
  131. }