2
0

jimeng_adapter.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. package middleware
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/constant"
  9. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  10. "github.com/gin-gonic/gin"
  11. )
  12. func JimengRequestConvert() func(c *gin.Context) {
  13. return func(c *gin.Context) {
  14. action := c.Query("Action")
  15. if action == "" {
  16. abortWithOpenAiMessage(c, http.StatusBadRequest, "Action query parameter is required")
  17. return
  18. }
  19. // Handle Jimeng official API request
  20. var originalReq map[string]interface{}
  21. if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
  22. abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request body")
  23. return
  24. }
  25. model, _ := originalReq["req_key"].(string)
  26. prompt, _ := originalReq["prompt"].(string)
  27. unifiedReq := map[string]interface{}{
  28. "model": model,
  29. "prompt": prompt,
  30. "metadata": originalReq,
  31. }
  32. jsonData, err := json.Marshal(unifiedReq)
  33. if err != nil {
  34. abortWithOpenAiMessage(c, http.StatusInternalServerError, "Failed to marshal request body")
  35. return
  36. }
  37. // Update request body
  38. c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
  39. c.Set(common.KeyRequestBody, jsonData)
  40. if image, ok := originalReq["image"]; !ok || image == "" {
  41. c.Set("action", constant.TaskActionTextGenerate)
  42. }
  43. c.Request.URL.Path = "/v1/video/generations"
  44. if action == "CVSync2AsyncGetResult" {
  45. taskId, ok := originalReq["task_id"].(string)
  46. if !ok || taskId == "" {
  47. abortWithOpenAiMessage(c, http.StatusBadRequest, "task_id is required for CVSync2AsyncGetResult")
  48. return
  49. }
  50. c.Request.URL.Path = "/v1/video/generations/" + taskId
  51. c.Request.Method = http.MethodGet
  52. c.Set("task_id", taskId)
  53. c.Set("relay_mode", relayconstant.RelayModeVideoFetchByID)
  54. }
  55. c.Next()
  56. }
  57. }