|
|
@@ -0,0 +1,257 @@
|
|
|
+package controller
|
|
|
+
|
|
|
+import (
|
|
|
+ "net/http"
|
|
|
+ "strconv"
|
|
|
+
|
|
|
+ "github.com/QuantumNous/new-api/common"
|
|
|
+ "github.com/QuantumNous/new-api/i18n"
|
|
|
+ "github.com/QuantumNous/new-api/model"
|
|
|
+ "github.com/QuantumNous/new-api/oauth"
|
|
|
+ "github.com/gin-contrib/sessions"
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+)
|
|
|
+
|
|
|
+// providerParams returns map with Provider key for i18n templates
|
|
|
+func providerParams(name string) map[string]any {
|
|
|
+ return map[string]any{"Provider": name}
|
|
|
+}
|
|
|
+
|
|
|
+// GenerateOAuthCode generates a state code for OAuth CSRF protection
|
|
|
+func GenerateOAuthCode(c *gin.Context) {
|
|
|
+ session := sessions.Default(c)
|
|
|
+ state := common.GetRandomString(12)
|
|
|
+ affCode := c.Query("aff")
|
|
|
+ if affCode != "" {
|
|
|
+ session.Set("aff", affCode)
|
|
|
+ }
|
|
|
+ session.Set("oauth_state", state)
|
|
|
+ err := session.Save()
|
|
|
+ if err != nil {
|
|
|
+ common.ApiError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": true,
|
|
|
+ "message": "",
|
|
|
+ "data": state,
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// HandleOAuth handles OAuth callback for all standard OAuth providers
|
|
|
+func HandleOAuth(c *gin.Context) {
|
|
|
+ providerName := c.Param("provider")
|
|
|
+ provider := oauth.GetProvider(providerName)
|
|
|
+ if provider == nil {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ session := sessions.Default(c)
|
|
|
+
|
|
|
+ // 1. Validate state (CSRF protection)
|
|
|
+ state := c.Query("state")
|
|
|
+ if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
|
|
+ c.JSON(http.StatusForbidden, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": i18n.T(c, i18n.MsgOAuthStateInvalid),
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 2. Check if user is already logged in (bind flow)
|
|
|
+ username := session.Get("username")
|
|
|
+ if username != nil {
|
|
|
+ handleOAuthBind(c, provider)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 3. Check if provider is enabled
|
|
|
+ if !provider.IsEnabled() {
|
|
|
+ common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 4. Handle error from provider
|
|
|
+ errorCode := c.Query("error")
|
|
|
+ if errorCode != "" {
|
|
|
+ errorDescription := c.Query("error_description")
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": errorDescription,
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 5. Exchange code for token
|
|
|
+ code := c.Query("code")
|
|
|
+ token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
|
|
+ if err != nil {
|
|
|
+ handleOAuthError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 6. Get user info
|
|
|
+ oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
|
|
|
+ if err != nil {
|
|
|
+ handleOAuthError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 7. Find or create user
|
|
|
+ user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
|
|
|
+ if err != nil {
|
|
|
+ switch err.(type) {
|
|
|
+ case *OAuthUserDeletedError:
|
|
|
+ common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
|
|
|
+ case *OAuthRegistrationDisabledError:
|
|
|
+ common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
|
|
|
+ default:
|
|
|
+ common.ApiError(c, err)
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 8. Check user status
|
|
|
+ if user.Status != common.UserStatusEnabled {
|
|
|
+ common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 9. Setup login
|
|
|
+ setupLogin(user, c)
|
|
|
+}
|
|
|
+
|
|
|
+// handleOAuthBind handles binding OAuth account to existing user
|
|
|
+func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
|
|
+ if !provider.IsEnabled() {
|
|
|
+ common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Exchange code for token
|
|
|
+ code := c.Query("code")
|
|
|
+ token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
|
|
+ if err != nil {
|
|
|
+ handleOAuthError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Get user info
|
|
|
+ oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
|
|
|
+ if err != nil {
|
|
|
+ handleOAuthError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check if this OAuth account is already bound
|
|
|
+ if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
|
|
|
+ common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Get current user from session
|
|
|
+ session := sessions.Default(c)
|
|
|
+ id := session.Get("id")
|
|
|
+ user := model.User{Id: id.(int)}
|
|
|
+ err = user.FillUserById()
|
|
|
+ if err != nil {
|
|
|
+ common.ApiError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Update user with OAuth ID
|
|
|
+ provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
|
|
|
+ err = user.Update(false)
|
|
|
+ if err != nil {
|
|
|
+ common.ApiError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
|
|
|
+}
|
|
|
+
|
|
|
+// findOrCreateOAuthUser finds existing user or creates new user
|
|
|
+func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
|
|
|
+ user := &model.User{}
|
|
|
+
|
|
|
+ // Check if user already exists
|
|
|
+ if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
|
|
|
+ provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
|
|
+ err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ // Check if user has been deleted
|
|
|
+ if user.Id == 0 {
|
|
|
+ return nil, &OAuthUserDeletedError{}
|
|
|
+ }
|
|
|
+ return user, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // User doesn't exist, create new user if registration is enabled
|
|
|
+ if !common.RegisterEnabled {
|
|
|
+ return nil, &OAuthRegistrationDisabledError{}
|
|
|
+ }
|
|
|
+
|
|
|
+ // Set up new user
|
|
|
+ user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
|
|
|
+ if oauthUser.DisplayName != "" {
|
|
|
+ user.DisplayName = oauthUser.DisplayName
|
|
|
+ } else if oauthUser.Username != "" {
|
|
|
+ user.DisplayName = oauthUser.Username
|
|
|
+ } else {
|
|
|
+ user.DisplayName = provider.GetName() + " User"
|
|
|
+ }
|
|
|
+ if oauthUser.Email != "" {
|
|
|
+ user.Email = oauthUser.Email
|
|
|
+ }
|
|
|
+ user.Role = common.RoleCommonUser
|
|
|
+ user.Status = common.UserStatusEnabled
|
|
|
+ provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
|
|
+
|
|
|
+ // Handle affiliate code
|
|
|
+ affCode := session.Get("aff")
|
|
|
+ inviterId := 0
|
|
|
+ if affCode != nil {
|
|
|
+ inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := user.Insert(inviterId); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return user, nil
|
|
|
+}
|
|
|
+
|
|
|
+// Error types for OAuth
|
|
|
+type OAuthUserDeletedError struct{}
|
|
|
+
|
|
|
+func (e *OAuthUserDeletedError) Error() string {
|
|
|
+ return "user has been deleted"
|
|
|
+}
|
|
|
+
|
|
|
+type OAuthRegistrationDisabledError struct{}
|
|
|
+
|
|
|
+func (e *OAuthRegistrationDisabledError) Error() string {
|
|
|
+ return "registration is disabled"
|
|
|
+}
|
|
|
+
|
|
|
+// handleOAuthError handles OAuth errors and returns translated message
|
|
|
+func handleOAuthError(c *gin.Context, err error) {
|
|
|
+ switch e := err.(type) {
|
|
|
+ case *oauth.OAuthError:
|
|
|
+ if e.Params != nil {
|
|
|
+ common.ApiErrorI18n(c, e.MsgKey, e.Params)
|
|
|
+ } else {
|
|
|
+ common.ApiErrorI18n(c, e.MsgKey)
|
|
|
+ }
|
|
|
+ case *oauth.TrustLevelError:
|
|
|
+ common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
|
|
|
+ default:
|
|
|
+ common.ApiError(c, err)
|
|
|
+ }
|
|
|
+}
|