auth.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package middleware
  2. import (
  3. "fmt"
  4. "maps"
  5. "net/http"
  6. "slices"
  7. "strings"
  8. "github.com/gin-gonic/gin"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/common/config"
  11. "github.com/labring/aiproxy/core/common/network"
  12. "github.com/labring/aiproxy/core/model"
  13. "github.com/labring/aiproxy/core/relay/meta"
  14. "github.com/labring/aiproxy/core/relay/mode"
  15. "github.com/sirupsen/logrus"
  16. )
  17. type APIResponse struct {
  18. Data any `json:"data,omitempty"`
  19. Message string `json:"message,omitempty"`
  20. Success bool `json:"success"`
  21. }
  22. func SuccessResponse(c *gin.Context, data any) {
  23. c.JSON(http.StatusOK, &APIResponse{
  24. Success: true,
  25. Data: data,
  26. })
  27. }
  28. func ErrorResponse(c *gin.Context, code int, message string) {
  29. c.JSON(code, &APIResponse{
  30. Success: false,
  31. Message: message,
  32. })
  33. }
  34. func AdminAuth(c *gin.Context) {
  35. if config.AdminKey == "" {
  36. ErrorResponse(c, http.StatusUnauthorized, "unauthorized, admin key is not set")
  37. c.Abort()
  38. return
  39. }
  40. accessToken := c.Request.Header.Get("Authorization")
  41. if accessToken == "" {
  42. accessToken = c.Query("key")
  43. }
  44. accessToken = strings.TrimPrefix(accessToken, "Bearer ")
  45. accessToken = strings.TrimPrefix(accessToken, "sk-")
  46. if accessToken != config.AdminKey {
  47. ErrorResponse(c, http.StatusUnauthorized, "unauthorized, no access token provided")
  48. c.Abort()
  49. return
  50. }
  51. c.Set(Token, &model.TokenCache{
  52. Key: config.AdminKey,
  53. })
  54. group := c.Param("group")
  55. if group != "" {
  56. log := common.GetLogger(c)
  57. log.Data["gid"] = group
  58. }
  59. c.Next()
  60. }
  61. func TokenAuth(c *gin.Context) {
  62. log := common.GetLogger(c)
  63. key := c.Request.Header.Get("Authorization")
  64. if key == "" {
  65. key = c.Request.Header.Get("X-Api-Key")
  66. }
  67. key = strings.TrimPrefix(
  68. strings.TrimPrefix(key, "Bearer "),
  69. "sk-",
  70. )
  71. var token model.TokenCache
  72. var useInternalToken bool
  73. if config.AdminKey != "" && config.AdminKey == key ||
  74. config.InternalToken != "" && config.InternalToken == key {
  75. token = model.TokenCache{
  76. Key: key,
  77. }
  78. useInternalToken = true
  79. } else {
  80. tokenCache, err := model.ValidateAndGetToken(key)
  81. if err != nil {
  82. AbortLogWithMessage(c, http.StatusUnauthorized, err.Error(), "invalid_token")
  83. return
  84. }
  85. token = *tokenCache
  86. }
  87. SetLogTokenFields(log.Data, token, useInternalToken)
  88. if len(token.Subnets) > 0 {
  89. if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnets); err != nil {
  90. AbortLogWithMessage(c, http.StatusInternalServerError, err.Error())
  91. return
  92. } else if !ok {
  93. AbortLogWithMessage(c, http.StatusForbidden,
  94. fmt.Sprintf("token (%s[%d]) can only be used in the specified subnets: %v, current ip: %s",
  95. token.Name,
  96. token.ID,
  97. token.Subnets,
  98. c.ClientIP(),
  99. ),
  100. )
  101. return
  102. }
  103. }
  104. modelCaches := model.LoadModelCaches()
  105. var group model.GroupCache
  106. if useInternalToken {
  107. group = model.GroupCache{
  108. Status: model.GroupStatusInternal,
  109. AvailableSets: slices.Collect(maps.Keys(modelCaches.EnabledModelsBySet)),
  110. }
  111. } else {
  112. groupCache, err := model.CacheGetGroup(token.Group)
  113. if err != nil {
  114. AbortLogWithMessage(c, http.StatusInternalServerError, fmt.Sprintf("failed to get group: %v", err))
  115. return
  116. }
  117. group = *groupCache
  118. }
  119. SetLogGroupFields(log.Data, group)
  120. if group.Status != model.GroupStatusEnabled && group.Status != model.GroupStatusInternal {
  121. AbortLogWithMessage(c, http.StatusForbidden, "group is disabled")
  122. return
  123. }
  124. token.SetAvailableSets(group.GetAvailableSets())
  125. token.SetModelsBySet(modelCaches.EnabledModelsBySet)
  126. c.Set(Group, group)
  127. c.Set(Token, token)
  128. c.Set(ModelCaches, modelCaches)
  129. c.Next()
  130. }
  131. func GetGroup(c *gin.Context) model.GroupCache {
  132. v, ok := c.MustGet(Group).(model.GroupCache)
  133. if !ok {
  134. panic(fmt.Sprintf("group cache type error: %T, %v", v, v))
  135. }
  136. return v
  137. }
  138. func GetToken(c *gin.Context) model.TokenCache {
  139. v, ok := c.MustGet(Token).(model.TokenCache)
  140. if !ok {
  141. panic(fmt.Sprintf("token cache type error: %T, %v", v, v))
  142. }
  143. return v
  144. }
  145. func GetModelCaches(c *gin.Context) *model.ModelCaches {
  146. v, ok := c.MustGet(ModelCaches).(*model.ModelCaches)
  147. if !ok {
  148. panic(fmt.Sprintf("model caches type error: %T, %v", v, v))
  149. }
  150. return v
  151. }
  152. func SetLogFieldsFromMeta(m *meta.Meta, fields logrus.Fields) {
  153. SetLogRequestIDField(fields, m.RequestID)
  154. SetLogModeField(fields, m.Mode)
  155. SetLogModelFields(fields, m.OriginModel)
  156. SetLogActualModelFields(fields, m.ActualModel)
  157. SetLogGroupFields(fields, m.Group)
  158. SetLogTokenFields(fields, m.Token, false)
  159. SetLogChannelFields(fields, m.Channel)
  160. }
  161. func SetLogModeField(fields logrus.Fields, mode mode.Mode) {
  162. fields["mode"] = mode.String()
  163. }
  164. func SetLogActualModelFields(fields logrus.Fields, actualModel string) {
  165. fields["actmodel"] = actualModel
  166. }
  167. func SetLogModelFields(fields logrus.Fields, model string) {
  168. fields["model"] = model
  169. }
  170. func SetLogChannelFields(fields logrus.Fields, channel meta.ChannelMeta) {
  171. if channel.ID > 0 {
  172. fields["chid"] = channel.ID
  173. }
  174. if channel.Name != "" {
  175. fields["chname"] = channel.Name
  176. }
  177. if channel.Type > 0 {
  178. fields["chtype"] = int(channel.Type)
  179. fields["chtype_name"] = channel.Type.String()
  180. }
  181. }
  182. func SetLogRequestIDField(fields logrus.Fields, requestID string) {
  183. fields["reqid"] = requestID
  184. }
  185. func SetLogGroupFields(fields logrus.Fields, group model.GroupCache) {
  186. if group.ID != "" {
  187. fields["gid"] = group.ID
  188. }
  189. }
  190. func SetLogTokenFields(fields logrus.Fields, token model.TokenCache, internal bool) {
  191. if token.ID > 0 {
  192. fields["kid"] = token.ID
  193. }
  194. if token.Name != "" {
  195. fields["kname"] = token.Name
  196. }
  197. if token.Key != "" {
  198. fields["key"] = maskTokenKey(token.Key)
  199. }
  200. if internal {
  201. fields["internal"] = "true"
  202. }
  203. }
  204. func maskTokenKey(key string) string {
  205. if len(key) <= 8 {
  206. return "*****"
  207. }
  208. return key[:4] + "*****" + key[len(key)-4:]
  209. }