playground.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package controller
  2. import (
  3. "errors"
  4. "fmt"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. "one-api/middleware"
  9. "one-api/model"
  10. "one-api/setting"
  11. "one-api/types"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func Playground(c *gin.Context) {
  16. var newAPIError *types.NewAPIError
  17. defer func() {
  18. if newAPIError != nil {
  19. c.JSON(newAPIError.StatusCode, gin.H{
  20. "error": newAPIError.ToOpenAIError(),
  21. })
  22. }
  23. }()
  24. useAccessToken := c.GetBool("use_access_token")
  25. if useAccessToken {
  26. newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
  27. return
  28. }
  29. playgroundRequest := &dto.PlayGroundRequest{}
  30. err := common.UnmarshalBodyReusable(c, playgroundRequest)
  31. if err != nil {
  32. newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
  33. return
  34. }
  35. if playgroundRequest.Model == "" {
  36. newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
  37. return
  38. }
  39. c.Set("original_model", playgroundRequest.Model)
  40. group := playgroundRequest.Group
  41. userGroup := c.GetString("group")
  42. if group == "" {
  43. group = userGroup
  44. } else {
  45. if !setting.GroupInUserUsableGroups(group) && group != userGroup {
  46. newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
  47. return
  48. }
  49. c.Set("group", group)
  50. }
  51. userId := c.GetInt("id")
  52. // Write user context to ensure acceptUnsetRatio is available
  53. userCache, err := model.GetUserCache(userId)
  54. if err != nil {
  55. newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
  56. return
  57. }
  58. userCache.WriteContext(c)
  59. tempToken := &model.Token{
  60. UserId: userId,
  61. Name: fmt.Sprintf("playground-%s", group),
  62. Group: group,
  63. }
  64. _ = middleware.SetupContextForToken(c, tempToken)
  65. _, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
  66. if newAPIError != nil {
  67. return
  68. }
  69. //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
  70. common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
  71. Relay(c)
  72. }