gemini.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. package dto
  2. import (
  3. "encoding/json"
  4. "github.com/gin-gonic/gin"
  5. "one-api/common"
  6. "one-api/logger"
  7. "one-api/types"
  8. "strings"
  9. )
  10. type GeminiChatRequest struct {
  11. Contents []GeminiChatContent `json:"contents"`
  12. SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
  13. GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
  14. Tools json.RawMessage `json:"tools,omitempty"`
  15. SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
  16. }
  17. func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
  18. var files []*types.FileMeta = make([]*types.FileMeta, 0)
  19. var maxTokens int
  20. if r.GenerationConfig.MaxOutputTokens > 0 {
  21. maxTokens = int(r.GenerationConfig.MaxOutputTokens)
  22. }
  23. var inputTexts []string
  24. for _, content := range r.Contents {
  25. for _, part := range content.Parts {
  26. if part.Text != "" {
  27. inputTexts = append(inputTexts, part.Text)
  28. }
  29. if part.InlineData != nil && part.InlineData.Data != "" {
  30. if strings.HasPrefix(part.InlineData.MimeType, "image/") {
  31. files = append(files, &types.FileMeta{
  32. FileType: types.FileTypeImage,
  33. OriginData: part.InlineData.Data,
  34. })
  35. } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
  36. files = append(files, &types.FileMeta{
  37. FileType: types.FileTypeAudio,
  38. OriginData: part.InlineData.Data,
  39. })
  40. } else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
  41. files = append(files, &types.FileMeta{
  42. FileType: types.FileTypeVideo,
  43. OriginData: part.InlineData.Data,
  44. })
  45. } else {
  46. files = append(files, &types.FileMeta{
  47. FileType: types.FileTypeFile,
  48. OriginData: part.InlineData.Data,
  49. })
  50. }
  51. }
  52. }
  53. }
  54. inputText := strings.Join(inputTexts, "\n")
  55. return &types.TokenCountMeta{
  56. CombineText: inputText,
  57. Files: files,
  58. MaxTokens: maxTokens,
  59. }
  60. }
  61. func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
  62. if c.Query("alt") == "sse" {
  63. return true
  64. }
  65. return false
  66. }
  67. func (r *GeminiChatRequest) SetModelName(modelName string) {
  68. // GeminiChatRequest does not have a model field, so this method does nothing.
  69. }
  70. func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
  71. var tools []GeminiChatTool
  72. if strings.HasSuffix(string(r.Tools), "[") {
  73. // is array
  74. if err := common.Unmarshal(r.Tools, &tools); err != nil {
  75. logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
  76. return nil
  77. }
  78. } else if strings.HasPrefix(string(r.Tools), "{") {
  79. // is object
  80. singleTool := GeminiChatTool{}
  81. if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
  82. logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
  83. return nil
  84. }
  85. tools = []GeminiChatTool{singleTool}
  86. }
  87. return tools
  88. }
  89. func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
  90. if len(tools) == 0 {
  91. r.Tools = json.RawMessage("[]")
  92. return
  93. }
  94. // Marshal the tools to JSON
  95. data, err := common.Marshal(tools)
  96. if err != nil {
  97. logger.LogError(nil, "error_marshalling_tools: "+err.Error())
  98. return
  99. }
  100. r.Tools = data
  101. }
  102. type GeminiThinkingConfig struct {
  103. IncludeThoughts bool `json:"includeThoughts,omitempty"`
  104. ThinkingBudget *int `json:"thinkingBudget,omitempty"`
  105. }
  106. func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
  107. c.ThinkingBudget = &budget
  108. }
  109. type GeminiInlineData struct {
  110. MimeType string `json:"mimeType"`
  111. Data string `json:"data"`
  112. }
  113. // UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
  114. func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
  115. type Alias GeminiInlineData // Use type alias to avoid recursion
  116. var aux struct {
  117. Alias
  118. MimeTypeSnake string `json:"mime_type"`
  119. }
  120. if err := common.Unmarshal(data, &aux); err != nil {
  121. return err
  122. }
  123. *g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
  124. // Prioritize snake_case if present
  125. if aux.MimeTypeSnake != "" {
  126. g.MimeType = aux.MimeTypeSnake
  127. } else if aux.MimeType != "" { // Fallback to camelCase from Alias
  128. g.MimeType = aux.MimeType
  129. }
  130. // g.Data would be populated by aux.Alias.Data
  131. return nil
  132. }
  133. type FunctionCall struct {
  134. FunctionName string `json:"name"`
  135. Arguments any `json:"args"`
  136. }
  137. type GeminiFunctionResponse struct {
  138. Name string `json:"name"`
  139. Response map[string]interface{} `json:"response"`
  140. }
  141. type GeminiPartExecutableCode struct {
  142. Language string `json:"language,omitempty"`
  143. Code string `json:"code,omitempty"`
  144. }
  145. type GeminiPartCodeExecutionResult struct {
  146. Outcome string `json:"outcome,omitempty"`
  147. Output string `json:"output,omitempty"`
  148. }
  149. type GeminiFileData struct {
  150. MimeType string `json:"mimeType,omitempty"`
  151. FileUri string `json:"fileUri,omitempty"`
  152. }
  153. type GeminiPart struct {
  154. Text string `json:"text,omitempty"`
  155. Thought bool `json:"thought,omitempty"`
  156. InlineData *GeminiInlineData `json:"inlineData,omitempty"`
  157. FunctionCall *FunctionCall `json:"functionCall,omitempty"`
  158. FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
  159. FileData *GeminiFileData `json:"fileData,omitempty"`
  160. ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
  161. CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
  162. }
  163. // UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
  164. func (p *GeminiPart) UnmarshalJSON(data []byte) error {
  165. // Alias to avoid recursion during unmarshalling
  166. type Alias GeminiPart
  167. var aux struct {
  168. Alias
  169. InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
  170. }
  171. if err := common.Unmarshal(data, &aux); err != nil {
  172. return err
  173. }
  174. // Assign fields from alias
  175. *p = GeminiPart(aux.Alias)
  176. // Prioritize snake_case for InlineData if present
  177. if aux.InlineDataSnake != nil {
  178. p.InlineData = aux.InlineDataSnake
  179. } else if aux.InlineData != nil { // Fallback to camelCase from Alias
  180. p.InlineData = aux.InlineData
  181. }
  182. // Other fields like Text, FunctionCall etc. are already populated via aux.Alias
  183. return nil
  184. }
  185. type GeminiChatContent struct {
  186. Role string `json:"role,omitempty"`
  187. Parts []GeminiPart `json:"parts"`
  188. }
  189. type GeminiChatSafetySettings struct {
  190. Category string `json:"category"`
  191. Threshold string `json:"threshold"`
  192. }
  193. type GeminiChatTool struct {
  194. GoogleSearch any `json:"googleSearch,omitempty"`
  195. GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
  196. CodeExecution any `json:"codeExecution,omitempty"`
  197. FunctionDeclarations any `json:"functionDeclarations,omitempty"`
  198. }
  199. type GeminiChatGenerationConfig struct {
  200. Temperature *float64 `json:"temperature,omitempty"`
  201. TopP float64 `json:"topP,omitempty"`
  202. TopK float64 `json:"topK,omitempty"`
  203. MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
  204. CandidateCount int `json:"candidateCount,omitempty"`
  205. StopSequences []string `json:"stopSequences,omitempty"`
  206. ResponseMimeType string `json:"responseMimeType,omitempty"`
  207. ResponseSchema any `json:"responseSchema,omitempty"`
  208. Seed int64 `json:"seed,omitempty"`
  209. ResponseModalities []string `json:"responseModalities,omitempty"`
  210. ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
  211. SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
  212. }
  213. type GeminiChatCandidate struct {
  214. Content GeminiChatContent `json:"content"`
  215. FinishReason *string `json:"finishReason"`
  216. Index int64 `json:"index"`
  217. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  218. }
  219. type GeminiChatSafetyRating struct {
  220. Category string `json:"category"`
  221. Probability string `json:"probability"`
  222. }
  223. type GeminiChatPromptFeedback struct {
  224. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  225. }
  226. type GeminiChatResponse struct {
  227. Candidates []GeminiChatCandidate `json:"candidates"`
  228. PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
  229. UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
  230. }
  231. type GeminiUsageMetadata struct {
  232. PromptTokenCount int `json:"promptTokenCount"`
  233. CandidatesTokenCount int `json:"candidatesTokenCount"`
  234. TotalTokenCount int `json:"totalTokenCount"`
  235. ThoughtsTokenCount int `json:"thoughtsTokenCount"`
  236. PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
  237. }
  238. type GeminiPromptTokensDetails struct {
  239. Modality string `json:"modality"`
  240. TokenCount int `json:"tokenCount"`
  241. }
  242. // Imagen related structs
  243. type GeminiImageRequest struct {
  244. Instances []GeminiImageInstance `json:"instances"`
  245. Parameters GeminiImageParameters `json:"parameters"`
  246. }
  247. type GeminiImageInstance struct {
  248. Prompt string `json:"prompt"`
  249. }
  250. type GeminiImageParameters struct {
  251. SampleCount int `json:"sampleCount,omitempty"`
  252. AspectRatio string `json:"aspectRatio,omitempty"`
  253. PersonGeneration string `json:"personGeneration,omitempty"`
  254. }
  255. type GeminiImageResponse struct {
  256. Predictions []GeminiImagePrediction `json:"predictions"`
  257. }
  258. type GeminiImagePrediction struct {
  259. MimeType string `json:"mimeType"`
  260. BytesBase64Encoded string `json:"bytesBase64Encoded"`
  261. RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
  262. SafetyAttributes any `json:"safetyAttributes,omitempty"`
  263. }
  264. // Embedding related structs
  265. type GeminiEmbeddingRequest struct {
  266. Model string `json:"model,omitempty"`
  267. Content GeminiChatContent `json:"content"`
  268. TaskType string `json:"taskType,omitempty"`
  269. Title string `json:"title,omitempty"`
  270. OutputDimensionality int `json:"outputDimensionality,omitempty"`
  271. }
  272. func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
  273. // Gemini embedding requests are not streamed
  274. return false
  275. }
  276. func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
  277. var inputTexts []string
  278. for _, part := range r.Content.Parts {
  279. if part.Text != "" {
  280. inputTexts = append(inputTexts, part.Text)
  281. }
  282. }
  283. inputText := strings.Join(inputTexts, "\n")
  284. return &types.TokenCountMeta{
  285. CombineText: inputText,
  286. }
  287. }
  288. func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
  289. if modelName != "" {
  290. r.Model = modelName
  291. }
  292. }
  293. type GeminiBatchEmbeddingRequest struct {
  294. Requests []*GeminiEmbeddingRequest `json:"requests"`
  295. }
  296. func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
  297. // Gemini batch embedding requests are not streamed
  298. return false
  299. }
  300. func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
  301. var inputTexts []string
  302. for _, request := range r.Requests {
  303. meta := request.GetTokenCountMeta()
  304. if meta != nil && meta.CombineText != "" {
  305. inputTexts = append(inputTexts, meta.CombineText)
  306. }
  307. }
  308. inputText := strings.Join(inputTexts, "\n")
  309. return &types.TokenCountMeta{
  310. CombineText: inputText,
  311. }
  312. }
  313. func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
  314. if modelName != "" {
  315. for _, req := range r.Requests {
  316. req.SetModelName(modelName)
  317. }
  318. }
  319. }
  320. type GeminiEmbeddingResponse struct {
  321. Embedding ContentEmbedding `json:"embedding"`
  322. }
  323. type GeminiBatchEmbeddingResponse struct {
  324. Embeddings []*ContentEmbedding `json:"embeddings"`
  325. }
  326. type ContentEmbedding struct {
  327. Values []float64 `json:"values"`
  328. }