main.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. package gemini
  2. import (
  3. "bufio"
  4. "bytes"
  5. "net/http"
  6. "strconv"
  7. "github.com/bytedance/sonic"
  8. "github.com/bytedance/sonic/ast"
  9. "github.com/gin-gonic/gin"
  10. "github.com/labring/aiproxy/core/common"
  11. "github.com/labring/aiproxy/core/model"
  12. "github.com/labring/aiproxy/core/relay/adaptor"
  13. "github.com/labring/aiproxy/core/relay/meta"
  14. relaymodel "github.com/labring/aiproxy/core/relay/model"
  15. "github.com/labring/aiproxy/core/relay/render"
  16. "github.com/labring/aiproxy/core/relay/utils"
  17. )
  18. // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
  19. // Dummy thought signatures for skipping Gemini's validation when the actual signature is unavailable
  20. // See: https://ai.google.dev/gemini-api/docs/thought-signatures#faqs
  21. const (
  22. ThoughtSignatureDummySkipValidator = "skip_thought_signature_validator"
  23. ThoughtSignatureDummyContextEng = "context_engineering_is_the_way_to_go"
  24. )
  25. func CleanFunctionResponseID(node *ast.Node) error {
  26. contents := node.Get("contents")
  27. if !contents.Exists() {
  28. return nil
  29. }
  30. return contents.ForEach(func(_ ast.Sequence, content *ast.Node) bool {
  31. parts := content.Get("parts")
  32. if !parts.Exists() {
  33. return true
  34. }
  35. _ = parts.ForEach(func(_ ast.Sequence, part *ast.Node) bool {
  36. functionResponse := part.Get("functionResponse")
  37. if functionResponse.Exists() {
  38. _, _ = functionResponse.Unset("id")
  39. }
  40. return true
  41. })
  42. return true
  43. })
  44. }
  45. func ensureThoughtSignature(node *ast.Node) error {
  46. contents := node.Get("contents")
  47. if !contents.Exists() {
  48. return nil
  49. }
  50. return contents.ForEach(func(_ ast.Sequence, content *ast.Node) bool {
  51. parts := content.Get("parts")
  52. if !parts.Exists() {
  53. return true
  54. }
  55. _ = parts.ForEach(func(_ ast.Sequence, part *ast.Node) bool {
  56. functionCall := part.Get("functionCall")
  57. if !functionCall.Exists() {
  58. return true
  59. }
  60. thoughtSignature := part.Get("thoughtSignature")
  61. if !thoughtSignature.Exists() {
  62. _, _ = part.Set(
  63. "thoughtSignature",
  64. ast.NewString(ThoughtSignatureDummySkipValidator),
  65. )
  66. } else {
  67. val, _ := thoughtSignature.String()
  68. if val == "" {
  69. _, _ = part.Set("thoughtSignature", ast.NewString(ThoughtSignatureDummySkipValidator))
  70. }
  71. }
  72. return true
  73. })
  74. return true
  75. })
  76. }
  77. func ensureRole(node *ast.Node) error {
  78. contents := node.Get("contents")
  79. if !contents.Exists() {
  80. return nil
  81. }
  82. return contents.ForEach(func(_ ast.Sequence, content *ast.Node) bool {
  83. role := content.Get("role")
  84. if !role.Exists() {
  85. _, _ = content.Set("role", ast.NewString("user"))
  86. } else {
  87. val, _ := role.String()
  88. if val == "" {
  89. _, _ = content.Set("role", ast.NewString("user"))
  90. }
  91. }
  92. return true
  93. })
  94. }
  95. func NativeConvertRequest(
  96. meta *meta.Meta,
  97. req *http.Request,
  98. callback ...func(node *ast.Node) error,
  99. ) (adaptor.ConvertResult, error) {
  100. node, err := common.UnmarshalRequest2NodeReusable(req)
  101. if err != nil {
  102. return adaptor.ConvertResult{}, err
  103. }
  104. err = ensureThoughtSignature(&node)
  105. if err != nil {
  106. return adaptor.ConvertResult{}, err
  107. }
  108. err = ensureRole(&node)
  109. if err != nil {
  110. return adaptor.ConvertResult{}, err
  111. }
  112. for _, callback := range callback {
  113. if callback == nil {
  114. continue
  115. }
  116. err = callback(&node)
  117. if err != nil {
  118. return adaptor.ConvertResult{}, err
  119. }
  120. }
  121. body, err := node.MarshalJSON()
  122. if err != nil {
  123. return adaptor.ConvertResult{}, err
  124. }
  125. return adaptor.ConvertResult{
  126. Header: http.Header{
  127. "Content-Type": {"application/json"},
  128. "Content-Length": {strconv.Itoa(len(body))},
  129. },
  130. Body: bytes.NewReader(body),
  131. }, nil
  132. }
  133. // NativeHandler handles non-streaming responses in native Gemini format (passthrough)
  134. func NativeHandler(
  135. meta *meta.Meta,
  136. c *gin.Context,
  137. resp *http.Response,
  138. ) (model.Usage, adaptor.Error) {
  139. if resp.StatusCode != http.StatusOK {
  140. return model.Usage{}, ErrorHandler(resp)
  141. }
  142. defer resp.Body.Close()
  143. var geminiResponse relaymodel.GeminiChatResponse
  144. if err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&geminiResponse); err != nil {
  145. return model.Usage{}, relaymodel.WrapperOpenAIError(
  146. err,
  147. "unmarshal_response_body_failed",
  148. http.StatusInternalServerError,
  149. )
  150. }
  151. // Calculate usage
  152. usage := model.Usage{}
  153. if geminiResponse.UsageMetadata != nil {
  154. usage = geminiResponse.UsageMetadata.ToUsage().ToModelUsage()
  155. }
  156. // Pass through the response as-is
  157. jsonResponse, err := sonic.Marshal(geminiResponse)
  158. if err != nil {
  159. return usage, relaymodel.WrapperOpenAIError(
  160. err,
  161. "marshal_response_body_failed",
  162. http.StatusInternalServerError,
  163. )
  164. }
  165. c.Writer.Header().Set("Content-Type", "application/json")
  166. c.Writer.Header().Set("Content-Length", strconv.Itoa(len(jsonResponse)))
  167. _, _ = c.Writer.Write(jsonResponse)
  168. return usage, nil
  169. }
  170. // NativeStreamHandler handles streaming responses in native Gemini format (passthrough)
  171. func NativeStreamHandler(
  172. meta *meta.Meta,
  173. c *gin.Context,
  174. resp *http.Response,
  175. ) (model.Usage, adaptor.Error) {
  176. if resp.StatusCode != http.StatusOK {
  177. return model.Usage{}, ErrorHandler(resp)
  178. }
  179. defer resp.Body.Close()
  180. log := common.GetLogger(c)
  181. scanner := bufio.NewScanner(resp.Body)
  182. buf := utils.GetScannerBuffer()
  183. defer utils.PutScannerBuffer(buf)
  184. scanner.Buffer(*buf, cap(*buf))
  185. usage := model.Usage{}
  186. for scanner.Scan() {
  187. data := scanner.Bytes()
  188. if !render.IsValidSSEData(data) {
  189. continue
  190. }
  191. data = render.ExtractSSEData(data)
  192. // Parse to extract usage metadata
  193. var geminiResp relaymodel.GeminiChatResponse
  194. if err := sonic.Unmarshal(data, &geminiResp); err == nil {
  195. if geminiResp.UsageMetadata != nil {
  196. usage = geminiResp.UsageMetadata.ToUsage().ToModelUsage()
  197. }
  198. }
  199. // Pass through the data as-is
  200. render.GeminiBytesData(c, data)
  201. }
  202. if err := scanner.Err(); err != nil {
  203. log.Error("error reading stream: " + err.Error())
  204. }
  205. return usage, nil
  206. }