playground.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. package controller
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/dto"
  10. "one-api/middleware"
  11. "one-api/model"
  12. "one-api/service"
  13. "one-api/setting"
  14. "time"
  15. )
  16. func Playground(c *gin.Context) {
  17. var openaiErr *dto.OpenAIErrorWithStatusCode
  18. defer func() {
  19. if openaiErr != nil {
  20. c.JSON(openaiErr.StatusCode, gin.H{
  21. "error": openaiErr.Error,
  22. })
  23. }
  24. }()
  25. useAccessToken := c.GetBool("use_access_token")
  26. if useAccessToken {
  27. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
  28. return
  29. }
  30. playgroundRequest := &dto.PlayGroundRequest{}
  31. err := common.UnmarshalBodyReusable(c, playgroundRequest)
  32. if err != nil {
  33. openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
  34. return
  35. }
  36. if playgroundRequest.Model == "" {
  37. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
  38. return
  39. }
  40. c.Set("original_model", playgroundRequest.Model)
  41. group := playgroundRequest.Group
  42. userGroup := c.GetString("group")
  43. if group == "" {
  44. group = userGroup
  45. } else {
  46. if !setting.GroupInUserUsableGroups(group) && group != userGroup {
  47. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
  48. return
  49. }
  50. c.Set("group", group)
  51. }
  52. c.Set("token_name", "playground-"+group)
  53. channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
  54. if err != nil {
  55. message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
  56. openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
  57. return
  58. }
  59. middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
  60. c.Set(constant.ContextKeyRequestStartTime, time.Now())
  61. Relay(c)
  62. }