| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- package controller
- import (
- "net/http"
- "strconv"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/i18n"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/oauth"
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- )
- // providerParams returns map with Provider key for i18n templates
- func providerParams(name string) map[string]any {
- return map[string]any{"Provider": name}
- }
- // GenerateOAuthCode generates a state code for OAuth CSRF protection
- func GenerateOAuthCode(c *gin.Context) {
- session := sessions.Default(c)
- state := common.GetRandomString(12)
- affCode := c.Query("aff")
- if affCode != "" {
- session.Set("aff", affCode)
- }
- session.Set("oauth_state", state)
- err := session.Save()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": state,
- })
- }
- // HandleOAuth handles OAuth callback for all standard OAuth providers
- func HandleOAuth(c *gin.Context) {
- providerName := c.Param("provider")
- provider := oauth.GetProvider(providerName)
- if provider == nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "success": false,
- "message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
- })
- return
- }
- session := sessions.Default(c)
- // 1. Validate state (CSRF protection)
- state := c.Query("state")
- if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": i18n.T(c, i18n.MsgOAuthStateInvalid),
- })
- return
- }
- // 2. Check if user is already logged in (bind flow)
- username := session.Get("username")
- if username != nil {
- handleOAuthBind(c, provider)
- return
- }
- // 3. Check if provider is enabled
- if !provider.IsEnabled() {
- common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
- return
- }
- // 4. Handle error from provider
- errorCode := c.Query("error")
- if errorCode != "" {
- errorDescription := c.Query("error_description")
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": errorDescription,
- })
- return
- }
- // 5. Exchange code for token
- code := c.Query("code")
- token, err := provider.ExchangeToken(c.Request.Context(), code, c)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // 6. Get user info
- oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // 7. Find or create user
- user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
- if err != nil {
- switch err.(type) {
- case *OAuthUserDeletedError:
- common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
- case *OAuthRegistrationDisabledError:
- common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
- default:
- common.ApiError(c, err)
- }
- return
- }
- // 8. Check user status
- if user.Status != common.UserStatusEnabled {
- common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
- return
- }
- // 9. Setup login
- setupLogin(user, c)
- }
- // handleOAuthBind handles binding OAuth account to existing user
- func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
- if !provider.IsEnabled() {
- common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
- return
- }
- // Exchange code for token
- code := c.Query("code")
- token, err := provider.ExchangeToken(c.Request.Context(), code, c)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // Get user info
- oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
- if err != nil {
- handleOAuthError(c, err)
- return
- }
- // Check if this OAuth account is already bound
- if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
- common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
- return
- }
- // Get current user from session
- session := sessions.Default(c)
- id := session.Get("id")
- user := model.User{Id: id.(int)}
- err = user.FillUserById()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // Update user with OAuth ID
- provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
- err = user.Update(false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
- }
- // findOrCreateOAuthUser finds existing user or creates new user
- func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
- user := &model.User{}
- // Check if user already exists
- if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
- provider.SetProviderUserID(user, oauthUser.ProviderUserID)
- err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
- if err != nil {
- return nil, err
- }
- // Check if user has been deleted
- if user.Id == 0 {
- return nil, &OAuthUserDeletedError{}
- }
- return user, nil
- }
- // User doesn't exist, create new user if registration is enabled
- if !common.RegisterEnabled {
- return nil, &OAuthRegistrationDisabledError{}
- }
- // Set up new user
- user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
- if oauthUser.DisplayName != "" {
- user.DisplayName = oauthUser.DisplayName
- } else if oauthUser.Username != "" {
- user.DisplayName = oauthUser.Username
- } else {
- user.DisplayName = provider.GetName() + " User"
- }
- if oauthUser.Email != "" {
- user.Email = oauthUser.Email
- }
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
- provider.SetProviderUserID(user, oauthUser.ProviderUserID)
- // Handle affiliate code
- affCode := session.Get("aff")
- inviterId := 0
- if affCode != nil {
- inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
- }
- if err := user.Insert(inviterId); err != nil {
- return nil, err
- }
- return user, nil
- }
- // Error types for OAuth
- type OAuthUserDeletedError struct{}
- func (e *OAuthUserDeletedError) Error() string {
- return "user has been deleted"
- }
- type OAuthRegistrationDisabledError struct{}
- func (e *OAuthRegistrationDisabledError) Error() string {
- return "registration is disabled"
- }
- // handleOAuthError handles OAuth errors and returns translated message
- func handleOAuthError(c *gin.Context, err error) {
- switch e := err.(type) {
- case *oauth.OAuthError:
- if e.Params != nil {
- common.ApiErrorI18n(c, e.MsgKey, e.Params)
- } else {
- common.ApiErrorI18n(c, e.MsgKey)
- }
- case *oauth.TrustLevelError:
- common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
- default:
- common.ApiError(c, err)
- }
- }
|