channel.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. package model
  2. import (
  3. "gorm.io/gorm"
  4. "one-api/common"
  5. )
  6. type Channel struct {
  7. Id int `json:"id"`
  8. Type int `json:"type" gorm:"default:0"`
  9. Key string `json:"key" gorm:"not null"`
  10. OpenAIOrganization *string `json:"openai_organization"`
  11. TestModel *string `json:"test_model"`
  12. Status int `json:"status" gorm:"default:1"`
  13. Name string `json:"name" gorm:"index"`
  14. Weight *uint `json:"weight" gorm:"default:0"`
  15. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  16. TestTime int64 `json:"test_time" gorm:"bigint"`
  17. ResponseTime int `json:"response_time"` // in milliseconds
  18. BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
  19. Other string `json:"other"`
  20. Balance float64 `json:"balance"` // in USD
  21. BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
  22. Models string `json:"models"`
  23. Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
  24. UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
  25. ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
  26. //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
  27. StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
  28. Priority *int64 `json:"priority" gorm:"bigint;default:0"`
  29. AutoBan *int `json:"auto_ban" gorm:"default:1"`
  30. }
  31. func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
  32. var channels []*Channel
  33. var err error
  34. order := "priority desc"
  35. if idSort {
  36. order = "id desc"
  37. }
  38. if selectAll {
  39. err = DB.Order(order).Find(&channels).Error
  40. } else {
  41. err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
  42. }
  43. return channels, err
  44. }
  45. func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
  46. var channels []*Channel
  47. keyCol := "`key`"
  48. groupCol := "`group`"
  49. modelsCol := "`models`"
  50. // 如果是 PostgreSQL,使用双引号
  51. if common.UsingPostgreSQL {
  52. keyCol = `"key"`
  53. groupCol = `"group"`
  54. modelsCol = `"models"`
  55. }
  56. // 构造基础查询
  57. baseQuery := DB.Model(&Channel{}).Omit(keyCol)
  58. // 构造WHERE子句
  59. var whereClause string
  60. var args []interface{}
  61. if group != "" {
  62. whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
  63. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
  64. } else {
  65. whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
  66. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
  67. }
  68. // 执行查询
  69. err := baseQuery.Where(whereClause, args...).Find(&channels).Error
  70. if err != nil {
  71. return nil, err
  72. }
  73. return channels, nil
  74. }
  75. func GetChannelById(id int, selectAll bool) (*Channel, error) {
  76. channel := Channel{Id: id}
  77. var err error = nil
  78. if selectAll {
  79. err = DB.First(&channel, "id = ?", id).Error
  80. } else {
  81. err = DB.Omit("key").First(&channel, "id = ?", id).Error
  82. }
  83. return &channel, err
  84. }
  85. func BatchInsertChannels(channels []Channel) error {
  86. var err error
  87. err = DB.Create(&channels).Error
  88. if err != nil {
  89. return err
  90. }
  91. for _, channel_ := range channels {
  92. err = channel_.AddAbilities()
  93. if err != nil {
  94. return err
  95. }
  96. }
  97. return nil
  98. }
  99. func BatchDeleteChannels(ids []int) error {
  100. //使用事务 删除channel表和channel_ability表
  101. tx := DB.Begin()
  102. err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
  103. if err != nil {
  104. // 回滚事务
  105. tx.Rollback()
  106. return err
  107. }
  108. err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
  109. if err != nil {
  110. // 回滚事务
  111. tx.Rollback()
  112. return err
  113. }
  114. // 提交事务
  115. tx.Commit()
  116. return err
  117. }
  118. func (channel *Channel) GetPriority() int64 {
  119. if channel.Priority == nil {
  120. return 0
  121. }
  122. return *channel.Priority
  123. }
  124. func (channel *Channel) GetWeight() int {
  125. if channel.Weight == nil {
  126. return 0
  127. }
  128. return int(*channel.Weight)
  129. }
  130. func (channel *Channel) GetBaseURL() string {
  131. if channel.BaseURL == nil {
  132. return ""
  133. }
  134. return *channel.BaseURL
  135. }
  136. func (channel *Channel) GetModelMapping() string {
  137. if channel.ModelMapping == nil {
  138. return ""
  139. }
  140. return *channel.ModelMapping
  141. }
  142. func (channel *Channel) GetStatusCodeMapping() string {
  143. if channel.StatusCodeMapping == nil {
  144. return ""
  145. }
  146. return *channel.StatusCodeMapping
  147. }
  148. func (channel *Channel) Insert() error {
  149. var err error
  150. err = DB.Create(channel).Error
  151. if err != nil {
  152. return err
  153. }
  154. err = channel.AddAbilities()
  155. return err
  156. }
  157. func (channel *Channel) Update() error {
  158. var err error
  159. err = DB.Model(channel).Updates(channel).Error
  160. if err != nil {
  161. return err
  162. }
  163. DB.Model(channel).First(channel, "id = ?", channel.Id)
  164. err = channel.UpdateAbilities()
  165. return err
  166. }
  167. func (channel *Channel) UpdateResponseTime(responseTime int64) {
  168. err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
  169. TestTime: common.GetTimestamp(),
  170. ResponseTime: int(responseTime),
  171. }).Error
  172. if err != nil {
  173. common.SysError("failed to update response time: " + err.Error())
  174. }
  175. }
  176. func (channel *Channel) UpdateBalance(balance float64) {
  177. err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
  178. BalanceUpdatedTime: common.GetTimestamp(),
  179. Balance: balance,
  180. }).Error
  181. if err != nil {
  182. common.SysError("failed to update balance: " + err.Error())
  183. }
  184. }
  185. func (channel *Channel) Delete() error {
  186. var err error
  187. err = DB.Delete(channel).Error
  188. if err != nil {
  189. return err
  190. }
  191. err = channel.DeleteAbilities()
  192. return err
  193. }
  194. func UpdateChannelStatusById(id int, status int) {
  195. err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
  196. if err != nil {
  197. common.SysError("failed to update ability status: " + err.Error())
  198. }
  199. err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
  200. if err != nil {
  201. common.SysError("failed to update channel status: " + err.Error())
  202. }
  203. }
  204. func UpdateChannelUsedQuota(id int, quota int) {
  205. if common.BatchUpdateEnabled {
  206. addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
  207. return
  208. }
  209. updateChannelUsedQuota(id, quota)
  210. }
  211. func updateChannelUsedQuota(id int, quota int) {
  212. err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
  213. if err != nil {
  214. common.SysError("failed to update channel used quota: " + err.Error())
  215. }
  216. }
  217. func DeleteChannelByStatus(status int64) (int64, error) {
  218. result := DB.Where("status = ?", status).Delete(&Channel{})
  219. return result.RowsAffected, result.Error
  220. }
  221. func DeleteDisabledChannel() (int64, error) {
  222. result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
  223. return result.RowsAffected, result.Error
  224. }