custom_oauth.go 12 KB


  1. package controller
  2. import (
  3. "net/http"
  4. "strconv"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/model"
  7. "github.com/QuantumNous/new-api/oauth"
  8. "github.com/gin-gonic/gin"
  9. )
  10. // CustomOAuthProviderResponse is the response structure for custom OAuth providers
  11. // It excludes sensitive fields like client_secret
  12. type CustomOAuthProviderResponse struct {
  13. Id int `json:"id"`
  14. Name string `json:"name"`
  15. Slug string `json:"slug"`
  16. Enabled bool `json:"enabled"`
  17. ClientId string `json:"client_id"`
  18. AuthorizationEndpoint string `json:"authorization_endpoint"`
  19. TokenEndpoint string `json:"token_endpoint"`
  20. UserInfoEndpoint string `json:"user_info_endpoint"`
  21. Scopes string `json:"scopes"`
  22. UserIdField string `json:"user_id_field"`
  23. UsernameField string `json:"username_field"`
  24. DisplayNameField string `json:"display_name_field"`
  25. EmailField string `json:"email_field"`
  26. WellKnown string `json:"well_known"`
  27. AuthStyle int `json:"auth_style"`
  28. }
  29. func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
  30. return &CustomOAuthProviderResponse{
  31. Id: p.Id,
  32. Name: p.Name,
  33. Slug: p.Slug,
  34. Enabled: p.Enabled,
  35. ClientId: p.ClientId,
  36. AuthorizationEndpoint: p.AuthorizationEndpoint,
  37. TokenEndpoint: p.TokenEndpoint,
  38. UserInfoEndpoint: p.UserInfoEndpoint,
  39. Scopes: p.Scopes,
  40. UserIdField: p.UserIdField,
  41. UsernameField: p.UsernameField,
  42. DisplayNameField: p.DisplayNameField,
  43. EmailField: p.EmailField,
  44. WellKnown: p.WellKnown,
  45. AuthStyle: p.AuthStyle,
  46. }
  47. }
  48. // GetCustomOAuthProviders returns all custom OAuth providers
  49. func GetCustomOAuthProviders(c *gin.Context) {
  50. providers, err := model.GetAllCustomOAuthProviders()
  51. if err != nil {
  52. common.ApiError(c, err)
  53. return
  54. }
  55. response := make([]*CustomOAuthProviderResponse, len(providers))
  56. for i, p := range providers {
  57. response[i] = toCustomOAuthProviderResponse(p)
  58. }
  59. c.JSON(http.StatusOK, gin.H{
  60. "success": true,
  61. "message": "",
  62. "data": response,
  63. })
  64. }
  65. // GetCustomOAuthProvider returns a single custom OAuth provider by ID
  66. func GetCustomOAuthProvider(c *gin.Context) {
  67. idStr := c.Param("id")
  68. id, err := strconv.Atoi(idStr)
  69. if err != nil {
  70. common.ApiErrorMsg(c, "无效的 ID")
  71. return
  72. }
  73. provider, err := model.GetCustomOAuthProviderById(id)
  74. if err != nil {
  75. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  76. return
  77. }
  78. c.JSON(http.StatusOK, gin.H{
  79. "success": true,
  80. "message": "",
  81. "data": toCustomOAuthProviderResponse(provider),
  82. })
  83. }
  84. // CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
  85. type CreateCustomOAuthProviderRequest struct {
  86. Name string `json:"name" binding:"required"`
  87. Slug string `json:"slug" binding:"required"`
  88. Enabled bool `json:"enabled"`
  89. ClientId string `json:"client_id" binding:"required"`
  90. ClientSecret string `json:"client_secret" binding:"required"`
  91. AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
  92. TokenEndpoint string `json:"token_endpoint" binding:"required"`
  93. UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
  94. Scopes string `json:"scopes"`
  95. UserIdField string `json:"user_id_field"`
  96. UsernameField string `json:"username_field"`
  97. DisplayNameField string `json:"display_name_field"`
  98. EmailField string `json:"email_field"`
  99. WellKnown string `json:"well_known"`
  100. AuthStyle int `json:"auth_style"`
  101. }
  102. // CreateCustomOAuthProvider creates a new custom OAuth provider
  103. func CreateCustomOAuthProvider(c *gin.Context) {
  104. var req CreateCustomOAuthProviderRequest
  105. if err := c.ShouldBindJSON(&req); err != nil {
  106. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  107. return
  108. }
  109. // Check if slug is already taken
  110. if model.IsSlugTaken(req.Slug, 0) {
  111. common.ApiErrorMsg(c, "该 Slug 已被使用")
  112. return
  113. }
  114. // Check if slug conflicts with built-in providers
  115. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  116. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  117. return
  118. }
  119. provider := &model.CustomOAuthProvider{
  120. Name: req.Name,
  121. Slug: req.Slug,
  122. Enabled: req.Enabled,
  123. ClientId: req.ClientId,
  124. ClientSecret: req.ClientSecret,
  125. AuthorizationEndpoint: req.AuthorizationEndpoint,
  126. TokenEndpoint: req.TokenEndpoint,
  127. UserInfoEndpoint: req.UserInfoEndpoint,
  128. Scopes: req.Scopes,
  129. UserIdField: req.UserIdField,
  130. UsernameField: req.UsernameField,
  131. DisplayNameField: req.DisplayNameField,
  132. EmailField: req.EmailField,
  133. WellKnown: req.WellKnown,
  134. AuthStyle: req.AuthStyle,
  135. }
  136. if err := model.CreateCustomOAuthProvider(provider); err != nil {
  137. common.ApiError(c, err)
  138. return
  139. }
  140. // Register the provider in the OAuth registry
  141. oauth.RegisterOrUpdateCustomProvider(provider)
  142. c.JSON(http.StatusOK, gin.H{
  143. "success": true,
  144. "message": "创建成功",
  145. "data": toCustomOAuthProviderResponse(provider),
  146. })
  147. }
  148. // UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
  149. type UpdateCustomOAuthProviderRequest struct {
  150. Name string `json:"name"`
  151. Slug string `json:"slug"`
  152. Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
  153. ClientId string `json:"client_id"`
  154. ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
  155. AuthorizationEndpoint string `json:"authorization_endpoint"`
  156. TokenEndpoint string `json:"token_endpoint"`
  157. UserInfoEndpoint string `json:"user_info_endpoint"`
  158. Scopes string `json:"scopes"`
  159. UserIdField string `json:"user_id_field"`
  160. UsernameField string `json:"username_field"`
  161. DisplayNameField string `json:"display_name_field"`
  162. EmailField string `json:"email_field"`
  163. WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
  164. AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
  165. }
  166. // UpdateCustomOAuthProvider updates an existing custom OAuth provider
  167. func UpdateCustomOAuthProvider(c *gin.Context) {
  168. idStr := c.Param("id")
  169. id, err := strconv.Atoi(idStr)
  170. if err != nil {
  171. common.ApiErrorMsg(c, "无效的 ID")
  172. return
  173. }
  174. var req UpdateCustomOAuthProviderRequest
  175. if err := c.ShouldBindJSON(&req); err != nil {
  176. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  177. return
  178. }
  179. // Get existing provider
  180. provider, err := model.GetCustomOAuthProviderById(id)
  181. if err != nil {
  182. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  183. return
  184. }
  185. oldSlug := provider.Slug
  186. // Check if new slug is taken by another provider
  187. if req.Slug != "" && req.Slug != provider.Slug {
  188. if model.IsSlugTaken(req.Slug, id) {
  189. common.ApiErrorMsg(c, "该 Slug 已被使用")
  190. return
  191. }
  192. // Check if slug conflicts with built-in providers
  193. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  194. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  195. return
  196. }
  197. }
  198. // Update fields
  199. if req.Name != "" {
  200. provider.Name = req.Name
  201. }
  202. if req.Slug != "" {
  203. provider.Slug = req.Slug
  204. }
  205. if req.Enabled != nil {
  206. provider.Enabled = *req.Enabled
  207. }
  208. if req.ClientId != "" {
  209. provider.ClientId = req.ClientId
  210. }
  211. if req.ClientSecret != "" {
  212. provider.ClientSecret = req.ClientSecret
  213. }
  214. if req.AuthorizationEndpoint != "" {
  215. provider.AuthorizationEndpoint = req.AuthorizationEndpoint
  216. }
  217. if req.TokenEndpoint != "" {
  218. provider.TokenEndpoint = req.TokenEndpoint
  219. }
  220. if req.UserInfoEndpoint != "" {
  221. provider.UserInfoEndpoint = req.UserInfoEndpoint
  222. }
  223. if req.Scopes != "" {
  224. provider.Scopes = req.Scopes
  225. }
  226. if req.UserIdField != "" {
  227. provider.UserIdField = req.UserIdField
  228. }
  229. if req.UsernameField != "" {
  230. provider.UsernameField = req.UsernameField
  231. }
  232. if req.DisplayNameField != "" {
  233. provider.DisplayNameField = req.DisplayNameField
  234. }
  235. if req.EmailField != "" {
  236. provider.EmailField = req.EmailField
  237. }
  238. if req.WellKnown != nil {
  239. provider.WellKnown = *req.WellKnown
  240. }
  241. if req.AuthStyle != nil {
  242. provider.AuthStyle = *req.AuthStyle
  243. }
  244. if err := model.UpdateCustomOAuthProvider(provider); err != nil {
  245. common.ApiError(c, err)
  246. return
  247. }
  248. // Update the provider in the OAuth registry
  249. if oldSlug != provider.Slug {
  250. oauth.UnregisterCustomProvider(oldSlug)
  251. }
  252. oauth.RegisterOrUpdateCustomProvider(provider)
  253. c.JSON(http.StatusOK, gin.H{
  254. "success": true,
  255. "message": "更新成功",
  256. "data": toCustomOAuthProviderResponse(provider),
  257. })
  258. }
  259. // DeleteCustomOAuthProvider deletes a custom OAuth provider
  260. func DeleteCustomOAuthProvider(c *gin.Context) {
  261. idStr := c.Param("id")
  262. id, err := strconv.Atoi(idStr)
  263. if err != nil {
  264. common.ApiErrorMsg(c, "无效的 ID")
  265. return
  266. }
  267. // Get existing provider to get slug
  268. provider, err := model.GetCustomOAuthProviderById(id)
  269. if err != nil {
  270. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  271. return
  272. }
  273. // Check if there are any user bindings
  274. count, err := model.GetBindingCountByProviderId(id)
  275. if err != nil {
  276. common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
  277. common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
  278. return
  279. }
  280. if count > 0 {
  281. common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
  282. return
  283. }
  284. if err := model.DeleteCustomOAuthProvider(id); err != nil {
  285. common.ApiError(c, err)
  286. return
  287. }
  288. // Unregister the provider from the OAuth registry
  289. oauth.UnregisterCustomProvider(provider.Slug)
  290. c.JSON(http.StatusOK, gin.H{
  291. "success": true,
  292. "message": "删除成功",
  293. })
  294. }
  295. // GetUserOAuthBindings returns all OAuth bindings for the current user
  296. func GetUserOAuthBindings(c *gin.Context) {
  297. userId := c.GetInt("id")
  298. if userId == 0 {
  299. common.ApiErrorMsg(c, "未登录")
  300. return
  301. }
  302. bindings, err := model.GetUserOAuthBindingsByUserId(userId)
  303. if err != nil {
  304. common.ApiError(c, err)
  305. return
  306. }
  307. // Build response with provider info
  308. type BindingResponse struct {
  309. ProviderId int `json:"provider_id"`
  310. ProviderName string `json:"provider_name"`
  311. ProviderSlug string `json:"provider_slug"`
  312. ProviderUserId string `json:"provider_user_id"`
  313. }
  314. response := make([]BindingResponse, 0)
  315. for _, binding := range bindings {
  316. provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
  317. if err != nil {
  318. continue // Skip if provider not found
  319. }
  320. response = append(response, BindingResponse{
  321. ProviderId: binding.ProviderId,
  322. ProviderName: provider.Name,
  323. ProviderSlug: provider.Slug,
  324. ProviderUserId: binding.ProviderUserId,
  325. })
  326. }
  327. c.JSON(http.StatusOK, gin.H{
  328. "success": true,
  329. "message": "",
  330. "data": response,
  331. })
  332. }
  333. // UnbindCustomOAuth unbinds a custom OAuth provider from the current user
  334. func UnbindCustomOAuth(c *gin.Context) {
  335. userId := c.GetInt("id")
  336. if userId == 0 {
  337. common.ApiErrorMsg(c, "未登录")
  338. return
  339. }
  340. providerIdStr := c.Param("provider_id")
  341. providerId, err := strconv.Atoi(providerIdStr)
  342. if err != nil {
  343. common.ApiErrorMsg(c, "无效的提供商 ID")
  344. return
  345. }
  346. if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
  347. common.ApiError(c, err)
  348. return
  349. }
  350. c.JSON(http.StatusOK, gin.H{
  351. "success": true,
  352. "message": "解绑成功",
  353. })
  354. }