custom_oauth_provider.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package model
  2. import (
  3. "errors"
  4. "strings"
  5. "time"
  6. )
  7. // CustomOAuthProvider stores configuration for custom OAuth providers
  8. type CustomOAuthProvider struct {
  9. Id int `json:"id" gorm:"primaryKey"`
  10. Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
  11. Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
  12. Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
  13. ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
  14. ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
  15. AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
  16. TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
  17. UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
  18. Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
  19. // Field mapping configuration (supports JSONPath via gjson)
  20. UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
  21. UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
  22. DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
  23. EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
  24. // Advanced options
  25. WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
  26. AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
  27. CreatedAt time.Time `json:"created_at"`
  28. UpdatedAt time.Time `json:"updated_at"`
  29. }
  30. func (CustomOAuthProvider) TableName() string {
  31. return "custom_oauth_providers"
  32. }
  33. // GetAllCustomOAuthProviders returns all custom OAuth providers
  34. func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
  35. var providers []*CustomOAuthProvider
  36. err := DB.Order("id asc").Find(&providers).Error
  37. return providers, err
  38. }
  39. // GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
  40. func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
  41. var providers []*CustomOAuthProvider
  42. err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
  43. return providers, err
  44. }
  45. // GetCustomOAuthProviderById returns a custom OAuth provider by ID
  46. func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
  47. var provider CustomOAuthProvider
  48. err := DB.First(&provider, id).Error
  49. if err != nil {
  50. return nil, err
  51. }
  52. return &provider, nil
  53. }
  54. // GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
  55. func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
  56. var provider CustomOAuthProvider
  57. err := DB.Where("slug = ?", slug).First(&provider).Error
  58. if err != nil {
  59. return nil, err
  60. }
  61. return &provider, nil
  62. }
  63. // CreateCustomOAuthProvider creates a new custom OAuth provider
  64. func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
  65. if err := validateCustomOAuthProvider(provider); err != nil {
  66. return err
  67. }
  68. return DB.Create(provider).Error
  69. }
  70. // UpdateCustomOAuthProvider updates an existing custom OAuth provider
  71. func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
  72. if err := validateCustomOAuthProvider(provider); err != nil {
  73. return err
  74. }
  75. return DB.Save(provider).Error
  76. }
  77. // DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
  78. func DeleteCustomOAuthProvider(id int) error {
  79. // First, delete all user bindings for this provider
  80. if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
  81. return err
  82. }
  83. return DB.Delete(&CustomOAuthProvider{}, id).Error
  84. }
  85. // IsSlugTaken checks if a slug is already taken by another provider
  86. // Returns true on DB errors (fail-closed) to prevent slug conflicts
  87. func IsSlugTaken(slug string, excludeId int) bool {
  88. var count int64
  89. query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
  90. if excludeId > 0 {
  91. query = query.Where("id != ?", excludeId)
  92. }
  93. res := query.Count(&count)
  94. if res.Error != nil {
  95. // Fail-closed: treat DB errors as slug being taken to prevent conflicts
  96. return true
  97. }
  98. return count > 0
  99. }
  100. // validateCustomOAuthProvider validates a custom OAuth provider configuration
  101. func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
  102. if provider.Name == "" {
  103. return errors.New("provider name is required")
  104. }
  105. if provider.Slug == "" {
  106. return errors.New("provider slug is required")
  107. }
  108. // Slug must be lowercase and contain only alphanumeric characters and hyphens
  109. slug := strings.ToLower(provider.Slug)
  110. for _, c := range slug {
  111. if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
  112. return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
  113. }
  114. }
  115. provider.Slug = slug
  116. if provider.ClientId == "" {
  117. return errors.New("client ID is required")
  118. }
  119. if provider.AuthorizationEndpoint == "" {
  120. return errors.New("authorization endpoint is required")
  121. }
  122. if provider.TokenEndpoint == "" {
  123. return errors.New("token endpoint is required")
  124. }
  125. if provider.UserInfoEndpoint == "" {
  126. return errors.New("user info endpoint is required")
  127. }
  128. // Set defaults for field mappings if empty
  129. if provider.UserIdField == "" {
  130. provider.UserIdField = "sub"
  131. }
  132. if provider.UsernameField == "" {
  133. provider.UsernameField = "preferred_username"
  134. }
  135. if provider.DisplayNameField == "" {
  136. provider.DisplayNameField = "name"
  137. }
  138. if provider.EmailField == "" {
  139. provider.EmailField = "email"
  140. }
  141. if provider.Scopes == "" {
  142. provider.Scopes = "openid profile email"
  143. }
  144. return nil
  145. }