| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- package oauth
- import (
- "fmt"
- "sync"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/model"
- )
- var (
- providers = make(map[string]Provider)
- mu sync.RWMutex
- // customProviderSlugs tracks which providers are custom (can be unregistered)
- customProviderSlugs = make(map[string]bool)
- )
- // Register registers an OAuth provider with the given name
- func Register(name string, provider Provider) {
- mu.Lock()
- defer mu.Unlock()
- providers[name] = provider
- }
- // RegisterCustom registers a custom OAuth provider (can be unregistered later)
- func RegisterCustom(name string, provider Provider) {
- mu.Lock()
- defer mu.Unlock()
- providers[name] = provider
- customProviderSlugs[name] = true
- }
- // Unregister removes a provider from the registry
- func Unregister(name string) {
- mu.Lock()
- defer mu.Unlock()
- delete(providers, name)
- delete(customProviderSlugs, name)
- }
- // GetProvider returns the OAuth provider for the given name
- func GetProvider(name string) Provider {
- mu.RLock()
- defer mu.RUnlock()
- return providers[name]
- }
- // GetAllProviders returns all registered OAuth providers
- func GetAllProviders() map[string]Provider {
- mu.RLock()
- defer mu.RUnlock()
- result := make(map[string]Provider, len(providers))
- for k, v := range providers {
- result[k] = v
- }
- return result
- }
- // GetEnabledCustomProviders returns all enabled custom OAuth providers
- func GetEnabledCustomProviders() []*GenericOAuthProvider {
- mu.RLock()
- defer mu.RUnlock()
- var result []*GenericOAuthProvider
- for name, provider := range providers {
- if customProviderSlugs[name] {
- if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() {
- result = append(result, gp)
- }
- }
- }
- return result
- }
- // IsProviderRegistered checks if a provider is registered
- func IsProviderRegistered(name string) bool {
- mu.RLock()
- defer mu.RUnlock()
- _, ok := providers[name]
- return ok
- }
- // IsCustomProvider checks if a provider is a custom provider
- func IsCustomProvider(name string) bool {
- mu.RLock()
- defer mu.RUnlock()
- return customProviderSlugs[name]
- }
- // LoadCustomProviders loads all custom OAuth providers from the database
- func LoadCustomProviders() error {
- // First, unregister all existing custom providers
- mu.Lock()
- for name := range customProviderSlugs {
- delete(providers, name)
- }
- customProviderSlugs = make(map[string]bool)
- mu.Unlock()
- // Load all custom providers from database
- customProviders, err := model.GetAllCustomOAuthProviders()
- if err != nil {
- common.SysError("Failed to load custom OAuth providers: " + err.Error())
- return err
- }
- // Register each custom provider
- for _, config := range customProviders {
- provider := NewGenericOAuthProvider(config)
- RegisterCustom(config.Slug, provider)
- common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")")
- }
- common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders)))
- return nil
- }
- // ReloadCustomProviders reloads all custom OAuth providers from the database
- func ReloadCustomProviders() error {
- return LoadCustomProviders()
- }
- // RegisterOrUpdateCustomProvider registers or updates a single custom provider
- func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) {
- provider := NewGenericOAuthProvider(config)
- mu.Lock()
- defer mu.Unlock()
- providers[config.Slug] = provider
- customProviderSlugs[config.Slug] = true
- }
- // UnregisterCustomProvider unregisters a custom provider by slug
- func UnregisterCustomProvider(slug string) {
- Unregister(slug)
- }
|