| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- package controller
- import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "net/url"
- "strconv"
- "strings"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay/channel/codex"
- "github.com/QuantumNous/new-api/service"
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- )
- type codexOAuthCompleteRequest struct {
- Input string `json:"input"`
- }
- func codexOAuthSessionKey(channelID int, field string) string {
- return fmt.Sprintf("codex_oauth_%s_%d", field, channelID)
- }
- func parseCodexAuthorizationInput(input string) (code string, state string, err error) {
- v := strings.TrimSpace(input)
- if v == "" {
- return "", "", errors.New("empty input")
- }
- if strings.Contains(v, "#") {
- parts := strings.SplitN(v, "#", 2)
- code = strings.TrimSpace(parts[0])
- state = strings.TrimSpace(parts[1])
- return code, state, nil
- }
- if strings.Contains(v, "code=") {
- u, parseErr := url.Parse(v)
- if parseErr == nil {
- q := u.Query()
- code = strings.TrimSpace(q.Get("code"))
- state = strings.TrimSpace(q.Get("state"))
- return code, state, nil
- }
- q, parseErr := url.ParseQuery(v)
- if parseErr == nil {
- code = strings.TrimSpace(q.Get("code"))
- state = strings.TrimSpace(q.Get("state"))
- return code, state, nil
- }
- }
- code = v
- return code, "", nil
- }
- func StartCodexOAuth(c *gin.Context) {
- startCodexOAuthWithChannelID(c, 0)
- }
- func StartCodexOAuthForChannel(c *gin.Context) {
- channelID, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
- return
- }
- startCodexOAuthWithChannelID(c, channelID)
- }
- func startCodexOAuthWithChannelID(c *gin.Context, channelID int) {
- if channelID > 0 {
- ch, err := model.GetChannelById(channelID, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if ch == nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
- return
- }
- if ch.Type != constant.ChannelTypeCodex {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
- return
- }
- }
- flow, err := service.CreateCodexOAuthAuthorizationFlow()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- session := sessions.Default(c)
- session.Set(codexOAuthSessionKey(channelID, "state"), flow.State)
- session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier)
- session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix())
- _ = session.Save()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "authorize_url": flow.AuthorizeURL,
- },
- })
- }
- func CompleteCodexOAuth(c *gin.Context) {
- completeCodexOAuthWithChannelID(c, 0)
- }
- func CompleteCodexOAuthForChannel(c *gin.Context) {
- channelID, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
- return
- }
- completeCodexOAuthWithChannelID(c, channelID)
- }
- func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
- req := codexOAuthCompleteRequest{}
- if err := c.ShouldBindJSON(&req); err != nil {
- common.ApiError(c, err)
- return
- }
- code, state, err := parseCodexAuthorizationInput(req.Input)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- if strings.TrimSpace(code) == "" {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"})
- return
- }
- if strings.TrimSpace(state) == "" {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"})
- return
- }
- if channelID > 0 {
- ch, err := model.GetChannelById(channelID, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if ch == nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
- return
- }
- if ch.Type != constant.ChannelTypeCodex {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
- return
- }
- }
- session := sessions.Default(c)
- expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string)
- verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string)
- if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"})
- return
- }
- if state != expectedState {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"})
- return
- }
- ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
- defer cancel()
- tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken)
- if !ok {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"})
- return
- }
- email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken)
- key := codex.OAuthKey{
- AccessToken: tokenRes.AccessToken,
- RefreshToken: tokenRes.RefreshToken,
- AccountID: accountID,
- LastRefresh: time.Now().Format(time.RFC3339),
- Expired: tokenRes.ExpiresAt.Format(time.RFC3339),
- Email: email,
- Type: "codex",
- }
- encoded, err := common.Marshal(key)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- session.Delete(codexOAuthSessionKey(channelID, "state"))
- session.Delete(codexOAuthSessionKey(channelID, "verifier"))
- session.Delete(codexOAuthSessionKey(channelID, "created_at"))
- _ = session.Save()
- if channelID > 0 {
- if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- service.ResetProxyClientCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "saved",
- "data": gin.H{
- "channel_id": channelID,
- "account_id": accountID,
- "email": email,
- "expires_at": key.Expired,
- "last_refresh": key.LastRefresh,
- },
- })
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "generated",
- "data": gin.H{
- "key": string(encoded),
- "account_id": accountID,
- "email": email,
- "expires_at": key.Expired,
- "last_refresh": key.LastRefresh,
- },
- })
- }
|