|
|
@@ -907,7 +907,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|
|
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
if err != nil {
|
|
|
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
}
|
|
|
common.CloseResponseBodyGracefully(resp)
|
|
|
if common.DebugEnabled {
|
|
|
@@ -916,10 +916,10 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
|
|
var geminiResponse GeminiChatResponse
|
|
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
|
|
if err != nil {
|
|
|
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
}
|
|
|
if len(geminiResponse.Candidates) == 0 {
|
|
|
- return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
|
|
|
+ return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
}
|
|
|
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
|
|
fullTextResponse.Model = info.UpstreamModelName
|
|
|
@@ -956,12 +956,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|
|
|
|
|
responseBody, readErr := io.ReadAll(resp.Body)
|
|
|
if readErr != nil {
|
|
|
- return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
|
|
+ return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
var geminiResponse GeminiEmbeddingResponse
|
|
|
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
|
|
- return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
|
|
+ return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
// convert to openai format response
|
|
|
@@ -991,9 +991,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|
|
|
|
|
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
|
|
if jsonErr != nil {
|
|
|
- return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
|
|
+ return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
|
|
return usage, nil
|
|
|
}
|
|
|
+
|
|
|
+func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
|
+ responseBody, readErr := io.ReadAll(resp.Body)
|
|
|
+ if readErr != nil {
|
|
|
+ return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ _ = resp.Body.Close()
|
|
|
+
|
|
|
+ var geminiResponse GeminiImageResponse
|
|
|
+ if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
|
|
+ return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(geminiResponse.Predictions) == 0 {
|
|
|
+ return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ // convert to openai format response
|
|
|
+ openAIResponse := dto.ImageResponse{
|
|
|
+ Created: common.GetTimestamp(),
|
|
|
+ Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, prediction := range geminiResponse.Predictions {
|
|
|
+ if prediction.RaiFilteredReason != "" {
|
|
|
+ continue // skip filtered image
|
|
|
+ }
|
|
|
+ openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
|
|
+ B64Json: prediction.BytesBase64Encoded,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
|
|
+ if jsonErr != nil {
|
|
|
+ return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Writer.Header().Set("Content-Type", "application/json")
|
|
|
+ c.Writer.WriteHeader(resp.StatusCode)
|
|
|
+ _, _ = c.Writer.Write(jsonResponse)
|
|
|
+
|
|
|
+ // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
|
|
+ // each image has fixed 258 tokens
|
|
|
+ const imageTokens = 258
|
|
|
+ generatedImages := len(openAIResponse.Data)
|
|
|
+
|
|
|
+ usage := &dto.Usage{
|
|
|
+ PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
|
|
+ CompletionTokens: 0, // image generation does not calculate completion tokens
|
|
|
+ TotalTokens: imageTokens * generatedImages,
|
|
|
+ }
|
|
|
+
|
|
|
+ return usage, nil
|
|
|
+}
|