playground.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package controller
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/dto"
  9. "one-api/middleware"
  10. "one-api/model"
  11. "one-api/service"
  12. "one-api/setting"
  13. "time"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func Playground(c *gin.Context) {
  17. var openaiErr *dto.OpenAIErrorWithStatusCode
  18. defer func() {
  19. if openaiErr != nil {
  20. c.JSON(openaiErr.StatusCode, gin.H{
  21. "error": openaiErr.Error,
  22. })
  23. }
  24. }()
  25. useAccessToken := c.GetBool("use_access_token")
  26. if useAccessToken {
  27. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
  28. return
  29. }
  30. playgroundRequest := &dto.PlayGroundRequest{}
  31. err := common.UnmarshalBodyReusable(c, playgroundRequest)
  32. if err != nil {
  33. openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
  34. return
  35. }
  36. if playgroundRequest.Model == "" {
  37. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
  38. return
  39. }
  40. c.Set("original_model", playgroundRequest.Model)
  41. group := playgroundRequest.Group
  42. userGroup := c.GetString("group")
  43. if group == "" {
  44. group = userGroup
  45. } else {
  46. if !setting.GroupInUserUsableGroups(group) && group != userGroup {
  47. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
  48. return
  49. }
  50. c.Set("group", group)
  51. }
  52. c.Set("token_name", "playground-"+group)
  53. channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
  54. if err != nil {
  55. message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
  56. openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
  57. return
  58. }
  59. middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
  60. common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
  61. // Write user context to ensure acceptUnsetRatio is available
  62. userId := c.GetInt("id")
  63. userCache, err := model.GetUserCache(userId)
  64. if err != nil {
  65. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
  66. return
  67. }
  68. userCache.WriteContext(c)
  69. Relay(c)
  70. }