auth.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package middleware
  2. import (
  3. "fmt"
  4. "maps"
  5. "net/http"
  6. "slices"
  7. "strings"
  8. "github.com/gin-gonic/gin"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/common/config"
  11. "github.com/labring/aiproxy/core/common/network"
  12. "github.com/labring/aiproxy/core/model"
  13. "github.com/labring/aiproxy/core/relay/meta"
  14. "github.com/labring/aiproxy/core/relay/mode"
  15. "github.com/sirupsen/logrus"
  16. )
  17. type APIResponse struct {
  18. Data any `json:"data,omitempty"`
  19. Message string `json:"message,omitempty"`
  20. Success bool `json:"success"`
  21. }
  22. func SuccessResponse(c *gin.Context, data any) {
  23. c.JSON(http.StatusOK, &APIResponse{
  24. Success: true,
  25. Data: data,
  26. })
  27. }
  28. func ErrorResponse(c *gin.Context, code int, message string) {
  29. c.JSON(code, &APIResponse{
  30. Success: false,
  31. Message: message,
  32. })
  33. }
  34. func AdminAuth(c *gin.Context) {
  35. if config.AdminKey == "" {
  36. ErrorResponse(c, http.StatusUnauthorized, "unauthorized, admin key is not set")
  37. c.Abort()
  38. return
  39. }
  40. accessToken := c.Request.Header.Get("Authorization")
  41. if accessToken == "" {
  42. accessToken = c.Query("key")
  43. }
  44. accessToken = strings.TrimPrefix(accessToken, "Bearer ")
  45. accessToken = strings.TrimPrefix(accessToken, "sk-")
  46. if accessToken != config.AdminKey {
  47. ErrorResponse(c, http.StatusUnauthorized, "unauthorized, no access token provided")
  48. c.Abort()
  49. return
  50. }
  51. c.Set(Token, &model.TokenCache{
  52. Key: config.AdminKey,
  53. })
  54. group := c.Param("group")
  55. if group != "" {
  56. log := common.GetLogger(c)
  57. log.Data["gid"] = group
  58. }
  59. c.Next()
  60. }
  61. func TokenAuth(c *gin.Context) {
  62. log := common.GetLogger(c)
  63. key := c.Request.Header.Get("Authorization")
  64. if key == "" {
  65. key = c.Request.Header.Get("X-Api-Key")
  66. }
  67. if key == "" {
  68. key = c.Request.Header.Get("X-Goog-Api-Key")
  69. }
  70. key = strings.TrimPrefix(
  71. strings.TrimPrefix(key, "Bearer "),
  72. "sk-",
  73. )
  74. var (
  75. token model.TokenCache
  76. useInternalToken bool
  77. )
  78. if config.AdminKey != "" && config.AdminKey == key ||
  79. config.InternalToken != "" && config.InternalToken == key {
  80. token = model.TokenCache{
  81. Key: key,
  82. }
  83. useInternalToken = true
  84. } else {
  85. tokenCache, err := model.GetAndValidateToken(key)
  86. if err != nil {
  87. AbortLogWithMessage(c, http.StatusUnauthorized, err.Error())
  88. return
  89. }
  90. token = *tokenCache
  91. }
  92. SetLogTokenFields(log.Data, token, useInternalToken)
  93. if len(token.Subnets) > 0 {
  94. if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnets); err != nil {
  95. AbortLogWithMessage(c, http.StatusInternalServerError, err.Error())
  96. return
  97. } else if !ok {
  98. AbortLogWithMessage(c, http.StatusForbidden,
  99. fmt.Sprintf("token (%s[%d]) can only be used in the specified subnets: %v, current ip: %s",
  100. token.Name,
  101. token.ID,
  102. token.Subnets,
  103. c.ClientIP(),
  104. ),
  105. )
  106. return
  107. }
  108. }
  109. modelCaches := model.LoadModelCaches()
  110. var group model.GroupCache
  111. if useInternalToken {
  112. group = model.GroupCache{
  113. Status: model.GroupStatusInternal,
  114. AvailableSets: slices.Collect(maps.Keys(modelCaches.EnabledModelsBySet)),
  115. }
  116. } else {
  117. groupCache, err := model.CacheGetGroup(token.Group)
  118. if err != nil {
  119. AbortLogWithMessage(c, http.StatusInternalServerError, fmt.Sprintf("failed to get group: %v", err))
  120. return
  121. }
  122. group = *groupCache
  123. }
  124. c.Header("Group", group.ID)
  125. SetLogGroupFields(log.Data, group)
  126. if group.Status != model.GroupStatusEnabled && group.Status != model.GroupStatusInternal {
  127. AbortLogWithMessage(c, http.StatusForbidden, "group is disabled")
  128. return
  129. }
  130. token.SetAvailableSets(group.GetAvailableSets())
  131. token.SetModelsBySet(modelCaches.EnabledModelsBySet)
  132. c.Set(Group, group)
  133. c.Set(Token, token)
  134. c.Set(ModelCaches, modelCaches)
  135. c.Next()
  136. }
  137. func GetGroup(c *gin.Context) model.GroupCache {
  138. v, ok := c.MustGet(Group).(model.GroupCache)
  139. if !ok {
  140. panic(fmt.Sprintf("group cache type error: %T, %v", v, v))
  141. }
  142. return v
  143. }
  144. func GetToken(c *gin.Context) model.TokenCache {
  145. v, ok := c.MustGet(Token).(model.TokenCache)
  146. if !ok {
  147. panic(fmt.Sprintf("token cache type error: %T, %v", v, v))
  148. }
  149. return v
  150. }
  151. func GetModelCaches(c *gin.Context) *model.ModelCaches {
  152. v, ok := c.MustGet(ModelCaches).(*model.ModelCaches)
  153. if !ok {
  154. panic(fmt.Sprintf("model caches type error: %T, %v", v, v))
  155. }
  156. return v
  157. }
  158. func SetLogFieldsFromMeta(m *meta.Meta, fields logrus.Fields) {
  159. SetLogRequestIDField(fields, m.RequestID)
  160. SetLogModeField(fields, m.Mode)
  161. SetLogModelFields(fields, m.OriginModel)
  162. SetLogActualModelFields(fields, m.ActualModel)
  163. SetLogGroupFields(fields, m.Group)
  164. SetLogTokenFields(fields, m.Token, false)
  165. SetLogChannelFields(fields, m.Channel)
  166. }
  167. func SetLogModeField(fields logrus.Fields, mode mode.Mode) {
  168. fields["mode"] = mode.String()
  169. }
  170. func SetLogActualModelFields(fields logrus.Fields, actualModel string) {
  171. fields["actmodel"] = actualModel
  172. }
  173. func SetLogModelFields(fields logrus.Fields, model string) {
  174. fields["model"] = model
  175. }
  176. func SetLogChannelFields(fields logrus.Fields, channel meta.ChannelMeta) {
  177. if channel.ID > 0 {
  178. fields["chid"] = channel.ID
  179. }
  180. if channel.Name != "" {
  181. fields["chname"] = channel.Name
  182. }
  183. if channel.Type > 0 {
  184. fields["chtype"] = int(channel.Type)
  185. fields["chtype_name"] = channel.Type.String()
  186. }
  187. }
  188. func SetLogRequestIDField(fields logrus.Fields, requestID string) {
  189. fields["reqid"] = requestID
  190. }
  191. func SetLogGroupFields(fields logrus.Fields, group model.GroupCache) {
  192. if group.ID != "" {
  193. fields["gid"] = group.ID
  194. }
  195. }
  196. func SetLogTokenFields(fields logrus.Fields, token model.TokenCache, internal bool) {
  197. if token.ID > 0 {
  198. fields["kid"] = token.ID
  199. }
  200. if token.Name != "" {
  201. fields["kname"] = token.Name
  202. }
  203. if token.Key != "" {
  204. fields["key"] = maskTokenKey(token.Key)
  205. }
  206. if internal {
  207. fields["internal"] = "true"
  208. }
  209. }
  210. func maskTokenKey(key string) string {
  211. if len(key) <= 8 {
  212. return "*****"
  213. }
  214. return key[:4] + "*****" + key[len(key)-4:]
  215. }