registry.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package oauth
  2. import (
  3. "fmt"
  4. "sync"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/model"
  7. )
  8. var (
  9. providers = make(map[string]Provider)
  10. mu sync.RWMutex
  11. // customProviderSlugs tracks which providers are custom (can be unregistered)
  12. customProviderSlugs = make(map[string]bool)
  13. )
  14. // Register registers an OAuth provider with the given name
  15. func Register(name string, provider Provider) {
  16. mu.Lock()
  17. defer mu.Unlock()
  18. providers[name] = provider
  19. }
  20. // RegisterCustom registers a custom OAuth provider (can be unregistered later)
  21. func RegisterCustom(name string, provider Provider) {
  22. mu.Lock()
  23. defer mu.Unlock()
  24. providers[name] = provider
  25. customProviderSlugs[name] = true
  26. }
  27. // Unregister removes a provider from the registry
  28. func Unregister(name string) {
  29. mu.Lock()
  30. defer mu.Unlock()
  31. delete(providers, name)
  32. delete(customProviderSlugs, name)
  33. }
  34. // GetProvider returns the OAuth provider for the given name
  35. func GetProvider(name string) Provider {
  36. mu.RLock()
  37. defer mu.RUnlock()
  38. return providers[name]
  39. }
  40. // GetAllProviders returns all registered OAuth providers
  41. func GetAllProviders() map[string]Provider {
  42. mu.RLock()
  43. defer mu.RUnlock()
  44. result := make(map[string]Provider, len(providers))
  45. for k, v := range providers {
  46. result[k] = v
  47. }
  48. return result
  49. }
  50. // GetEnabledCustomProviders returns all enabled custom OAuth providers
  51. func GetEnabledCustomProviders() []*GenericOAuthProvider {
  52. mu.RLock()
  53. defer mu.RUnlock()
  54. var result []*GenericOAuthProvider
  55. for name, provider := range providers {
  56. if customProviderSlugs[name] {
  57. if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() {
  58. result = append(result, gp)
  59. }
  60. }
  61. }
  62. return result
  63. }
  64. // IsProviderRegistered checks if a provider is registered
  65. func IsProviderRegistered(name string) bool {
  66. mu.RLock()
  67. defer mu.RUnlock()
  68. _, ok := providers[name]
  69. return ok
  70. }
  71. // IsCustomProvider checks if a provider is a custom provider
  72. func IsCustomProvider(name string) bool {
  73. mu.RLock()
  74. defer mu.RUnlock()
  75. return customProviderSlugs[name]
  76. }
  77. // LoadCustomProviders loads all custom OAuth providers from the database
  78. func LoadCustomProviders() error {
  79. // First, unregister all existing custom providers
  80. mu.Lock()
  81. for name := range customProviderSlugs {
  82. delete(providers, name)
  83. }
  84. customProviderSlugs = make(map[string]bool)
  85. mu.Unlock()
  86. // Load all custom providers from database
  87. customProviders, err := model.GetAllCustomOAuthProviders()
  88. if err != nil {
  89. common.SysError("Failed to load custom OAuth providers: " + err.Error())
  90. return err
  91. }
  92. // Register each custom provider
  93. for _, config := range customProviders {
  94. provider := NewGenericOAuthProvider(config)
  95. RegisterCustom(config.Slug, provider)
  96. common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")")
  97. }
  98. common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders)))
  99. return nil
  100. }
  101. // ReloadCustomProviders reloads all custom OAuth providers from the database
  102. func ReloadCustomProviders() error {
  103. return LoadCustomProviders()
  104. }
  105. // RegisterOrUpdateCustomProvider registers or updates a single custom provider
  106. func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) {
  107. provider := NewGenericOAuthProvider(config)
  108. mu.Lock()
  109. defer mu.Unlock()
  110. providers[config.Slug] = provider
  111. customProviderSlugs[config.Slug] = true
  112. }
  113. // UnregisterCustomProvider unregisters a custom provider by slug
  114. func UnregisterCustomProvider(slug string) {
  115. Unregister(slug)
  116. }