oauth_client.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package model
  2. import (
  3. "encoding/json"
  4. "one-api/common"
  5. "strings"
  6. "time"
  7. "gorm.io/gorm"
  8. )
  9. // OAuthClient OAuth2 客户端模型
  10. type OAuthClient struct {
  11. ID string `json:"id" gorm:"type:varchar(64);primaryKey"`
  12. Secret string `json:"secret" gorm:"type:varchar(128);not null"`
  13. Name string `json:"name" gorm:"type:varchar(255);not null"`
  14. Domain string `json:"domain" gorm:"type:varchar(255)"` // 允许的重定向域名
  15. RedirectURIs string `json:"redirect_uris" gorm:"type:text"` // JSON array of redirect URIs
  16. GrantTypes string `json:"grant_types" gorm:"type:varchar(255);default:'client_credentials'"`
  17. Scopes string `json:"scopes" gorm:"type:varchar(255);default:'api:read'"`
  18. RequirePKCE bool `json:"require_pkce" gorm:"default:true"`
  19. Status int `json:"status" gorm:"type:int;default:1"` // 1: enabled, 2: disabled
  20. CreatedBy int `json:"created_by" gorm:"type:int;not null"` // 创建者用户ID
  21. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  22. LastUsedTime int64 `json:"last_used_time" gorm:"bigint;default:0"`
  23. TokenCount int `json:"token_count" gorm:"type:int;default:0"` // 已签发的token数量
  24. Description string `json:"description" gorm:"type:text"`
  25. ClientType string `json:"client_type" gorm:"type:varchar(32);default:'confidential'"` // confidential, public
  26. DeletedAt gorm.DeletedAt `gorm:"index"`
  27. }
  28. // GetRedirectURIs 获取重定向URI列表
  29. func (c *OAuthClient) GetRedirectURIs() []string {
  30. if c.RedirectURIs == "" {
  31. return []string{}
  32. }
  33. var uris []string
  34. err := json.Unmarshal([]byte(c.RedirectURIs), &uris)
  35. if err != nil {
  36. common.SysLog("failed to unmarshal redirect URIs: " + err.Error())
  37. return []string{}
  38. }
  39. return uris
  40. }
  41. // SetRedirectURIs 设置重定向URI列表
  42. func (c *OAuthClient) SetRedirectURIs(uris []string) {
  43. data, err := json.Marshal(uris)
  44. if err != nil {
  45. common.SysLog("failed to marshal redirect URIs: " + err.Error())
  46. return
  47. }
  48. c.RedirectURIs = string(data)
  49. }
  50. // GetGrantTypes 获取允许的授权类型列表
  51. func (c *OAuthClient) GetGrantTypes() []string {
  52. if c.GrantTypes == "" {
  53. return []string{"client_credentials"}
  54. }
  55. return strings.Split(c.GrantTypes, ",")
  56. }
  57. // SetGrantTypes 设置允许的授权类型列表
  58. func (c *OAuthClient) SetGrantTypes(types []string) {
  59. c.GrantTypes = strings.Join(types, ",")
  60. }
  61. // GetScopes 获取允许的scope列表
  62. func (c *OAuthClient) GetScopes() []string {
  63. if c.Scopes == "" {
  64. return []string{"api:read"}
  65. }
  66. return strings.Split(c.Scopes, ",")
  67. }
  68. // SetScopes 设置允许的scope列表
  69. func (c *OAuthClient) SetScopes(scopes []string) {
  70. c.Scopes = strings.Join(scopes, ",")
  71. }
  72. // ValidateRedirectURI 验证重定向URI是否有效
  73. func (c *OAuthClient) ValidateRedirectURI(uri string) bool {
  74. allowedURIs := c.GetRedirectURIs()
  75. for _, allowedURI := range allowedURIs {
  76. if allowedURI == uri {
  77. return true
  78. }
  79. }
  80. return false
  81. }
  82. // ValidateGrantType 验证授权类型是否被允许
  83. func (c *OAuthClient) ValidateGrantType(grantType string) bool {
  84. allowedTypes := c.GetGrantTypes()
  85. for _, allowedType := range allowedTypes {
  86. if allowedType == grantType {
  87. return true
  88. }
  89. }
  90. return false
  91. }
  92. // ValidateScope 验证scope是否被允许
  93. func (c *OAuthClient) ValidateScope(scope string) bool {
  94. allowedScopes := c.GetScopes()
  95. requestedScopes := strings.Split(scope, " ")
  96. for _, requestedScope := range requestedScopes {
  97. requestedScope = strings.TrimSpace(requestedScope)
  98. if requestedScope == "" {
  99. continue
  100. }
  101. found := false
  102. for _, allowedScope := range allowedScopes {
  103. if allowedScope == requestedScope {
  104. found = true
  105. break
  106. }
  107. }
  108. if !found {
  109. return false
  110. }
  111. }
  112. return true
  113. }
  114. // BeforeCreate GORM hook - 在创建前设置时间
  115. func (c *OAuthClient) BeforeCreate(tx *gorm.DB) (err error) {
  116. c.CreatedTime = time.Now().Unix()
  117. return
  118. }
  119. // UpdateLastUsedTime 更新最后使用时间
  120. func (c *OAuthClient) UpdateLastUsedTime() error {
  121. c.LastUsedTime = time.Now().Unix()
  122. c.TokenCount++
  123. return DB.Model(c).Select("last_used_time", "token_count").Updates(c).Error
  124. }
  125. // GetOAuthClientByID 根据ID获取OAuth客户端
  126. func GetOAuthClientByID(id string) (*OAuthClient, error) {
  127. var client OAuthClient
  128. err := DB.Where("id = ? AND status = ?", id, common.UserStatusEnabled).First(&client).Error
  129. return &client, err
  130. }
  131. // GetAllOAuthClients 获取所有OAuth客户端
  132. func GetAllOAuthClients(startIdx int, num int) ([]*OAuthClient, error) {
  133. var clients []*OAuthClient
  134. err := DB.Order("created_time desc").Limit(num).Offset(startIdx).Find(&clients).Error
  135. return clients, err
  136. }
  137. // SearchOAuthClients 搜索OAuth客户端
  138. func SearchOAuthClients(keyword string) ([]*OAuthClient, error) {
  139. var clients []*OAuthClient
  140. err := DB.Where("name LIKE ? OR id LIKE ? OR description LIKE ?",
  141. "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%").Find(&clients).Error
  142. return clients, err
  143. }
  144. // CreateOAuthClient 创建OAuth客户端
  145. func CreateOAuthClient(client *OAuthClient) error {
  146. return DB.Create(client).Error
  147. }
  148. // UpdateOAuthClient 更新OAuth客户端
  149. func UpdateOAuthClient(client *OAuthClient) error {
  150. return DB.Save(client).Error
  151. }
  152. // DeleteOAuthClient 删除OAuth客户端
  153. func DeleteOAuthClient(id string) error {
  154. return DB.Where("id = ?", id).Delete(&OAuthClient{}).Error
  155. }
  156. // CountOAuthClients 统计OAuth客户端数量
  157. func CountOAuthClients() (int64, error) {
  158. var count int64
  159. err := DB.Model(&OAuthClient{}).Count(&count).Error
  160. return count, err
  161. }