pricing.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package controller
  2. import (
  3. "github.com/QuantumNous/new-api/common"
  4. "github.com/QuantumNous/new-api/model"
  5. "github.com/QuantumNous/new-api/service"
  6. "github.com/QuantumNous/new-api/setting/ratio_setting"
  7. "github.com/gin-gonic/gin"
  8. )
  9. func filterPricingByUsableGroups(pricing []model.Pricing, usableGroup map[string]string) []model.Pricing {
  10. if len(pricing) == 0 {
  11. return pricing
  12. }
  13. if len(usableGroup) == 0 {
  14. return []model.Pricing{}
  15. }
  16. filtered := make([]model.Pricing, 0, len(pricing))
  17. for _, item := range pricing {
  18. if common.StringsContains(item.EnableGroup, "all") {
  19. filtered = append(filtered, item)
  20. continue
  21. }
  22. for _, group := range item.EnableGroup {
  23. if _, ok := usableGroup[group]; ok {
  24. filtered = append(filtered, item)
  25. break
  26. }
  27. }
  28. }
  29. return filtered
  30. }
  31. func GetPricing(c *gin.Context) {
  32. pricing := model.GetPricing()
  33. userId, exists := c.Get("id")
  34. usableGroup := map[string]string{}
  35. groupRatio := map[string]float64{}
  36. for s, f := range ratio_setting.GetGroupRatioCopy() {
  37. groupRatio[s] = f
  38. }
  39. var group string
  40. if exists {
  41. user, err := model.GetUserCache(userId.(int))
  42. if err == nil {
  43. group = user.Group
  44. for g := range groupRatio {
  45. ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
  46. if ok {
  47. groupRatio[g] = ratio
  48. }
  49. }
  50. }
  51. }
  52. usableGroup = service.GetUserUsableGroups(group)
  53. pricing = filterPricingByUsableGroups(pricing, usableGroup)
  54. // check groupRatio contains usableGroup
  55. for group := range ratio_setting.GetGroupRatioCopy() {
  56. if _, ok := usableGroup[group]; !ok {
  57. delete(groupRatio, group)
  58. }
  59. }
  60. c.JSON(200, gin.H{
  61. "success": true,
  62. "data": pricing,
  63. "vendors": model.GetVendors(),
  64. "group_ratio": groupRatio,
  65. "usable_group": usableGroup,
  66. "supported_endpoint": model.GetSupportedEndpointMap(),
  67. "auto_groups": service.GetUserAutoGroup(group),
  68. "pricing_version": "a42d372ccf0b5dd13ecf71203521f9d2",
  69. })
  70. }
  71. func ResetModelRatio(c *gin.Context) {
  72. defaultStr := ratio_setting.DefaultModelRatio2JSONString()
  73. err := model.UpdateOption("ModelRatio", defaultStr)
  74. if err != nil {
  75. c.JSON(200, gin.H{
  76. "success": false,
  77. "message": err.Error(),
  78. })
  79. return
  80. }
  81. err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
  82. if err != nil {
  83. c.JSON(200, gin.H{
  84. "success": false,
  85. "message": err.Error(),
  86. })
  87. return
  88. }
  89. c.JSON(200, gin.H{
  90. "success": true,
  91. "message": "重置模型倍率成功",
  92. })
  93. }