2
0

codex_oauth.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel/codex"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/gin-contrib/sessions"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type codexOAuthCompleteRequest struct {
  20. Input string `json:"input"`
  21. }
  22. func codexOAuthSessionKey(channelID int, field string) string {
  23. return fmt.Sprintf("codex_oauth_%s_%d", field, channelID)
  24. }
  25. func parseCodexAuthorizationInput(input string) (code string, state string, err error) {
  26. v := strings.TrimSpace(input)
  27. if v == "" {
  28. return "", "", errors.New("empty input")
  29. }
  30. if strings.Contains(v, "#") {
  31. parts := strings.SplitN(v, "#", 2)
  32. code = strings.TrimSpace(parts[0])
  33. state = strings.TrimSpace(parts[1])
  34. return code, state, nil
  35. }
  36. if strings.Contains(v, "code=") {
  37. u, parseErr := url.Parse(v)
  38. if parseErr == nil {
  39. q := u.Query()
  40. code = strings.TrimSpace(q.Get("code"))
  41. state = strings.TrimSpace(q.Get("state"))
  42. return code, state, nil
  43. }
  44. q, parseErr := url.ParseQuery(v)
  45. if parseErr == nil {
  46. code = strings.TrimSpace(q.Get("code"))
  47. state = strings.TrimSpace(q.Get("state"))
  48. return code, state, nil
  49. }
  50. }
  51. code = v
  52. return code, "", nil
  53. }
  54. func StartCodexOAuth(c *gin.Context) {
  55. startCodexOAuthWithChannelID(c, 0)
  56. }
  57. func StartCodexOAuthForChannel(c *gin.Context) {
  58. channelID, err := strconv.Atoi(c.Param("id"))
  59. if err != nil {
  60. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  61. return
  62. }
  63. startCodexOAuthWithChannelID(c, channelID)
  64. }
  65. func startCodexOAuthWithChannelID(c *gin.Context, channelID int) {
  66. if channelID > 0 {
  67. ch, err := model.GetChannelById(channelID, false)
  68. if err != nil {
  69. common.ApiError(c, err)
  70. return
  71. }
  72. if ch == nil {
  73. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
  74. return
  75. }
  76. if ch.Type != constant.ChannelTypeCodex {
  77. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
  78. return
  79. }
  80. }
  81. flow, err := service.CreateCodexOAuthAuthorizationFlow()
  82. if err != nil {
  83. common.ApiError(c, err)
  84. return
  85. }
  86. session := sessions.Default(c)
  87. session.Set(codexOAuthSessionKey(channelID, "state"), flow.State)
  88. session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier)
  89. session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix())
  90. _ = session.Save()
  91. c.JSON(http.StatusOK, gin.H{
  92. "success": true,
  93. "message": "",
  94. "data": gin.H{
  95. "authorize_url": flow.AuthorizeURL,
  96. },
  97. })
  98. }
  99. func CompleteCodexOAuth(c *gin.Context) {
  100. completeCodexOAuthWithChannelID(c, 0)
  101. }
  102. func CompleteCodexOAuthForChannel(c *gin.Context) {
  103. channelID, err := strconv.Atoi(c.Param("id"))
  104. if err != nil {
  105. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  106. return
  107. }
  108. completeCodexOAuthWithChannelID(c, channelID)
  109. }
  110. func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
  111. req := codexOAuthCompleteRequest{}
  112. if err := c.ShouldBindJSON(&req); err != nil {
  113. common.ApiError(c, err)
  114. return
  115. }
  116. code, state, err := parseCodexAuthorizationInput(req.Input)
  117. if err != nil {
  118. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  119. return
  120. }
  121. if strings.TrimSpace(code) == "" {
  122. c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"})
  123. return
  124. }
  125. if strings.TrimSpace(state) == "" {
  126. c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"})
  127. return
  128. }
  129. if channelID > 0 {
  130. ch, err := model.GetChannelById(channelID, false)
  131. if err != nil {
  132. common.ApiError(c, err)
  133. return
  134. }
  135. if ch == nil {
  136. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
  137. return
  138. }
  139. if ch.Type != constant.ChannelTypeCodex {
  140. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
  141. return
  142. }
  143. }
  144. session := sessions.Default(c)
  145. expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string)
  146. verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string)
  147. if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" {
  148. c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"})
  149. return
  150. }
  151. if state != expectedState {
  152. c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"})
  153. return
  154. }
  155. ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
  156. defer cancel()
  157. tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
  158. if err != nil {
  159. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  160. return
  161. }
  162. accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken)
  163. if !ok {
  164. c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"})
  165. return
  166. }
  167. email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken)
  168. key := codex.OAuthKey{
  169. AccessToken: tokenRes.AccessToken,
  170. RefreshToken: tokenRes.RefreshToken,
  171. AccountID: accountID,
  172. LastRefresh: time.Now().Format(time.RFC3339),
  173. Expired: tokenRes.ExpiresAt.Format(time.RFC3339),
  174. Email: email,
  175. Type: "codex",
  176. }
  177. encoded, err := common.Marshal(key)
  178. if err != nil {
  179. common.ApiError(c, err)
  180. return
  181. }
  182. session.Delete(codexOAuthSessionKey(channelID, "state"))
  183. session.Delete(codexOAuthSessionKey(channelID, "verifier"))
  184. session.Delete(codexOAuthSessionKey(channelID, "created_at"))
  185. _ = session.Save()
  186. if channelID > 0 {
  187. if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil {
  188. common.ApiError(c, err)
  189. return
  190. }
  191. model.InitChannelCache()
  192. service.ResetProxyClientCache()
  193. c.JSON(http.StatusOK, gin.H{
  194. "success": true,
  195. "message": "saved",
  196. "data": gin.H{
  197. "channel_id": channelID,
  198. "account_id": accountID,
  199. "email": email,
  200. "expires_at": key.Expired,
  201. "last_refresh": key.LastRefresh,
  202. },
  203. })
  204. return
  205. }
  206. c.JSON(http.StatusOK, gin.H{
  207. "success": true,
  208. "message": "generated",
  209. "data": gin.H{
  210. "key": string(encoded),
  211. "account_id": accountID,
  212. "email": email,
  213. "expires_at": key.Expired,
  214. "last_refresh": key.LastRefresh,
  215. },
  216. })
  217. }