auth.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package middleware
  2. import (
  3. "fmt"
  4. "maps"
  5. "net/http"
  6. "slices"
  7. "strings"
  8. "github.com/gin-gonic/gin"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/common/config"
  11. "github.com/labring/aiproxy/core/common/network"
  12. "github.com/labring/aiproxy/core/model"
  13. "github.com/labring/aiproxy/core/relay/meta"
  14. "github.com/labring/aiproxy/core/relay/mode"
  15. "github.com/sirupsen/logrus"
  16. )
  17. type APIResponse struct {
  18. Data any `json:"data,omitempty"`
  19. Message string `json:"message,omitempty"`
  20. Success bool `json:"success"`
  21. }
  22. func SuccessResponse(c *gin.Context, data any) {
  23. c.JSON(http.StatusOK, &APIResponse{
  24. Success: true,
  25. Data: data,
  26. })
  27. }
  28. func ErrorResponse(c *gin.Context, code int, message string) {
  29. c.JSON(code, &APIResponse{
  30. Success: false,
  31. Message: message,
  32. })
  33. }
  34. func AdminAuth(c *gin.Context) {
  35. if config.AdminKey == "" {
  36. ErrorResponse(c, http.StatusUnauthorized, "unauthorized, admin key is not set")
  37. c.Abort()
  38. return
  39. }
  40. accessToken := c.Request.Header.Get("Authorization")
  41. if accessToken == "" {
  42. accessToken = c.Query("key")
  43. }
  44. accessToken = strings.TrimPrefix(accessToken, "Bearer ")
  45. accessToken = strings.TrimPrefix(accessToken, "sk-")
  46. if accessToken != config.AdminKey {
  47. ErrorResponse(c, http.StatusUnauthorized, "unauthorized, no access token provided")
  48. c.Abort()
  49. return
  50. }
  51. c.Set(Token, &model.TokenCache{
  52. Key: config.AdminKey,
  53. })
  54. group := c.Param("group")
  55. if group != "" {
  56. log := common.GetLogger(c)
  57. log.Data["gid"] = group
  58. }
  59. c.Next()
  60. }
  61. func TokenAuth(c *gin.Context) {
  62. log := common.GetLogger(c)
  63. key := c.Request.Header.Get("Authorization")
  64. if key == "" {
  65. key = c.Request.Header.Get("X-Api-Key")
  66. }
  67. key = strings.TrimPrefix(
  68. strings.TrimPrefix(key, "Bearer "),
  69. "sk-",
  70. )
  71. var (
  72. token model.TokenCache
  73. useInternalToken bool
  74. )
  75. if config.AdminKey != "" && config.AdminKey == key ||
  76. config.InternalToken != "" && config.InternalToken == key {
  77. token = model.TokenCache{
  78. Key: key,
  79. }
  80. useInternalToken = true
  81. } else {
  82. tokenCache, err := model.GetAndValidateToken(key)
  83. if err != nil {
  84. AbortLogWithMessage(c, http.StatusUnauthorized, err.Error())
  85. return
  86. }
  87. token = *tokenCache
  88. }
  89. SetLogTokenFields(log.Data, token, useInternalToken)
  90. if len(token.Subnets) > 0 {
  91. if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnets); err != nil {
  92. AbortLogWithMessage(c, http.StatusInternalServerError, err.Error())
  93. return
  94. } else if !ok {
  95. AbortLogWithMessage(c, http.StatusForbidden,
  96. fmt.Sprintf("token (%s[%d]) can only be used in the specified subnets: %v, current ip: %s",
  97. token.Name,
  98. token.ID,
  99. token.Subnets,
  100. c.ClientIP(),
  101. ),
  102. )
  103. return
  104. }
  105. }
  106. modelCaches := model.LoadModelCaches()
  107. var group model.GroupCache
  108. if useInternalToken {
  109. group = model.GroupCache{
  110. Status: model.GroupStatusInternal,
  111. AvailableSets: slices.Collect(maps.Keys(modelCaches.EnabledModelsBySet)),
  112. }
  113. } else {
  114. groupCache, err := model.CacheGetGroup(token.Group)
  115. if err != nil {
  116. AbortLogWithMessage(c, http.StatusInternalServerError, fmt.Sprintf("failed to get group: %v", err))
  117. return
  118. }
  119. group = *groupCache
  120. }
  121. SetLogGroupFields(log.Data, group)
  122. if group.Status != model.GroupStatusEnabled && group.Status != model.GroupStatusInternal {
  123. AbortLogWithMessage(c, http.StatusForbidden, "group is disabled")
  124. return
  125. }
  126. token.SetAvailableSets(group.GetAvailableSets())
  127. token.SetModelsBySet(modelCaches.EnabledModelsBySet)
  128. c.Set(Group, group)
  129. c.Set(Token, token)
  130. c.Set(ModelCaches, modelCaches)
  131. c.Next()
  132. }
  133. func GetGroup(c *gin.Context) model.GroupCache {
  134. v, ok := c.MustGet(Group).(model.GroupCache)
  135. if !ok {
  136. panic(fmt.Sprintf("group cache type error: %T, %v", v, v))
  137. }
  138. return v
  139. }
  140. func GetToken(c *gin.Context) model.TokenCache {
  141. v, ok := c.MustGet(Token).(model.TokenCache)
  142. if !ok {
  143. panic(fmt.Sprintf("token cache type error: %T, %v", v, v))
  144. }
  145. return v
  146. }
  147. func GetModelCaches(c *gin.Context) *model.ModelCaches {
  148. v, ok := c.MustGet(ModelCaches).(*model.ModelCaches)
  149. if !ok {
  150. panic(fmt.Sprintf("model caches type error: %T, %v", v, v))
  151. }
  152. return v
  153. }
  154. func SetLogFieldsFromMeta(m *meta.Meta, fields logrus.Fields) {
  155. SetLogRequestIDField(fields, m.RequestID)
  156. SetLogModeField(fields, m.Mode)
  157. SetLogModelFields(fields, m.OriginModel)
  158. SetLogActualModelFields(fields, m.ActualModel)
  159. SetLogGroupFields(fields, m.Group)
  160. SetLogTokenFields(fields, m.Token, false)
  161. SetLogChannelFields(fields, m.Channel)
  162. }
  163. func SetLogModeField(fields logrus.Fields, mode mode.Mode) {
  164. fields["mode"] = mode.String()
  165. }
  166. func SetLogActualModelFields(fields logrus.Fields, actualModel string) {
  167. fields["actmodel"] = actualModel
  168. }
  169. func SetLogModelFields(fields logrus.Fields, model string) {
  170. fields["model"] = model
  171. }
  172. func SetLogChannelFields(fields logrus.Fields, channel meta.ChannelMeta) {
  173. if channel.ID > 0 {
  174. fields["chid"] = channel.ID
  175. }
  176. if channel.Name != "" {
  177. fields["chname"] = channel.Name
  178. }
  179. if channel.Type > 0 {
  180. fields["chtype"] = int(channel.Type)
  181. fields["chtype_name"] = channel.Type.String()
  182. }
  183. }
  184. func SetLogRequestIDField(fields logrus.Fields, requestID string) {
  185. fields["reqid"] = requestID
  186. }
  187. func SetLogGroupFields(fields logrus.Fields, group model.GroupCache) {
  188. if group.ID != "" {
  189. fields["gid"] = group.ID
  190. }
  191. }
  192. func SetLogTokenFields(fields logrus.Fields, token model.TokenCache, internal bool) {
  193. if token.ID > 0 {
  194. fields["kid"] = token.ID
  195. }
  196. if token.Name != "" {
  197. fields["kname"] = token.Name
  198. }
  199. if token.Key != "" {
  200. fields["key"] = maskTokenKey(token.Key)
  201. }
  202. if internal {
  203. fields["internal"] = "true"
  204. }
  205. }
  206. func maskTokenKey(key string) string {
  207. if len(key) <= 8 {
  208. return "*****"
  209. }
  210. return key[:4] + "*****" + key[len(key)-4:]
  211. }