mcp.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package middleware
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strings"
  6. "github.com/gin-gonic/gin"
  7. "github.com/labring/aiproxy/core/common"
  8. "github.com/labring/aiproxy/core/common/config"
  9. "github.com/labring/aiproxy/core/common/network"
  10. "github.com/labring/aiproxy/core/model"
  11. )
  12. func MCPAuth(c *gin.Context) {
  13. log := common.GetLogger(c)
  14. key := c.Request.Header.Get("Authorization")
  15. if key == "" {
  16. key, _ = c.GetQuery("key")
  17. }
  18. key = strings.TrimPrefix(
  19. strings.TrimPrefix(key, "Bearer "),
  20. "sk-",
  21. )
  22. var (
  23. token model.TokenCache
  24. useInternalToken bool
  25. )
  26. if config.AdminKey != "" && config.AdminKey == key ||
  27. config.InternalToken != "" && config.InternalToken == key {
  28. token = model.TokenCache{
  29. Key: key,
  30. }
  31. useInternalToken = true
  32. } else {
  33. tokenCache, err := model.GetAndValidateToken(key)
  34. if err != nil {
  35. AbortLogWithMessage(c, http.StatusUnauthorized, err.Error())
  36. return
  37. }
  38. token = *tokenCache
  39. }
  40. SetLogTokenFields(log.Data, token, useInternalToken)
  41. if len(token.Subnets) > 0 {
  42. if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnets); err != nil {
  43. AbortLogWithMessage(c, http.StatusInternalServerError, err.Error())
  44. return
  45. } else if !ok {
  46. AbortLogWithMessage(c, http.StatusForbidden,
  47. fmt.Sprintf("token (%s[%d]) can only be used in the specified subnets: %v, current ip: %s",
  48. token.Name,
  49. token.ID,
  50. token.Subnets,
  51. c.ClientIP(),
  52. ),
  53. )
  54. return
  55. }
  56. }
  57. var group model.GroupCache
  58. if useInternalToken {
  59. group = model.GroupCache{
  60. Status: model.GroupStatusInternal,
  61. }
  62. } else {
  63. groupCache, err := model.CacheGetGroup(token.Group)
  64. if err != nil {
  65. AbortLogWithMessage(c, http.StatusInternalServerError, fmt.Sprintf("failed to get group: %v", err))
  66. return
  67. }
  68. group = *groupCache
  69. }
  70. SetLogGroupFields(log.Data, group)
  71. if group.Status != model.GroupStatusEnabled && group.Status != model.GroupStatusInternal {
  72. AbortLogWithMessage(c, http.StatusForbidden, "group is disabled")
  73. return
  74. }
  75. c.Set(Group, group)
  76. c.Set(Token, token)
  77. c.Next()
  78. }