|
|
@@ -10,6 +10,7 @@ import (
|
|
|
"io"
|
|
|
"log/slog"
|
|
|
"maps"
|
|
|
+ "net/http"
|
|
|
"os"
|
|
|
"slices"
|
|
|
"strings"
|
|
|
@@ -130,32 +131,42 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
|
|
|
|
|
|
mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
|
|
|
|
|
|
- if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
|
|
|
- slog.Info("Detected expired OAuth token, attempting refresh", "provider", providerCfg.ID)
|
|
|
- if refreshErr := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); refreshErr != nil {
|
|
|
- slog.Error("Failed to refresh OAuth token", "provider", providerCfg.ID, "error", refreshErr)
|
|
|
- return nil, refreshErr
|
|
|
- }
|
|
|
-
|
|
|
- // Rebuild models with refreshed token
|
|
|
- if updateErr := c.UpdateModels(ctx); updateErr != nil {
|
|
|
- slog.Error("Failed to update models after token refresh", "error", updateErr)
|
|
|
- return nil, updateErr
|
|
|
+ run := func() (*fantasy.AgentResult, error) {
|
|
|
+ return c.currentAgent.Run(ctx, SessionAgentCall{
|
|
|
+ SessionID: sessionID,
|
|
|
+ Prompt: prompt,
|
|
|
+ Attachments: attachments,
|
|
|
+ MaxOutputTokens: maxTokens,
|
|
|
+ ProviderOptions: mergedOptions,
|
|
|
+ Temperature: temp,
|
|
|
+ TopP: topP,
|
|
|
+ TopK: topK,
|
|
|
+ FrequencyPenalty: freqPenalty,
|
|
|
+ PresencePenalty: presPenalty,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ result, originalErr := run()
|
|
|
+
|
|
|
+ if c.isUnauthorized(originalErr) {
|
|
|
+ switch {
|
|
|
+ case providerCfg.OAuthToken != nil:
|
|
|
+ slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
|
|
|
+ if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
|
|
|
+ return nil, originalErr
|
|
|
+ }
|
|
|
+ slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
|
|
|
+ return run()
|
|
|
+ case strings.Contains(providerCfg.APIKeyTemplate, "$"):
|
|
|
+ slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
|
|
|
+ if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
|
|
|
+ return nil, originalErr
|
|
|
+ }
|
|
|
+ slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
|
|
|
+ return run()
|
|
|
}
|
|
|
}
|
|
|
- result, err := c.currentAgent.Run(ctx, SessionAgentCall{
|
|
|
- SessionID: sessionID,
|
|
|
- Prompt: prompt,
|
|
|
- Attachments: attachments,
|
|
|
- MaxOutputTokens: maxTokens,
|
|
|
- ProviderOptions: mergedOptions,
|
|
|
- Temperature: temp,
|
|
|
- TopP: topP,
|
|
|
- TopK: topK,
|
|
|
- FrequencyPenalty: freqPenalty,
|
|
|
- PresencePenalty: presPenalty,
|
|
|
- })
|
|
|
- return result, err
|
|
|
+
|
|
|
+ return result, originalErr
|
|
|
}
|
|
|
|
|
|
func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
|
|
|
@@ -773,3 +784,35 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
|
|
|
}
|
|
|
return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
|
|
|
}
|
|
|
+
|
|
|
+func (c *coordinator) isUnauthorized(err error) bool {
|
|
|
+ var providerErr *fantasy.ProviderError
|
|
|
+ return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
|
|
|
+}
|
|
|
+
|
|
|
+func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
|
|
|
+ if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
|
|
|
+ slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if err := c.UpdateModels(ctx); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
|
|
|
+ newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ providerCfg.APIKey = newAPIKey
|
|
|
+ c.cfg.Providers.Set(providerCfg.ID, providerCfg)
|
|
|
+
|
|
|
+ if err := c.UpdateModels(ctx); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|