user_oauth_binding.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package model
  2. import (
  3. "errors"
  4. "time"
  5. )
  6. // UserOAuthBinding stores the binding relationship between users and custom OAuth providers
  7. type UserOAuthBinding struct {
  8. Id int `json:"id" gorm:"primaryKey"`
  9. UserId int `json:"user_id" gorm:"index;not null"` // User ID
  10. ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
  11. ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
  12. CreatedAt time.Time `json:"created_at"`
  13. // Composite unique index to prevent duplicate bindings
  14. // One OAuth account can only be bound to one user
  15. }
  16. func (UserOAuthBinding) TableName() string {
  17. return "user_oauth_bindings"
  18. }
  19. // GetUserOAuthBindingsByUserId returns all OAuth bindings for a user
  20. func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) {
  21. var bindings []*UserOAuthBinding
  22. err := DB.Where("user_id = ?", userId).Find(&bindings).Error
  23. return bindings, err
  24. }
  25. // GetUserOAuthBinding returns a specific binding for a user and provider
  26. func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) {
  27. var binding UserOAuthBinding
  28. err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
  29. if err != nil {
  30. return nil, err
  31. }
  32. return &binding, nil
  33. }
  34. // GetUserByOAuthBinding finds a user by provider ID and provider user ID
  35. func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) {
  36. var binding UserOAuthBinding
  37. err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error
  38. if err != nil {
  39. return nil, err
  40. }
  41. var user User
  42. err = DB.First(&user, binding.UserId).Error
  43. if err != nil {
  44. return nil, err
  45. }
  46. return &user, nil
  47. }
  48. // IsProviderUserIdTaken checks if a provider user ID is already bound to any user
  49. func IsProviderUserIdTaken(providerId int, providerUserId string) bool {
  50. var count int64
  51. DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count)
  52. return count > 0
  53. }
  54. // CreateUserOAuthBinding creates a new OAuth binding
  55. func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
  56. if binding.UserId == 0 {
  57. return errors.New("user ID is required")
  58. }
  59. if binding.ProviderId == 0 {
  60. return errors.New("provider ID is required")
  61. }
  62. if binding.ProviderUserId == "" {
  63. return errors.New("provider user ID is required")
  64. }
  65. // Check if this provider user ID is already taken
  66. if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) {
  67. return errors.New("this OAuth account is already bound to another user")
  68. }
  69. binding.CreatedAt = time.Now()
  70. return DB.Create(binding).Error
  71. }
  72. // UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
  73. func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
  74. // Check if the new provider user ID is already taken by another user
  75. var existingBinding UserOAuthBinding
  76. err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error
  77. if err == nil && existingBinding.UserId != userId {
  78. return errors.New("this OAuth account is already bound to another user")
  79. }
  80. // Check if user already has a binding for this provider
  81. var binding UserOAuthBinding
  82. err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
  83. if err != nil {
  84. // No existing binding, create new one
  85. return CreateUserOAuthBinding(&UserOAuthBinding{
  86. UserId: userId,
  87. ProviderId: providerId,
  88. ProviderUserId: newProviderUserId,
  89. })
  90. }
  91. // Update existing binding
  92. return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error
  93. }
  94. // DeleteUserOAuthBinding deletes an OAuth binding
  95. func DeleteUserOAuthBinding(userId, providerId int) error {
  96. return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error
  97. }
  98. // DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user
  99. func DeleteUserOAuthBindingsByUserId(userId int) error {
  100. return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error
  101. }
  102. // GetBindingCountByProviderId returns the number of bindings for a provider
  103. func GetBindingCountByProviderId(providerId int) (int64, error) {
  104. var count int64
  105. err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error
  106. return count, err
  107. }