custom_oauth.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. package controller
  2. import (
  3. "net/http"
  4. "strconv"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/model"
  7. "github.com/QuantumNous/new-api/oauth"
  8. "github.com/gin-gonic/gin"
  9. )
  10. // CustomOAuthProviderResponse is the response structure for custom OAuth providers
  11. // It excludes sensitive fields like client_secret
  12. type CustomOAuthProviderResponse struct {
  13. Id int `json:"id"`
  14. Name string `json:"name"`
  15. Slug string `json:"slug"`
  16. Enabled bool `json:"enabled"`
  17. ClientId string `json:"client_id"`
  18. AuthorizationEndpoint string `json:"authorization_endpoint"`
  19. TokenEndpoint string `json:"token_endpoint"`
  20. UserInfoEndpoint string `json:"user_info_endpoint"`
  21. Scopes string `json:"scopes"`
  22. UserIdField string `json:"user_id_field"`
  23. UsernameField string `json:"username_field"`
  24. DisplayNameField string `json:"display_name_field"`
  25. EmailField string `json:"email_field"`
  26. WellKnown string `json:"well_known"`
  27. AuthStyle int `json:"auth_style"`
  28. }
  29. func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
  30. return &CustomOAuthProviderResponse{
  31. Id: p.Id,
  32. Name: p.Name,
  33. Slug: p.Slug,
  34. Enabled: p.Enabled,
  35. ClientId: p.ClientId,
  36. AuthorizationEndpoint: p.AuthorizationEndpoint,
  37. TokenEndpoint: p.TokenEndpoint,
  38. UserInfoEndpoint: p.UserInfoEndpoint,
  39. Scopes: p.Scopes,
  40. UserIdField: p.UserIdField,
  41. UsernameField: p.UsernameField,
  42. DisplayNameField: p.DisplayNameField,
  43. EmailField: p.EmailField,
  44. WellKnown: p.WellKnown,
  45. AuthStyle: p.AuthStyle,
  46. }
  47. }
  48. // GetCustomOAuthProviders returns all custom OAuth providers
  49. func GetCustomOAuthProviders(c *gin.Context) {
  50. providers, err := model.GetAllCustomOAuthProviders()
  51. if err != nil {
  52. common.ApiError(c, err)
  53. return
  54. }
  55. response := make([]*CustomOAuthProviderResponse, len(providers))
  56. for i, p := range providers {
  57. response[i] = toCustomOAuthProviderResponse(p)
  58. }
  59. c.JSON(http.StatusOK, gin.H{
  60. "success": true,
  61. "message": "",
  62. "data": response,
  63. })
  64. }
  65. // GetCustomOAuthProvider returns a single custom OAuth provider by ID
  66. func GetCustomOAuthProvider(c *gin.Context) {
  67. idStr := c.Param("id")
  68. id, err := strconv.Atoi(idStr)
  69. if err != nil {
  70. common.ApiErrorMsg(c, "无效的 ID")
  71. return
  72. }
  73. provider, err := model.GetCustomOAuthProviderById(id)
  74. if err != nil {
  75. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  76. return
  77. }
  78. c.JSON(http.StatusOK, gin.H{
  79. "success": true,
  80. "message": "",
  81. "data": toCustomOAuthProviderResponse(provider),
  82. })
  83. }
  84. // CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
  85. type CreateCustomOAuthProviderRequest struct {
  86. Name string `json:"name" binding:"required"`
  87. Slug string `json:"slug" binding:"required"`
  88. Enabled bool `json:"enabled"`
  89. ClientId string `json:"client_id" binding:"required"`
  90. ClientSecret string `json:"client_secret" binding:"required"`
  91. AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
  92. TokenEndpoint string `json:"token_endpoint" binding:"required"`
  93. UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
  94. Scopes string `json:"scopes"`
  95. UserIdField string `json:"user_id_field"`
  96. UsernameField string `json:"username_field"`
  97. DisplayNameField string `json:"display_name_field"`
  98. EmailField string `json:"email_field"`
  99. WellKnown string `json:"well_known"`
  100. AuthStyle int `json:"auth_style"`
  101. }
  102. // CreateCustomOAuthProvider creates a new custom OAuth provider
  103. func CreateCustomOAuthProvider(c *gin.Context) {
  104. var req CreateCustomOAuthProviderRequest
  105. if err := c.ShouldBindJSON(&req); err != nil {
  106. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  107. return
  108. }
  109. // Check if slug is already taken
  110. if model.IsSlugTaken(req.Slug, 0) {
  111. common.ApiErrorMsg(c, "该 Slug 已被使用")
  112. return
  113. }
  114. // Check if slug conflicts with built-in providers
  115. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  116. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  117. return
  118. }
  119. provider := &model.CustomOAuthProvider{
  120. Name: req.Name,
  121. Slug: req.Slug,
  122. Enabled: req.Enabled,
  123. ClientId: req.ClientId,
  124. ClientSecret: req.ClientSecret,
  125. AuthorizationEndpoint: req.AuthorizationEndpoint,
  126. TokenEndpoint: req.TokenEndpoint,
  127. UserInfoEndpoint: req.UserInfoEndpoint,
  128. Scopes: req.Scopes,
  129. UserIdField: req.UserIdField,
  130. UsernameField: req.UsernameField,
  131. DisplayNameField: req.DisplayNameField,
  132. EmailField: req.EmailField,
  133. WellKnown: req.WellKnown,
  134. AuthStyle: req.AuthStyle,
  135. }
  136. if err := model.CreateCustomOAuthProvider(provider); err != nil {
  137. common.ApiError(c, err)
  138. return
  139. }
  140. // Register the provider in the OAuth registry
  141. oauth.RegisterOrUpdateCustomProvider(provider)
  142. c.JSON(http.StatusOK, gin.H{
  143. "success": true,
  144. "message": "创建成功",
  145. "data": toCustomOAuthProviderResponse(provider),
  146. })
  147. }
  148. // UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
  149. type UpdateCustomOAuthProviderRequest struct {
  150. Name string `json:"name"`
  151. Slug string `json:"slug"`
  152. Enabled bool `json:"enabled"`
  153. ClientId string `json:"client_id"`
  154. ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
  155. AuthorizationEndpoint string `json:"authorization_endpoint"`
  156. TokenEndpoint string `json:"token_endpoint"`
  157. UserInfoEndpoint string `json:"user_info_endpoint"`
  158. Scopes string `json:"scopes"`
  159. UserIdField string `json:"user_id_field"`
  160. UsernameField string `json:"username_field"`
  161. DisplayNameField string `json:"display_name_field"`
  162. EmailField string `json:"email_field"`
  163. WellKnown string `json:"well_known"`
  164. AuthStyle int `json:"auth_style"`
  165. }
  166. // UpdateCustomOAuthProvider updates an existing custom OAuth provider
  167. func UpdateCustomOAuthProvider(c *gin.Context) {
  168. idStr := c.Param("id")
  169. id, err := strconv.Atoi(idStr)
  170. if err != nil {
  171. common.ApiErrorMsg(c, "无效的 ID")
  172. return
  173. }
  174. var req UpdateCustomOAuthProviderRequest
  175. if err := c.ShouldBindJSON(&req); err != nil {
  176. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  177. return
  178. }
  179. // Get existing provider
  180. provider, err := model.GetCustomOAuthProviderById(id)
  181. if err != nil {
  182. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  183. return
  184. }
  185. oldSlug := provider.Slug
  186. // Check if new slug is taken by another provider
  187. if req.Slug != "" && req.Slug != provider.Slug {
  188. if model.IsSlugTaken(req.Slug, id) {
  189. common.ApiErrorMsg(c, "该 Slug 已被使用")
  190. return
  191. }
  192. // Check if slug conflicts with built-in providers
  193. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  194. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  195. return
  196. }
  197. }
  198. // Update fields
  199. if req.Name != "" {
  200. provider.Name = req.Name
  201. }
  202. if req.Slug != "" {
  203. provider.Slug = req.Slug
  204. }
  205. provider.Enabled = req.Enabled
  206. if req.ClientId != "" {
  207. provider.ClientId = req.ClientId
  208. }
  209. if req.ClientSecret != "" {
  210. provider.ClientSecret = req.ClientSecret
  211. }
  212. if req.AuthorizationEndpoint != "" {
  213. provider.AuthorizationEndpoint = req.AuthorizationEndpoint
  214. }
  215. if req.TokenEndpoint != "" {
  216. provider.TokenEndpoint = req.TokenEndpoint
  217. }
  218. if req.UserInfoEndpoint != "" {
  219. provider.UserInfoEndpoint = req.UserInfoEndpoint
  220. }
  221. if req.Scopes != "" {
  222. provider.Scopes = req.Scopes
  223. }
  224. if req.UserIdField != "" {
  225. provider.UserIdField = req.UserIdField
  226. }
  227. if req.UsernameField != "" {
  228. provider.UsernameField = req.UsernameField
  229. }
  230. if req.DisplayNameField != "" {
  231. provider.DisplayNameField = req.DisplayNameField
  232. }
  233. if req.EmailField != "" {
  234. provider.EmailField = req.EmailField
  235. }
  236. provider.WellKnown = req.WellKnown
  237. provider.AuthStyle = req.AuthStyle
  238. if err := model.UpdateCustomOAuthProvider(provider); err != nil {
  239. common.ApiError(c, err)
  240. return
  241. }
  242. // Update the provider in the OAuth registry
  243. if oldSlug != provider.Slug {
  244. oauth.UnregisterCustomProvider(oldSlug)
  245. }
  246. oauth.RegisterOrUpdateCustomProvider(provider)
  247. c.JSON(http.StatusOK, gin.H{
  248. "success": true,
  249. "message": "更新成功",
  250. "data": toCustomOAuthProviderResponse(provider),
  251. })
  252. }
  253. // DeleteCustomOAuthProvider deletes a custom OAuth provider
  254. func DeleteCustomOAuthProvider(c *gin.Context) {
  255. idStr := c.Param("id")
  256. id, err := strconv.Atoi(idStr)
  257. if err != nil {
  258. common.ApiErrorMsg(c, "无效的 ID")
  259. return
  260. }
  261. // Get existing provider to get slug
  262. provider, err := model.GetCustomOAuthProviderById(id)
  263. if err != nil {
  264. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  265. return
  266. }
  267. // Check if there are any user bindings
  268. count, _ := model.GetBindingCountByProviderId(id)
  269. if count > 0 {
  270. common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
  271. return
  272. }
  273. if err := model.DeleteCustomOAuthProvider(id); err != nil {
  274. common.ApiError(c, err)
  275. return
  276. }
  277. // Unregister the provider from the OAuth registry
  278. oauth.UnregisterCustomProvider(provider.Slug)
  279. c.JSON(http.StatusOK, gin.H{
  280. "success": true,
  281. "message": "删除成功",
  282. })
  283. }
  284. // GetUserOAuthBindings returns all OAuth bindings for the current user
  285. func GetUserOAuthBindings(c *gin.Context) {
  286. userId := c.GetInt("id")
  287. if userId == 0 {
  288. common.ApiErrorMsg(c, "未登录")
  289. return
  290. }
  291. bindings, err := model.GetUserOAuthBindingsByUserId(userId)
  292. if err != nil {
  293. common.ApiError(c, err)
  294. return
  295. }
  296. // Build response with provider info
  297. type BindingResponse struct {
  298. ProviderId int `json:"provider_id"`
  299. ProviderName string `json:"provider_name"`
  300. ProviderSlug string `json:"provider_slug"`
  301. ProviderUserId string `json:"provider_user_id"`
  302. }
  303. response := make([]BindingResponse, 0)
  304. for _, binding := range bindings {
  305. provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
  306. if err != nil {
  307. continue // Skip if provider not found
  308. }
  309. response = append(response, BindingResponse{
  310. ProviderId: binding.ProviderId,
  311. ProviderName: provider.Name,
  312. ProviderSlug: provider.Slug,
  313. ProviderUserId: binding.ProviderUserId,
  314. })
  315. }
  316. c.JSON(http.StatusOK, gin.H{
  317. "success": true,
  318. "message": "",
  319. "data": response,
  320. })
  321. }
  322. // UnbindCustomOAuth unbinds a custom OAuth provider from the current user
  323. func UnbindCustomOAuth(c *gin.Context) {
  324. userId := c.GetInt("id")
  325. if userId == 0 {
  326. common.ApiErrorMsg(c, "未登录")
  327. return
  328. }
  329. providerIdStr := c.Param("provider_id")
  330. providerId, err := strconv.Atoi(providerIdStr)
  331. if err != nil {
  332. common.ApiErrorMsg(c, "无效的提供商 ID")
  333. return
  334. }
  335. if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
  336. common.ApiError(c, err)
  337. return
  338. }
  339. c.JSON(http.StatusOK, gin.H{
  340. "success": true,
  341. "message": "解绑成功",
  342. })
  343. }