relay_utils.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package common
  2. import (
  3. "fmt"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. "strings"
  9. "github.com/gin-gonic/gin"
  10. )
  11. type HasPrompt interface {
  12. GetPrompt() string
  13. }
  14. type HasImage interface {
  15. HasImage() bool
  16. }
  17. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  18. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  19. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  20. switch channelType {
  21. case constant.ChannelTypeOpenAI:
  22. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  23. case constant.ChannelTypeAzure:
  24. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  25. }
  26. }
  27. return fullRequestURL
  28. }
  29. func GetAPIVersion(c *gin.Context) string {
  30. query := c.Request.URL.Query()
  31. apiVersion := query.Get("api-version")
  32. if apiVersion == "" {
  33. apiVersion = c.GetString("api_version")
  34. }
  35. return apiVersion
  36. }
  37. func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
  38. return &dto.TaskError{
  39. Code: code,
  40. Message: err.Error(),
  41. StatusCode: statusCode,
  42. LocalError: localError,
  43. Error: err,
  44. }
  45. }
  46. func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
  47. info.Action = action
  48. c.Set("task_request", requestObj)
  49. }
  50. func validatePrompt(prompt string) *dto.TaskError {
  51. if strings.TrimSpace(prompt) == "" {
  52. return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
  53. }
  54. return nil
  55. }
  56. func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
  57. var req TaskSubmitReq
  58. if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  59. return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
  60. }
  61. if taskErr := validatePrompt(req.Prompt); taskErr != nil {
  62. return taskErr
  63. }
  64. if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
  65. // 兼容单图上传
  66. req.Images = []string{req.Image}
  67. }
  68. storeTaskRequest(c, info, action, req)
  69. return nil
  70. }
  71. func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
  72. hasPrompt, ok := requestObj.(HasPrompt)
  73. if !ok {
  74. return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
  75. }
  76. if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
  77. return taskErr
  78. }
  79. action := constant.TaskActionTextGenerate
  80. if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
  81. action = constant.TaskActionGenerate
  82. }
  83. storeTaskRequest(c, info, action, requestObj)
  84. return nil
  85. }
  86. func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
  87. var req TaskSubmitReq
  88. if err := c.ShouldBindJSON(&req); err != nil {
  89. return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
  90. }
  91. return ValidateTaskRequestWithImage(c, info, req)
  92. }