|
|
@@ -19,7 +19,20 @@ func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
|
|
return token
|
|
|
}
|
|
|
|
|
|
-func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
|
+func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
|
|
|
+ if embeddingRequest.Input == nil {
|
|
|
+ return fmt.Errorf("input is empty")
|
|
|
+ }
|
|
|
+ if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
|
|
+ embeddingRequest.Model = "omni-moderation-latest"
|
|
|
+ }
|
|
|
+ if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
|
|
+ embeddingRequest.Model = c.Param("model")
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
|
relayInfo := relaycommon.GenRelayInfo(c)
|
|
|
|
|
|
var embeddingRequest *dto.EmbeddingRequest
|
|
|
@@ -28,15 +41,12 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
|
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
|
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
|
|
}
|
|
|
- if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
|
|
- embeddingRequest.Model = "m3e-base"
|
|
|
- }
|
|
|
- if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
|
|
- embeddingRequest.Model = c.Param("model")
|
|
|
- }
|
|
|
- if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
|
|
|
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
|
|
|
+
|
|
|
+ err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
|
|
}
|
|
|
+
|
|
|
// map model name
|
|
|
modelMapping := c.GetString("model_mapping")
|
|
|
//isModelMapped := false
|
|
|
@@ -89,8 +99,8 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
|
|
}
|
|
|
adaptor.Init(relayInfo)
|
|
|
|
|
|
- convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
|
|
|
-
|
|
|
+ convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
|
|
+
|
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
@@ -100,7 +110,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
|
|
|
}
|
|
|
requestBody := bytes.NewBuffer(jsonData)
|
|
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
|
|
- resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
|
|
|
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
|
}
|