codex_oauth.go 6.5 KB


  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel/codex"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/gin-contrib/sessions"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type codexOAuthCompleteRequest struct {
  20. Input string `json:"input"`
  21. }
  22. func codexOAuthSessionKey(channelID int, field string) string {
  23. return fmt.Sprintf("codex_oauth_%s_%d", field, channelID)
  24. }
  25. func parseCodexAuthorizationInput(input string) (code string, state string, err error) {
  26. v := strings.TrimSpace(input)
  27. if v == "" {
  28. return "", "", errors.New("empty input")
  29. }
  30. if strings.Contains(v, "#") {
  31. parts := strings.SplitN(v, "#", 2)
  32. code = strings.TrimSpace(parts[0])
  33. state = strings.TrimSpace(parts[1])
  34. return code, state, nil
  35. }
  36. if strings.Contains(v, "code=") {
  37. u, parseErr := url.Parse(v)
  38. if parseErr == nil {
  39. q := u.Query()
  40. code = strings.TrimSpace(q.Get("code"))
  41. state = strings.TrimSpace(q.Get("state"))
  42. return code, state, nil
  43. }
  44. q, parseErr := url.ParseQuery(v)
  45. if parseErr == nil {
  46. code = strings.TrimSpace(q.Get("code"))
  47. state = strings.TrimSpace(q.Get("state"))
  48. return code, state, nil
  49. }
  50. }
  51. code = v
  52. return code, "", nil
  53. }
  54. func StartCodexOAuth(c *gin.Context) {
  55. startCodexOAuthWithChannelID(c, 0)
  56. }
  57. func StartCodexOAuthForChannel(c *gin.Context) {
  58. channelID, err := strconv.Atoi(c.Param("id"))
  59. if err != nil {
  60. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  61. return
  62. }
  63. startCodexOAuthWithChannelID(c, channelID)
  64. }
  65. func startCodexOAuthWithChannelID(c *gin.Context, channelID int) {
  66. if channelID > 0 {
  67. ch, err := model.GetChannelById(channelID, false)
  68. if err != nil {
  69. common.ApiError(c, err)
  70. return
  71. }
  72. if ch == nil {
  73. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
  74. return
  75. }
  76. if ch.Type != constant.ChannelTypeCodex {
  77. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
  78. return
  79. }
  80. }
  81. flow, err := service.CreateCodexOAuthAuthorizationFlow()
  82. if err != nil {
  83. common.ApiError(c, err)
  84. return
  85. }
  86. session := sessions.Default(c)
  87. session.Set(codexOAuthSessionKey(channelID, "state"), flow.State)
  88. session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier)
  89. session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix())
  90. _ = session.Save()
  91. c.JSON(http.StatusOK, gin.H{
  92. "success": true,
  93. "message": "",
  94. "data": gin.H{
  95. "authorize_url": flow.AuthorizeURL,
  96. },
  97. })
  98. }
  99. func CompleteCodexOAuth(c *gin.Context) {
  100. completeCodexOAuthWithChannelID(c, 0)
  101. }
  102. func CompleteCodexOAuthForChannel(c *gin.Context) {
  103. channelID, err := strconv.Atoi(c.Param("id"))
  104. if err != nil {
  105. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  106. return
  107. }
  108. completeCodexOAuthWithChannelID(c, channelID)
  109. }
  110. func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
  111. req := codexOAuthCompleteRequest{}
  112. if err := c.ShouldBindJSON(&req); err != nil {
  113. common.ApiError(c, err)
  114. return
  115. }
  116. code, state, err := parseCodexAuthorizationInput(req.Input)
  117. if err != nil {
  118. common.SysError("failed to parse codex authorization input: " + err.Error())
  119. c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析授权信息失败,请检查输入格式"})
  120. return
  121. }
  122. if strings.TrimSpace(code) == "" {
  123. c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"})
  124. return
  125. }
  126. if strings.TrimSpace(state) == "" {
  127. c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"})
  128. return
  129. }
  130. if channelID > 0 {
  131. ch, err := model.GetChannelById(channelID, false)
  132. if err != nil {
  133. common.ApiError(c, err)
  134. return
  135. }
  136. if ch == nil {
  137. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
  138. return
  139. }
  140. if ch.Type != constant.ChannelTypeCodex {
  141. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
  142. return
  143. }
  144. }
  145. session := sessions.Default(c)
  146. expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string)
  147. verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string)
  148. if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" {
  149. c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"})
  150. return
  151. }
  152. if state != expectedState {
  153. c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"})
  154. return
  155. }
  156. ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
  157. defer cancel()
  158. tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
  159. if err != nil {
  160. common.SysError("failed to exchange codex authorization code: " + err.Error())
  161. c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})
  162. return
  163. }
  164. accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken)
  165. if !ok {
  166. c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"})
  167. return
  168. }
  169. email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken)
  170. key := codex.OAuthKey{
  171. AccessToken: tokenRes.AccessToken,
  172. RefreshToken: tokenRes.RefreshToken,
  173. AccountID: accountID,
  174. LastRefresh: time.Now().Format(time.RFC3339),
  175. Expired: tokenRes.ExpiresAt.Format(time.RFC3339),
  176. Email: email,
  177. Type: "codex",
  178. }
  179. encoded, err := common.Marshal(key)
  180. if err != nil {
  181. common.ApiError(c, err)
  182. return
  183. }
  184. session.Delete(codexOAuthSessionKey(channelID, "state"))
  185. session.Delete(codexOAuthSessionKey(channelID, "verifier"))
  186. session.Delete(codexOAuthSessionKey(channelID, "created_at"))
  187. _ = session.Save()
  188. if channelID > 0 {
  189. if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil {
  190. common.ApiError(c, err)
  191. return
  192. }
  193. model.InitChannelCache()
  194. service.ResetProxyClientCache()
  195. c.JSON(http.StatusOK, gin.H{
  196. "success": true,
  197. "message": "saved",
  198. "data": gin.H{
  199. "channel_id": channelID,
  200. "account_id": accountID,
  201. "email": email,
  202. "expires_at": key.Expired,
  203. "last_refresh": key.LastRefresh,
  204. },
  205. })
  206. return
  207. }
  208. c.JSON(http.StatusOK, gin.H{
  209. "success": true,
  210. "message": "generated",
  211. "data": gin.H{
  212. "key": string(encoded),
  213. "account_id": accountID,
  214. "email": email,
  215. "expires_at": key.Expired,
  216. "last_refresh": key.LastRefresh,
  217. },
  218. })
  219. }