| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- package model
- import (
- "errors"
- "strings"
- "time"
- )
- // CustomOAuthProvider stores configuration for custom OAuth providers
- type CustomOAuthProvider struct {
- Id int `json:"id" gorm:"primaryKey"`
- Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
- Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
- Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
- ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
- ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
- AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
- TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
- UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
- Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
- // Field mapping configuration (supports JSONPath via gjson)
- UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
- UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
- DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
- EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
- // Advanced options
- WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
- AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
- }
- func (CustomOAuthProvider) TableName() string {
- return "custom_oauth_providers"
- }
- // GetAllCustomOAuthProviders returns all custom OAuth providers
- func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
- var providers []*CustomOAuthProvider
- err := DB.Order("id asc").Find(&providers).Error
- return providers, err
- }
- // GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
- func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
- var providers []*CustomOAuthProvider
- err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
- return providers, err
- }
- // GetCustomOAuthProviderById returns a custom OAuth provider by ID
- func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
- var provider CustomOAuthProvider
- err := DB.First(&provider, id).Error
- if err != nil {
- return nil, err
- }
- return &provider, nil
- }
- // GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
- func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
- var provider CustomOAuthProvider
- err := DB.Where("slug = ?", slug).First(&provider).Error
- if err != nil {
- return nil, err
- }
- return &provider, nil
- }
- // CreateCustomOAuthProvider creates a new custom OAuth provider
- func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
- if err := validateCustomOAuthProvider(provider); err != nil {
- return err
- }
- return DB.Create(provider).Error
- }
- // UpdateCustomOAuthProvider updates an existing custom OAuth provider
- func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
- if err := validateCustomOAuthProvider(provider); err != nil {
- return err
- }
- return DB.Save(provider).Error
- }
- // DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
- func DeleteCustomOAuthProvider(id int) error {
- // First, delete all user bindings for this provider
- if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
- return err
- }
- return DB.Delete(&CustomOAuthProvider{}, id).Error
- }
- // IsSlugTaken checks if a slug is already taken by another provider
- // Returns true on DB errors (fail-closed) to prevent slug conflicts
- func IsSlugTaken(slug string, excludeId int) bool {
- var count int64
- query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
- if excludeId > 0 {
- query = query.Where("id != ?", excludeId)
- }
- res := query.Count(&count)
- if res.Error != nil {
- // Fail-closed: treat DB errors as slug being taken to prevent conflicts
- return true
- }
- return count > 0
- }
- // validateCustomOAuthProvider validates a custom OAuth provider configuration
- func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
- if provider.Name == "" {
- return errors.New("provider name is required")
- }
- if provider.Slug == "" {
- return errors.New("provider slug is required")
- }
- // Slug must be lowercase and contain only alphanumeric characters and hyphens
- slug := strings.ToLower(provider.Slug)
- for _, c := range slug {
- if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
- return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
- }
- }
- provider.Slug = slug
- if provider.ClientId == "" {
- return errors.New("client ID is required")
- }
- if provider.AuthorizationEndpoint == "" {
- return errors.New("authorization endpoint is required")
- }
- if provider.TokenEndpoint == "" {
- return errors.New("token endpoint is required")
- }
- if provider.UserInfoEndpoint == "" {
- return errors.New("user info endpoint is required")
- }
- // Set defaults for field mappings if empty
- if provider.UserIdField == "" {
- provider.UserIdField = "sub"
- }
- if provider.UsernameField == "" {
- provider.UsernameField = "preferred_username"
- }
- if provider.DisplayNameField == "" {
- provider.DisplayNameField = "name"
- }
- if provider.EmailField == "" {
- provider.EmailField = "email"
- }
- if provider.Scopes == "" {
- provider.Scopes = "openid profile email"
- }
- return nil
- }
|