user_oauth_binding.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package model
  2. import (
  3. "errors"
  4. "time"
  5. "gorm.io/gorm"
  6. )
  7. // UserOAuthBinding stores the binding relationship between users and custom OAuth providers
  8. type UserOAuthBinding struct {
  9. Id int `json:"id" gorm:"primaryKey"`
  10. UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
  11. ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
  12. ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
  13. CreatedAt time.Time `json:"created_at"`
  14. }
  15. func (UserOAuthBinding) TableName() string {
  16. return "user_oauth_bindings"
  17. }
  18. // GetUserOAuthBindingsByUserId returns all OAuth bindings for a user
  19. func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) {
  20. var bindings []*UserOAuthBinding
  21. err := DB.Where("user_id = ?", userId).Find(&bindings).Error
  22. return bindings, err
  23. }
  24. // GetUserOAuthBinding returns a specific binding for a user and provider
  25. func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) {
  26. var binding UserOAuthBinding
  27. err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
  28. if err != nil {
  29. return nil, err
  30. }
  31. return &binding, nil
  32. }
  33. // GetUserByOAuthBinding finds a user by provider ID and provider user ID
  34. func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) {
  35. var binding UserOAuthBinding
  36. err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error
  37. if err != nil {
  38. return nil, err
  39. }
  40. var user User
  41. err = DB.First(&user, binding.UserId).Error
  42. if err != nil {
  43. return nil, err
  44. }
  45. return &user, nil
  46. }
  47. // IsProviderUserIdTaken checks if a provider user ID is already bound to any user
  48. func IsProviderUserIdTaken(providerId int, providerUserId string) bool {
  49. var count int64
  50. DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count)
  51. return count > 0
  52. }
  53. // CreateUserOAuthBinding creates a new OAuth binding
  54. func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
  55. if binding.UserId == 0 {
  56. return errors.New("user ID is required")
  57. }
  58. if binding.ProviderId == 0 {
  59. return errors.New("provider ID is required")
  60. }
  61. if binding.ProviderUserId == "" {
  62. return errors.New("provider user ID is required")
  63. }
  64. // Check if this provider user ID is already taken
  65. if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) {
  66. return errors.New("this OAuth account is already bound to another user")
  67. }
  68. binding.CreatedAt = time.Now()
  69. return DB.Create(binding).Error
  70. }
  71. // CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction
  72. func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error {
  73. if binding.UserId == 0 {
  74. return errors.New("user ID is required")
  75. }
  76. if binding.ProviderId == 0 {
  77. return errors.New("provider ID is required")
  78. }
  79. if binding.ProviderUserId == "" {
  80. return errors.New("provider user ID is required")
  81. }
  82. // Check if this provider user ID is already taken (use tx to check within the same transaction)
  83. var count int64
  84. tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count)
  85. if count > 0 {
  86. return errors.New("this OAuth account is already bound to another user")
  87. }
  88. binding.CreatedAt = time.Now()
  89. return tx.Create(binding).Error
  90. }
  91. // UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
  92. func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
  93. // Check if the new provider user ID is already taken by another user
  94. var existingBinding UserOAuthBinding
  95. err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error
  96. if err == nil && existingBinding.UserId != userId {
  97. return errors.New("this OAuth account is already bound to another user")
  98. }
  99. // Check if user already has a binding for this provider
  100. var binding UserOAuthBinding
  101. err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
  102. if err != nil {
  103. // No existing binding, create new one
  104. return CreateUserOAuthBinding(&UserOAuthBinding{
  105. UserId: userId,
  106. ProviderId: providerId,
  107. ProviderUserId: newProviderUserId,
  108. })
  109. }
  110. // Update existing binding
  111. return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error
  112. }
  113. // DeleteUserOAuthBinding deletes an OAuth binding
  114. func DeleteUserOAuthBinding(userId, providerId int) error {
  115. return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error
  116. }
  117. // DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user
  118. func DeleteUserOAuthBindingsByUserId(userId int) error {
  119. return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error
  120. }
  121. // GetBindingCountByProviderId returns the number of bindings for a provider
  122. func GetBindingCountByProviderId(providerId int) (int64, error) {
  123. var count int64
  124. err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error
  125. return count, err
  126. }