| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- package controller
- import (
- "net/http"
- "strconv"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/oauth"
- "github.com/gin-gonic/gin"
- )
- // CustomOAuthProviderResponse is the response structure for custom OAuth providers
- // It excludes sensitive fields like client_secret
- type CustomOAuthProviderResponse struct {
- Id int `json:"id"`
- Name string `json:"name"`
- Slug string `json:"slug"`
- Enabled bool `json:"enabled"`
- ClientId string `json:"client_id"`
- AuthorizationEndpoint string `json:"authorization_endpoint"`
- TokenEndpoint string `json:"token_endpoint"`
- UserInfoEndpoint string `json:"user_info_endpoint"`
- Scopes string `json:"scopes"`
- UserIdField string `json:"user_id_field"`
- UsernameField string `json:"username_field"`
- DisplayNameField string `json:"display_name_field"`
- EmailField string `json:"email_field"`
- WellKnown string `json:"well_known"`
- AuthStyle int `json:"auth_style"`
- }
- func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
- return &CustomOAuthProviderResponse{
- Id: p.Id,
- Name: p.Name,
- Slug: p.Slug,
- Enabled: p.Enabled,
- ClientId: p.ClientId,
- AuthorizationEndpoint: p.AuthorizationEndpoint,
- TokenEndpoint: p.TokenEndpoint,
- UserInfoEndpoint: p.UserInfoEndpoint,
- Scopes: p.Scopes,
- UserIdField: p.UserIdField,
- UsernameField: p.UsernameField,
- DisplayNameField: p.DisplayNameField,
- EmailField: p.EmailField,
- WellKnown: p.WellKnown,
- AuthStyle: p.AuthStyle,
- }
- }
- // GetCustomOAuthProviders returns all custom OAuth providers
- func GetCustomOAuthProviders(c *gin.Context) {
- providers, err := model.GetAllCustomOAuthProviders()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- response := make([]*CustomOAuthProviderResponse, len(providers))
- for i, p := range providers {
- response[i] = toCustomOAuthProviderResponse(p)
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": response,
- })
- }
- // GetCustomOAuthProvider returns a single custom OAuth provider by ID
- func GetCustomOAuthProvider(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiErrorMsg(c, "无效的 ID")
- return
- }
- provider, err := model.GetCustomOAuthProviderById(id)
- if err != nil {
- common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": toCustomOAuthProviderResponse(provider),
- })
- }
- // CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
- type CreateCustomOAuthProviderRequest struct {
- Name string `json:"name" binding:"required"`
- Slug string `json:"slug" binding:"required"`
- Enabled bool `json:"enabled"`
- ClientId string `json:"client_id" binding:"required"`
- ClientSecret string `json:"client_secret" binding:"required"`
- AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
- TokenEndpoint string `json:"token_endpoint" binding:"required"`
- UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
- Scopes string `json:"scopes"`
- UserIdField string `json:"user_id_field"`
- UsernameField string `json:"username_field"`
- DisplayNameField string `json:"display_name_field"`
- EmailField string `json:"email_field"`
- WellKnown string `json:"well_known"`
- AuthStyle int `json:"auth_style"`
- }
- // CreateCustomOAuthProvider creates a new custom OAuth provider
- func CreateCustomOAuthProvider(c *gin.Context) {
- var req CreateCustomOAuthProviderRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
- return
- }
- // Check if slug is already taken
- if model.IsSlugTaken(req.Slug, 0) {
- common.ApiErrorMsg(c, "该 Slug 已被使用")
- return
- }
- // Check if slug conflicts with built-in providers
- if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
- common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
- return
- }
- provider := &model.CustomOAuthProvider{
- Name: req.Name,
- Slug: req.Slug,
- Enabled: req.Enabled,
- ClientId: req.ClientId,
- ClientSecret: req.ClientSecret,
- AuthorizationEndpoint: req.AuthorizationEndpoint,
- TokenEndpoint: req.TokenEndpoint,
- UserInfoEndpoint: req.UserInfoEndpoint,
- Scopes: req.Scopes,
- UserIdField: req.UserIdField,
- UsernameField: req.UsernameField,
- DisplayNameField: req.DisplayNameField,
- EmailField: req.EmailField,
- WellKnown: req.WellKnown,
- AuthStyle: req.AuthStyle,
- }
- if err := model.CreateCustomOAuthProvider(provider); err != nil {
- common.ApiError(c, err)
- return
- }
- // Register the provider in the OAuth registry
- oauth.RegisterOrUpdateCustomProvider(provider)
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "创建成功",
- "data": toCustomOAuthProviderResponse(provider),
- })
- }
- // UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
- type UpdateCustomOAuthProviderRequest struct {
- Name string `json:"name"`
- Slug string `json:"slug"`
- Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
- ClientId string `json:"client_id"`
- ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
- AuthorizationEndpoint string `json:"authorization_endpoint"`
- TokenEndpoint string `json:"token_endpoint"`
- UserInfoEndpoint string `json:"user_info_endpoint"`
- Scopes string `json:"scopes"`
- UserIdField string `json:"user_id_field"`
- UsernameField string `json:"username_field"`
- DisplayNameField string `json:"display_name_field"`
- EmailField string `json:"email_field"`
- WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
- AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
- }
- // UpdateCustomOAuthProvider updates an existing custom OAuth provider
- func UpdateCustomOAuthProvider(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiErrorMsg(c, "无效的 ID")
- return
- }
- var req UpdateCustomOAuthProviderRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
- return
- }
- // Get existing provider
- provider, err := model.GetCustomOAuthProviderById(id)
- if err != nil {
- common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
- return
- }
- oldSlug := provider.Slug
- // Check if new slug is taken by another provider
- if req.Slug != "" && req.Slug != provider.Slug {
- if model.IsSlugTaken(req.Slug, id) {
- common.ApiErrorMsg(c, "该 Slug 已被使用")
- return
- }
- // Check if slug conflicts with built-in providers
- if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
- common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
- return
- }
- }
- // Update fields
- if req.Name != "" {
- provider.Name = req.Name
- }
- if req.Slug != "" {
- provider.Slug = req.Slug
- }
- if req.Enabled != nil {
- provider.Enabled = *req.Enabled
- }
- if req.ClientId != "" {
- provider.ClientId = req.ClientId
- }
- if req.ClientSecret != "" {
- provider.ClientSecret = req.ClientSecret
- }
- if req.AuthorizationEndpoint != "" {
- provider.AuthorizationEndpoint = req.AuthorizationEndpoint
- }
- if req.TokenEndpoint != "" {
- provider.TokenEndpoint = req.TokenEndpoint
- }
- if req.UserInfoEndpoint != "" {
- provider.UserInfoEndpoint = req.UserInfoEndpoint
- }
- if req.Scopes != "" {
- provider.Scopes = req.Scopes
- }
- if req.UserIdField != "" {
- provider.UserIdField = req.UserIdField
- }
- if req.UsernameField != "" {
- provider.UsernameField = req.UsernameField
- }
- if req.DisplayNameField != "" {
- provider.DisplayNameField = req.DisplayNameField
- }
- if req.EmailField != "" {
- provider.EmailField = req.EmailField
- }
- if req.WellKnown != nil {
- provider.WellKnown = *req.WellKnown
- }
- if req.AuthStyle != nil {
- provider.AuthStyle = *req.AuthStyle
- }
- if err := model.UpdateCustomOAuthProvider(provider); err != nil {
- common.ApiError(c, err)
- return
- }
- // Update the provider in the OAuth registry
- if oldSlug != provider.Slug {
- oauth.UnregisterCustomProvider(oldSlug)
- }
- oauth.RegisterOrUpdateCustomProvider(provider)
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "更新成功",
- "data": toCustomOAuthProviderResponse(provider),
- })
- }
- // DeleteCustomOAuthProvider deletes a custom OAuth provider
- func DeleteCustomOAuthProvider(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiErrorMsg(c, "无效的 ID")
- return
- }
- // Get existing provider to get slug
- provider, err := model.GetCustomOAuthProviderById(id)
- if err != nil {
- common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
- return
- }
- // Check if there are any user bindings
- count, err := model.GetBindingCountByProviderId(id)
- if err != nil {
- common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
- common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
- return
- }
- if count > 0 {
- common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
- return
- }
- if err := model.DeleteCustomOAuthProvider(id); err != nil {
- common.ApiError(c, err)
- return
- }
- // Unregister the provider from the OAuth registry
- oauth.UnregisterCustomProvider(provider.Slug)
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "删除成功",
- })
- }
- // GetUserOAuthBindings returns all OAuth bindings for the current user
- func GetUserOAuthBindings(c *gin.Context) {
- userId := c.GetInt("id")
- if userId == 0 {
- common.ApiErrorMsg(c, "未登录")
- return
- }
- bindings, err := model.GetUserOAuthBindingsByUserId(userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // Build response with provider info
- type BindingResponse struct {
- ProviderId int `json:"provider_id"`
- ProviderName string `json:"provider_name"`
- ProviderSlug string `json:"provider_slug"`
- ProviderUserId string `json:"provider_user_id"`
- }
- response := make([]BindingResponse, 0)
- for _, binding := range bindings {
- provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
- if err != nil {
- continue // Skip if provider not found
- }
- response = append(response, BindingResponse{
- ProviderId: binding.ProviderId,
- ProviderName: provider.Name,
- ProviderSlug: provider.Slug,
- ProviderUserId: binding.ProviderUserId,
- })
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": response,
- })
- }
- // UnbindCustomOAuth unbinds a custom OAuth provider from the current user
- func UnbindCustomOAuth(c *gin.Context) {
- userId := c.GetInt("id")
- if userId == 0 {
- common.ApiErrorMsg(c, "未登录")
- return
- }
- providerIdStr := c.Param("provider_id")
- providerId, err := strconv.Atoi(providerIdStr)
- if err != nil {
- common.ApiErrorMsg(c, "无效的提供商 ID")
- return
- }
- if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "解绑成功",
- })
- }
|