gemini.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. package dto
  2. import (
  3. "encoding/json"
  4. "strings"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/logger"
  7. "github.com/QuantumNous/new-api/types"
  8. "github.com/gin-gonic/gin"
  9. )
  10. type GeminiChatRequest struct {
  11. Requests []GeminiChatRequest `json:"requests,omitempty"` // For batch requests
  12. Contents []GeminiChatContent `json:"contents"`
  13. SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
  14. GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
  15. Tools json.RawMessage `json:"tools,omitempty"`
  16. ToolConfig *ToolConfig `json:"toolConfig,omitempty"`
  17. SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
  18. CachedContent string `json:"cachedContent,omitempty"`
  19. }
  20. // UnmarshalJSON allows GeminiChatRequest to accept both snake_case and camelCase fields.
  21. func (r *GeminiChatRequest) UnmarshalJSON(data []byte) error {
  22. type Alias GeminiChatRequest
  23. var aux struct {
  24. Alias
  25. SystemInstructionSnake *GeminiChatContent `json:"system_instruction,omitempty"`
  26. }
  27. if err := common.Unmarshal(data, &aux); err != nil {
  28. return err
  29. }
  30. *r = GeminiChatRequest(aux.Alias)
  31. if aux.SystemInstructionSnake != nil {
  32. r.SystemInstructions = aux.SystemInstructionSnake
  33. }
  34. return nil
  35. }
  36. type ToolConfig struct {
  37. FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
  38. RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
  39. }
  40. type FunctionCallingConfig struct {
  41. Mode FunctionCallingConfigMode `json:"mode,omitempty"`
  42. AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
  43. }
  44. type FunctionCallingConfigMode string
  45. type RetrievalConfig struct {
  46. LatLng *LatLng `json:"latLng,omitempty"`
  47. LanguageCode string `json:"languageCode,omitempty"`
  48. }
  49. type LatLng struct {
  50. Latitude *float64 `json:"latitude,omitempty"`
  51. Longitude *float64 `json:"longitude,omitempty"`
  52. }
  53. func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
  54. var files []*types.FileMeta = make([]*types.FileMeta, 0)
  55. var maxTokens int
  56. if r.GenerationConfig.MaxOutputTokens > 0 {
  57. maxTokens = int(r.GenerationConfig.MaxOutputTokens)
  58. }
  59. var inputTexts []string
  60. for _, content := range r.Contents {
  61. for _, part := range content.Parts {
  62. if part.Text != "" {
  63. inputTexts = append(inputTexts, part.Text)
  64. }
  65. if part.InlineData != nil && part.InlineData.Data != "" {
  66. if strings.HasPrefix(part.InlineData.MimeType, "image/") {
  67. files = append(files, &types.FileMeta{
  68. FileType: types.FileTypeImage,
  69. OriginData: part.InlineData.Data,
  70. })
  71. } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
  72. files = append(files, &types.FileMeta{
  73. FileType: types.FileTypeAudio,
  74. OriginData: part.InlineData.Data,
  75. })
  76. } else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
  77. files = append(files, &types.FileMeta{
  78. FileType: types.FileTypeVideo,
  79. OriginData: part.InlineData.Data,
  80. })
  81. } else {
  82. files = append(files, &types.FileMeta{
  83. FileType: types.FileTypeFile,
  84. OriginData: part.InlineData.Data,
  85. })
  86. }
  87. }
  88. }
  89. }
  90. inputText := strings.Join(inputTexts, "\n")
  91. return &types.TokenCountMeta{
  92. CombineText: inputText,
  93. Files: files,
  94. MaxTokens: maxTokens,
  95. }
  96. }
  97. func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
  98. if c.Query("alt") == "sse" {
  99. return true
  100. }
  101. return false
  102. }
  103. func (r *GeminiChatRequest) SetModelName(modelName string) {
  104. // GeminiChatRequest does not have a model field, so this method does nothing.
  105. }
  106. func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
  107. var tools []GeminiChatTool
  108. if strings.HasSuffix(string(r.Tools), "[") {
  109. // is array
  110. if err := common.Unmarshal(r.Tools, &tools); err != nil {
  111. logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
  112. return nil
  113. }
  114. } else if strings.HasPrefix(string(r.Tools), "{") {
  115. // is object
  116. singleTool := GeminiChatTool{}
  117. if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
  118. logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
  119. return nil
  120. }
  121. tools = []GeminiChatTool{singleTool}
  122. }
  123. return tools
  124. }
  125. func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
  126. if len(tools) == 0 {
  127. r.Tools = json.RawMessage("[]")
  128. return
  129. }
  130. // Marshal the tools to JSON
  131. data, err := common.Marshal(tools)
  132. if err != nil {
  133. logger.LogError(nil, "error_marshalling_tools: "+err.Error())
  134. return
  135. }
  136. r.Tools = data
  137. }
  138. type GeminiThinkingConfig struct {
  139. IncludeThoughts bool `json:"includeThoughts,omitempty"`
  140. ThinkingBudget *int `json:"thinkingBudget,omitempty"`
  141. // TODO Conflict with thinkingbudget.
  142. ThinkingLevel string `json:"thinkingLevel,omitempty"`
  143. }
  144. // UnmarshalJSON allows GeminiThinkingConfig to accept both snake_case and camelCase fields.
  145. func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
  146. type Alias GeminiThinkingConfig
  147. var aux struct {
  148. Alias
  149. IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
  150. ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
  151. ThinkingLevelSnake string `json:"thinking_level,omitempty"`
  152. }
  153. if err := common.Unmarshal(data, &aux); err != nil {
  154. return err
  155. }
  156. *c = GeminiThinkingConfig(aux.Alias)
  157. if aux.IncludeThoughtsSnake != nil {
  158. c.IncludeThoughts = *aux.IncludeThoughtsSnake
  159. }
  160. if aux.ThinkingBudgetSnake != nil {
  161. c.ThinkingBudget = aux.ThinkingBudgetSnake
  162. }
  163. if aux.ThinkingLevelSnake != "" {
  164. c.ThinkingLevel = aux.ThinkingLevelSnake
  165. }
  166. return nil
  167. }
  168. func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
  169. c.ThinkingBudget = &budget
  170. }
  171. type GeminiInlineData struct {
  172. MimeType string `json:"mimeType"`
  173. Data string `json:"data"`
  174. }
  175. // UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
  176. func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
  177. type Alias GeminiInlineData // Use type alias to avoid recursion
  178. var aux struct {
  179. Alias
  180. MimeTypeSnake string `json:"mime_type"`
  181. }
  182. if err := common.Unmarshal(data, &aux); err != nil {
  183. return err
  184. }
  185. *g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
  186. // Prioritize snake_case if present
  187. if aux.MimeTypeSnake != "" {
  188. g.MimeType = aux.MimeTypeSnake
  189. } else if aux.MimeType != "" { // Fallback to camelCase from Alias
  190. g.MimeType = aux.MimeType
  191. }
  192. // g.Data would be populated by aux.Alias.Data
  193. return nil
  194. }
  195. type FunctionCall struct {
  196. FunctionName string `json:"name"`
  197. Arguments any `json:"args"`
  198. }
  199. type GeminiFunctionResponse struct {
  200. Name string `json:"name"`
  201. Response map[string]interface{} `json:"response"`
  202. WillContinue json.RawMessage `json:"willContinue,omitempty"`
  203. Scheduling json.RawMessage `json:"scheduling,omitempty"`
  204. Parts json.RawMessage `json:"parts,omitempty"`
  205. ID json.RawMessage `json:"id,omitempty"`
  206. }
  207. type GeminiPartExecutableCode struct {
  208. Language string `json:"language,omitempty"`
  209. Code string `json:"code,omitempty"`
  210. }
  211. type GeminiPartCodeExecutionResult struct {
  212. Outcome string `json:"outcome,omitempty"`
  213. Output string `json:"output,omitempty"`
  214. }
  215. type GeminiFileData struct {
  216. MimeType string `json:"mimeType,omitempty"`
  217. FileUri string `json:"fileUri,omitempty"`
  218. }
  219. type GeminiPart struct {
  220. Text string `json:"text,omitempty"`
  221. Thought bool `json:"thought,omitempty"`
  222. InlineData *GeminiInlineData `json:"inlineData,omitempty"`
  223. FunctionCall *FunctionCall `json:"functionCall,omitempty"`
  224. ThoughtSignature json.RawMessage `json:"thoughtSignature,omitempty"`
  225. FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
  226. // Optional. Media resolution for the input media.
  227. MediaResolution json.RawMessage `json:"mediaResolution,omitempty"`
  228. VideoMetadata json.RawMessage `json:"videoMetadata,omitempty"`
  229. FileData *GeminiFileData `json:"fileData,omitempty"`
  230. ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
  231. CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
  232. }
  233. // UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
  234. func (p *GeminiPart) UnmarshalJSON(data []byte) error {
  235. // Alias to avoid recursion during unmarshalling
  236. type Alias GeminiPart
  237. var aux struct {
  238. Alias
  239. InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
  240. }
  241. if err := common.Unmarshal(data, &aux); err != nil {
  242. return err
  243. }
  244. // Assign fields from alias
  245. *p = GeminiPart(aux.Alias)
  246. // Prioritize snake_case for InlineData if present
  247. if aux.InlineDataSnake != nil {
  248. p.InlineData = aux.InlineDataSnake
  249. } else if aux.InlineData != nil { // Fallback to camelCase from Alias
  250. p.InlineData = aux.InlineData
  251. }
  252. // Other fields like Text, FunctionCall etc. are already populated via aux.Alias
  253. return nil
  254. }
  255. type GeminiChatContent struct {
  256. Role string `json:"role,omitempty"`
  257. Parts []GeminiPart `json:"parts"`
  258. }
  259. type GeminiChatSafetySettings struct {
  260. Category string `json:"category"`
  261. Threshold string `json:"threshold"`
  262. }
  263. type GeminiChatTool struct {
  264. GoogleSearch any `json:"googleSearch,omitempty"`
  265. GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
  266. CodeExecution any `json:"codeExecution,omitempty"`
  267. FunctionDeclarations any `json:"functionDeclarations,omitempty"`
  268. URLContext any `json:"urlContext,omitempty"`
  269. }
  270. type GeminiChatGenerationConfig struct {
  271. Temperature *float64 `json:"temperature,omitempty"`
  272. TopP float64 `json:"topP,omitempty"`
  273. TopK float64 `json:"topK,omitempty"`
  274. MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
  275. CandidateCount int `json:"candidateCount,omitempty"`
  276. StopSequences []string `json:"stopSequences,omitempty"`
  277. ResponseMimeType string `json:"responseMimeType,omitempty"`
  278. ResponseSchema any `json:"responseSchema,omitempty"`
  279. ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
  280. PresencePenalty *float32 `json:"presencePenalty,omitempty"`
  281. FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
  282. ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
  283. Logprobs *int32 `json:"logprobs,omitempty"`
  284. MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
  285. Seed int64 `json:"seed,omitempty"`
  286. ResponseModalities []string `json:"responseModalities,omitempty"`
  287. ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
  288. SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
  289. ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config
  290. }
  291. type MediaResolution string
  292. type GeminiChatCandidate struct {
  293. Content GeminiChatContent `json:"content"`
  294. FinishReason *string `json:"finishReason"`
  295. Index int64 `json:"index"`
  296. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  297. }
  298. type GeminiChatSafetyRating struct {
  299. Category string `json:"category"`
  300. Probability string `json:"probability"`
  301. }
  302. type GeminiChatPromptFeedback struct {
  303. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  304. BlockReason *string `json:"blockReason,omitempty"`
  305. }
  306. type GeminiChatResponse struct {
  307. Candidates []GeminiChatCandidate `json:"candidates"`
  308. PromptFeedback *GeminiChatPromptFeedback `json:"promptFeedback,omitempty"`
  309. UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
  310. }
  311. type GeminiUsageMetadata struct {
  312. PromptTokenCount int `json:"promptTokenCount"`
  313. CandidatesTokenCount int `json:"candidatesTokenCount"`
  314. TotalTokenCount int `json:"totalTokenCount"`
  315. ThoughtsTokenCount int `json:"thoughtsTokenCount"`
  316. PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
  317. }
  318. type GeminiPromptTokensDetails struct {
  319. Modality string `json:"modality"`
  320. TokenCount int `json:"tokenCount"`
  321. }
  322. // Imagen related structs
  323. type GeminiImageRequest struct {
  324. Instances []GeminiImageInstance `json:"instances"`
  325. Parameters GeminiImageParameters `json:"parameters"`
  326. }
  327. type GeminiImageInstance struct {
  328. Prompt string `json:"prompt"`
  329. }
  330. type GeminiImageParameters struct {
  331. SampleCount int `json:"sampleCount,omitempty"`
  332. AspectRatio string `json:"aspectRatio,omitempty"`
  333. PersonGeneration string `json:"personGeneration,omitempty"`
  334. ImageSize string `json:"imageSize,omitempty"`
  335. }
  336. type GeminiImageResponse struct {
  337. Predictions []GeminiImagePrediction `json:"predictions"`
  338. }
  339. type GeminiImagePrediction struct {
  340. MimeType string `json:"mimeType"`
  341. BytesBase64Encoded string `json:"bytesBase64Encoded"`
  342. RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
  343. SafetyAttributes any `json:"safetyAttributes,omitempty"`
  344. }
  345. // Embedding related structs
  346. type GeminiEmbeddingRequest struct {
  347. Model string `json:"model,omitempty"`
  348. Content GeminiChatContent `json:"content"`
  349. TaskType string `json:"taskType,omitempty"`
  350. Title string `json:"title,omitempty"`
  351. OutputDimensionality int `json:"outputDimensionality,omitempty"`
  352. }
  353. func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
  354. // Gemini embedding requests are not streamed
  355. return false
  356. }
  357. func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
  358. var inputTexts []string
  359. for _, part := range r.Content.Parts {
  360. if part.Text != "" {
  361. inputTexts = append(inputTexts, part.Text)
  362. }
  363. }
  364. inputText := strings.Join(inputTexts, "\n")
  365. return &types.TokenCountMeta{
  366. CombineText: inputText,
  367. }
  368. }
  369. func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
  370. if modelName != "" {
  371. r.Model = modelName
  372. }
  373. }
  374. type GeminiBatchEmbeddingRequest struct {
  375. Requests []*GeminiEmbeddingRequest `json:"requests"`
  376. }
  377. func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
  378. // Gemini batch embedding requests are not streamed
  379. return false
  380. }
  381. func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
  382. var inputTexts []string
  383. for _, request := range r.Requests {
  384. meta := request.GetTokenCountMeta()
  385. if meta != nil && meta.CombineText != "" {
  386. inputTexts = append(inputTexts, meta.CombineText)
  387. }
  388. }
  389. inputText := strings.Join(inputTexts, "\n")
  390. return &types.TokenCountMeta{
  391. CombineText: inputText,
  392. }
  393. }
  394. func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
  395. if modelName != "" {
  396. for _, req := range r.Requests {
  397. req.SetModelName(modelName)
  398. }
  399. }
  400. }
  401. type GeminiEmbeddingResponse struct {
  402. Embedding ContentEmbedding `json:"embedding"`
  403. }
  404. type GeminiBatchEmbeddingResponse struct {
  405. Embeddings []*ContentEmbedding `json:"embeddings"`
  406. }
  407. type ContentEmbedding struct {
  408. Values []float64 `json:"values"`
  409. }