Sfoglia il codice sorgente

ccm,ocm: unify HTTP request retry with fast retry and exponential backoff

世界 1 mese fa
parent
commit
f6821be8a3

+ 12 - 10
service/ccm/credential.go

@@ -2,6 +2,7 @@ package ccm
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"io"
 	"net/http"
@@ -142,7 +143,7 @@ func (c *oauthCredentials) needsRefresh() bool {
 	return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
 }
 
-func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
+func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
 	if credentials.RefreshToken == "" {
 		return nil, E.New("refresh token is empty")
 	}
@@ -156,15 +157,16 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
 		return nil, E.Cause(err, "marshal request")
 	}
 
-	request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
-	if err != nil {
-		return nil, err
-	}
-	request.Header.Set("Content-Type", "application/json")
-	request.Header.Set("Accept", "application/json")
-	request.Header.Set("User-Agent", ccmUserAgentValue)
-
-	response, err := httpClient.Do(request)
+	response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
+		request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
+		if err != nil {
+			return nil, err
+		}
+		request.Header.Set("Content-Type", "application/json")
+		request.Header.Set("Accept", "application/json")
+		request.Header.Set("User-Agent", ccmUserAgentValue)
+		return request, nil
+	})
 	if err != nil {
 		return nil, err
 	}

+ 8 - 8
service/ccm/credential_external.go

@@ -449,14 +449,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
 		Timeout:   5 * time.Second,
 	}
 
-	request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
-	if err != nil {
-		c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
-		return
-	}
-	request.Header.Set("Authorization", "Bearer "+c.token)
-
-	response, err := httpClient.Do(request)
+	response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
+		request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
+		if err != nil {
+			return nil, err
+		}
+		request.Header.Set("Authorization", "Bearer "+c.token)
+		return request, nil
+	})
 	if err != nil {
 		c.logger.Error("poll usage for ", c.tag, ": ", err)
 		c.stateMutex.Lock()

+ 40 - 29
service/ccm/credential_state.go

@@ -5,7 +5,6 @@ import (
 	"context"
 	stdTLS "crypto/tls"
 	"encoding/json"
-	"errors"
 	"io"
 	"math"
 	"math/rand/v2"
@@ -29,6 +28,38 @@ import (
 
 const defaultPollInterval = 60 * time.Minute
 
+const (
+	httpRetryMaxAttempts  = 3
+	httpRetryInitialDelay = 200 * time.Millisecond
+)
+
+func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
+	var lastError error
+	for attempt := range httpRetryMaxAttempts {
+		if attempt > 0 {
+			delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
+			select {
+			case <-ctx.Done():
+				return nil, lastError
+			case <-time.After(delay):
+			}
+		}
+		request, err := buildRequest()
+		if err != nil {
+			return nil, err
+		}
+		response, err := client.Do(request)
+		if err == nil {
+			return response, nil
+		}
+		lastError = err
+		if ctx.Err() != nil {
+			return nil, lastError
+		}
+	}
+	return nil, lastError
+}
+
 type credentialState struct {
 	fiveHourUtilization       float64
 	fiveHourReset             time.Time
@@ -46,6 +77,7 @@ type credentialState struct {
 
 type defaultCredential struct {
 	tag                string
+	serviceContext     context.Context
 	credentialPath     string
 	credentialFilePath string
 	credentials        *oauthCredentials
@@ -151,6 +183,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
 	requestContext, cancelRequests := context.WithCancel(context.Background())
 	credential := &defaultCredential{
 		tag:            tag,
+		serviceContext: ctx,
 		credentialPath: options.CredentialPath,
 		reserve5h:      reserve5h,
 		reserveWeekly:  reserveWeekly,
@@ -231,7 +264,7 @@ func (c *defaultCredential) getAccessToken() (string, error) {
 	}
 
 	baseCredentials := cloneCredentials(c.credentials)
-	newCredentials, err := refreshToken(c.httpClient, c.credentials)
+	newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials)
 	if err != nil {
 		return "", err
 	}
@@ -498,16 +531,6 @@ func (c *defaultCredential) earliestReset() time.Time {
 	return earliest
 }
 
-const pollUsageMaxRetries = 3
-
-func isTimeoutError(err error) bool {
-	var netErr net.Error
-	if errors.As(err, &netErr) {
-		return netErr.Timeout()
-	}
-	return false
-}
-
 func (c *defaultCredential) pollUsage(ctx context.Context) {
 	if !c.pollAccess.TryLock() {
 		return
@@ -531,30 +554,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
 		Timeout:   5 * time.Second,
 	}
 
-	var response *http.Response
-	for attempt := range pollUsageMaxRetries {
+	response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
 		request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
 		if err != nil {
-			c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
-			return
+			return nil, err
 		}
 		request.Header.Set("Authorization", "Bearer "+accessToken)
 		request.Header.Set("Content-Type", "application/json")
 		request.Header.Set("User-Agent", ccmUserAgentValue)
 		request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
-
-		response, err = httpClient.Do(request)
-		if err == nil {
-			break
-		}
-		if !isTimeoutError(err) {
-			c.logger.Error("poll usage for ", c.tag, ": ", err)
-			return
-		}
-		if attempt < pollUsageMaxRetries-1 {
-			c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")")
-			continue
-		}
+		return request, nil
+	})
+	if err != nil {
 		c.logger.Error("poll usage for ", c.tag, ": ", err)
 		return
 	}

+ 11 - 9
service/ocm/credential.go

@@ -2,6 +2,7 @@ package ocm
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"io"
 	"net/http"
@@ -118,7 +119,7 @@ func (c *oauthCredentials) needsRefresh() bool {
 	return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
 }
 
-func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
+func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
 	if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
 		return nil, E.New("refresh token is empty")
 	}
@@ -133,14 +134,15 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
 		return nil, E.Cause(err, "marshal request")
 	}
 
-	request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
-	if err != nil {
-		return nil, err
-	}
-	request.Header.Set("Content-Type", "application/json")
-	request.Header.Set("Accept", "application/json")
-
-	response, err := httpClient.Do(request)
+	response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
+		request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
+		if err != nil {
+			return nil, err
+		}
+		request.Header.Set("Content-Type", "application/json")
+		request.Header.Set("Accept", "application/json")
+		return request, nil
+	})
 	if err != nil {
 		return nil, err
 	}

+ 8 - 8
service/ocm/credential_external.go

@@ -485,14 +485,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
 		Timeout:   5 * time.Second,
 	}
 
-	request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
-	if err != nil {
-		c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
-		return
-	}
-	request.Header.Set("Authorization", "Bearer "+c.token)
-
-	response, err := httpClient.Do(request)
+	response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
+		request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
+		if err != nil {
+			return nil, err
+		}
+		request.Header.Set("Authorization", "Bearer "+c.token)
+		return request, nil
+	})
 	if err != nil {
 		c.logger.Error("poll usage for ", c.tag, ": ", err)
 		c.stateMutex.Lock()

+ 40 - 29
service/ocm/credential_state.go

@@ -5,7 +5,6 @@ import (
 	"context"
 	stdTLS "crypto/tls"
 	"encoding/json"
-	"errors"
 	"io"
 	"math/rand/v2"
 	"net"
@@ -29,6 +28,38 @@ import (
 
 const defaultPollInterval = 60 * time.Minute
 
+const (
+	httpRetryMaxAttempts  = 3
+	httpRetryInitialDelay = 200 * time.Millisecond
+)
+
+func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
+	var lastError error
+	for attempt := range httpRetryMaxAttempts {
+		if attempt > 0 {
+			delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
+			select {
+			case <-ctx.Done():
+				return nil, lastError
+			case <-time.After(delay):
+			}
+		}
+		request, err := buildRequest()
+		if err != nil {
+			return nil, err
+		}
+		response, err := client.Do(request)
+		if err == nil {
+			return response, nil
+		}
+		lastError = err
+		if ctx.Err() != nil {
+			return nil, lastError
+		}
+	}
+	return nil, lastError
+}
+
 type credentialState struct {
 	fiveHourUtilization       float64
 	fiveHourReset             time.Time
@@ -46,6 +77,7 @@ type credentialState struct {
 
 type defaultCredential struct {
 	tag                string
+	serviceContext     context.Context
 	credentialPath     string
 	credentialFilePath string
 	credentials        *oauthCredentials
@@ -159,6 +191,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef
 	requestContext, cancelRequests := context.WithCancel(context.Background())
 	credential := &defaultCredential{
 		tag:            tag,
+		serviceContext: ctx,
 		credentialPath: options.CredentialPath,
 		reserve5h:      reserve5h,
 		reserveWeekly:  reserveWeekly,
@@ -240,7 +273,7 @@ func (c *defaultCredential) getAccessToken() (string, error) {
 	}
 
 	baseCredentials := cloneCredentials(c.credentials)
-	newCredentials, err := refreshToken(c.httpClient, c.credentials)
+	newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials)
 	if err != nil {
 		return "", err
 	}
@@ -507,16 +540,6 @@ func (c *defaultCredential) earliestReset() time.Time {
 	return earliest
 }
 
-const pollUsageMaxRetries = 3
-
-func isTimeoutError(err error) bool {
-	var netErr net.Error
-	if errors.As(err, &netErr) {
-		return netErr.Timeout()
-	}
-	return false
-}
-
 func (c *defaultCredential) pollUsage(ctx context.Context) {
 	if !c.pollAccess.TryLock() {
 		return
@@ -551,30 +574,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
 		Timeout:   5 * time.Second,
 	}
 
-	var response *http.Response
-	for attempt := range pollUsageMaxRetries {
+	response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
 		request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
 		if err != nil {
-			c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
-			return
+			return nil, err
 		}
 		request.Header.Set("Authorization", "Bearer "+accessToken)
 		if accountID != "" {
 			request.Header.Set("ChatGPT-Account-Id", accountID)
 		}
-
-		response, err = httpClient.Do(request)
-		if err == nil {
-			break
-		}
-		if !isTimeoutError(err) {
-			c.logger.Error("poll usage for ", c.tag, ": ", err)
-			return
-		}
-		if attempt < pollUsageMaxRetries-1 {
-			c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")")
-			continue
-		}
+		return request, nil
+	})
+	if err != nil {
 		c.logger.Error("poll usage for ", c.tag, ": ", err)
 		return
 	}