| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926 |
- package agent
- import (
- "bytes"
- "cmp"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "maps"
- "net/http"
- "os"
- "slices"
- "strings"
- "charm.land/catwalk/pkg/catwalk"
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/agent/hyper"
- "github.com/charmbracelet/crush/internal/agent/prompt"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/filetracker"
- "github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/log"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/oauth/copilot"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/session"
- "golang.org/x/sync/errgroup"
- "charm.land/fantasy/providers/anthropic"
- "charm.land/fantasy/providers/azure"
- "charm.land/fantasy/providers/bedrock"
- "charm.land/fantasy/providers/google"
- "charm.land/fantasy/providers/openai"
- "charm.land/fantasy/providers/openaicompat"
- "charm.land/fantasy/providers/openrouter"
- "charm.land/fantasy/providers/vercel"
- openaisdk "github.com/openai/openai-go/v2/option"
- "github.com/qjebbs/go-jsons"
- )
- type Coordinator interface {
- // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
- // SetMainAgent(string)
- Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
- Cancel(sessionID string)
- CancelAll()
- IsSessionBusy(sessionID string) bool
- IsBusy() bool
- QueuedPrompts(sessionID string) int
- QueuedPromptsList(sessionID string) []string
- ClearQueue(sessionID string)
- Summarize(context.Context, string) error
- Model() Model
- UpdateModels(ctx context.Context) error
- }
- type coordinator struct {
- cfg *config.Config
- sessions session.Service
- messages message.Service
- permissions permission.Service
- history history.Service
- filetracker filetracker.Service
- lspManager *lsp.Manager
- currentAgent SessionAgent
- agents map[string]SessionAgent
- readyWg errgroup.Group
- }
- func NewCoordinator(
- ctx context.Context,
- cfg *config.Config,
- sessions session.Service,
- messages message.Service,
- permissions permission.Service,
- history history.Service,
- filetracker filetracker.Service,
- lspManager *lsp.Manager,
- ) (Coordinator, error) {
- c := &coordinator{
- cfg: cfg,
- sessions: sessions,
- messages: messages,
- permissions: permissions,
- history: history,
- filetracker: filetracker,
- lspManager: lspManager,
- agents: make(map[string]SessionAgent),
- }
- agentCfg, ok := cfg.Agents[config.AgentCoder]
- if !ok {
- return nil, errors.New("coder agent not configured")
- }
- // TODO: make this dynamic when we support multiple agents
- prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
- if err != nil {
- return nil, err
- }
- agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
- if err != nil {
- return nil, err
- }
- c.currentAgent = agent
- c.agents[config.AgentCoder] = agent
- return c, nil
- }
- // Run implements Coordinator.
- func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
- if err := c.readyWg.Wait(); err != nil {
- return nil, err
- }
- // refresh models before each run
- if err := c.UpdateModels(ctx); err != nil {
- return nil, fmt.Errorf("failed to update models: %w", err)
- }
- model := c.currentAgent.Model()
- maxTokens := model.CatwalkCfg.DefaultMaxTokens
- if model.ModelCfg.MaxTokens != 0 {
- maxTokens = model.ModelCfg.MaxTokens
- }
- if !model.CatwalkCfg.SupportsImages && attachments != nil {
- // filter out image attachments
- filteredAttachments := make([]message.Attachment, 0, len(attachments))
- for _, att := range attachments {
- if att.IsText() {
- filteredAttachments = append(filteredAttachments, att)
- }
- }
- attachments = filteredAttachments
- }
- providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
- if !ok {
- return nil, errors.New("model provider not configured")
- }
- mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
- if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
- slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- return nil, err
- }
- }
- run := func() (*fantasy.AgentResult, error) {
- return c.currentAgent.Run(ctx, SessionAgentCall{
- SessionID: sessionID,
- Prompt: prompt,
- Attachments: attachments,
- MaxOutputTokens: maxTokens,
- ProviderOptions: mergedOptions,
- Temperature: temp,
- TopP: topP,
- TopK: topK,
- FrequencyPenalty: freqPenalty,
- PresencePenalty: presPenalty,
- })
- }
- result, originalErr := run()
- if c.isUnauthorized(originalErr) {
- switch {
- case providerCfg.OAuthToken != nil:
- slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
- return run()
- case strings.Contains(providerCfg.APIKeyTemplate, "$"):
- slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
- if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
- return run()
- }
- }
- return result, originalErr
- }
- func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
- options := fantasy.ProviderOptions{}
- cfgOpts := []byte("{}")
- providerCfgOpts := []byte("{}")
- catwalkOpts := []byte("{}")
- if model.ModelCfg.ProviderOptions != nil {
- data, err := json.Marshal(model.ModelCfg.ProviderOptions)
- if err == nil {
- cfgOpts = data
- }
- }
- if providerCfg.ProviderOptions != nil {
- data, err := json.Marshal(providerCfg.ProviderOptions)
- if err == nil {
- providerCfgOpts = data
- }
- }
- if model.CatwalkCfg.Options.ProviderOptions != nil {
- data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
- if err == nil {
- catwalkOpts = data
- }
- }
- readers := []io.Reader{
- bytes.NewReader(catwalkOpts),
- bytes.NewReader(providerCfgOpts),
- bytes.NewReader(cfgOpts),
- }
- got, err := jsons.Merge(readers)
- if err != nil {
- slog.Error("Could not merge call config", "err", err)
- return options
- }
- mergedOptions := make(map[string]any)
- err = json.Unmarshal([]byte(got), &mergedOptions)
- if err != nil {
- slog.Error("Could not create config for call", "err", err)
- return options
- }
- providerType := providerCfg.Type
- if providerType == "hyper" {
- if strings.Contains(model.CatwalkCfg.ID, "claude") {
- providerType = anthropic.Name
- } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
- providerType = openai.Name
- } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
- providerType = google.Name
- } else {
- providerType = openaicompat.Name
- }
- }
- switch providerType {
- case openai.Name, azure.Name:
- _, hasReasoningEffort := mergedOptions["reasoning_effort"]
- if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
- }
- if openai.IsResponsesModel(model.CatwalkCfg.ID) {
- if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
- mergedOptions["reasoning_summary"] = "auto"
- mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
- }
- parsed, err := openai.ParseResponsesOptions(mergedOptions)
- if err == nil {
- options[openai.Name] = parsed
- }
- } else {
- parsed, err := openai.ParseOptions(mergedOptions)
- if err == nil {
- options[openai.Name] = parsed
- }
- }
- case anthropic.Name:
- _, hasThink := mergedOptions["thinking"]
- if !hasThink && model.ModelCfg.Think {
- mergedOptions["thinking"] = map[string]any{
- // TODO: kujtim see if we need to make this dynamic
- "budget_tokens": 2000,
- }
- }
- parsed, err := anthropic.ParseOptions(mergedOptions)
- if err == nil {
- options[anthropic.Name] = parsed
- }
- case openrouter.Name:
- _, hasReasoning := mergedOptions["reasoning"]
- if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning"] = map[string]any{
- "enabled": true,
- "effort": model.ModelCfg.ReasoningEffort,
- }
- }
- parsed, err := openrouter.ParseOptions(mergedOptions)
- if err == nil {
- options[openrouter.Name] = parsed
- }
- case vercel.Name:
- _, hasReasoning := mergedOptions["reasoning"]
- if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning"] = map[string]any{
- "enabled": true,
- "effort": model.ModelCfg.ReasoningEffort,
- }
- }
- parsed, err := vercel.ParseOptions(mergedOptions)
- if err == nil {
- options[vercel.Name] = parsed
- }
- case google.Name:
- _, hasReasoning := mergedOptions["thinking_config"]
- if !hasReasoning {
- mergedOptions["thinking_config"] = map[string]any{
- "thinking_budget": 2000,
- "include_thoughts": true,
- }
- }
- parsed, err := google.ParseOptions(mergedOptions)
- if err == nil {
- options[google.Name] = parsed
- }
- case openaicompat.Name:
- _, hasReasoningEffort := mergedOptions["reasoning_effort"]
- if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
- }
- parsed, err := openaicompat.ParseOptions(mergedOptions)
- if err == nil {
- options[openaicompat.Name] = parsed
- }
- }
- return options
- }
- func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
- modelOptions := getProviderOptions(model, cfg)
- temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
- topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
- topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
- freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
- presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
- return modelOptions, temp, topP, topK, freqPenalty, presPenalty
- }
- func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
- large, small, err := c.buildAgentModels(ctx, isSubAgent)
- if err != nil {
- return nil, err
- }
- largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
- result := NewSessionAgent(SessionAgentOptions{
- large,
- small,
- largeProviderCfg.SystemPromptPrefix,
- "",
- isSubAgent,
- c.cfg.Options.DisableAutoSummarize,
- c.permissions.SkipRequests(),
- c.sessions,
- c.messages,
- nil,
- })
- c.readyWg.Go(func() error {
- systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
- if err != nil {
- return err
- }
- result.SetSystemPrompt(systemPrompt)
- return nil
- })
- c.readyWg.Go(func() error {
- tools, err := c.buildTools(ctx, agent)
- if err != nil {
- return err
- }
- result.SetTools(tools)
- return nil
- })
- return result, nil
- }
- func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
- var allTools []fantasy.AgentTool
- if slices.Contains(agent.AllowedTools, AgentToolName) {
- agentTool, err := c.agentTool(ctx)
- if err != nil {
- return nil, err
- }
- allTools = append(allTools, agentTool)
- }
- if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
- agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
- if err != nil {
- return nil, err
- }
- allTools = append(allTools, agenticFetchTool)
- }
- // Get the model name for the agent
- modelName := ""
- if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
- if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
- modelName = model.Name
- }
- }
- allTools = append(allTools,
- tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
- tools.NewJobOutputTool(),
- tools.NewJobKillTool(),
- tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
- tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
- tools.NewGlobTool(c.cfg.WorkingDir()),
- tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
- tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
- tools.NewSourcegraphTool(nil),
- tools.NewTodosTool(c.sessions),
- tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
- tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- )
- // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
- if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
- allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
- }
- if len(c.cfg.MCP) > 0 {
- allTools = append(
- allTools,
- tools.NewListMCPResourcesTool(c.cfg, c.permissions),
- tools.NewReadMCPResourceTool(c.cfg, c.permissions),
- )
- }
- var filteredTools []fantasy.AgentTool
- for _, tool := range allTools {
- if slices.Contains(agent.AllowedTools, tool.Info().Name) {
- filteredTools = append(filteredTools, tool)
- }
- }
- for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
- if agent.AllowedMCP == nil {
- // No MCP restrictions
- filteredTools = append(filteredTools, tool)
- continue
- }
- if len(agent.AllowedMCP) == 0 {
- // No MCPs allowed
- slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
- break
- }
- for mcp, tools := range agent.AllowedMCP {
- if mcp != tool.MCP() {
- continue
- }
- if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
- filteredTools = append(filteredTools, tool)
- }
- }
- slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
- }
- slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
- return strings.Compare(a.Info().Name, b.Info().Name)
- })
- return filteredTools, nil
- }
- // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
- func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
- largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
- if !ok {
- return Model{}, Model{}, errors.New("large model not selected")
- }
- smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
- if !ok {
- return Model{}, Model{}, errors.New("small model not selected")
- }
- largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
- if !ok {
- return Model{}, Model{}, errors.New("large model provider not configured")
- }
- largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
- if err != nil {
- return Model{}, Model{}, err
- }
- smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
- if !ok {
- return Model{}, Model{}, errors.New("large model provider not configured")
- }
- smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg, true)
- if err != nil {
- return Model{}, Model{}, err
- }
- var largeCatwalkModel *catwalk.Model
- var smallCatwalkModel *catwalk.Model
- for _, m := range largeProviderCfg.Models {
- if m.ID == largeModelCfg.Model {
- largeCatwalkModel = &m
- }
- }
- for _, m := range smallProviderCfg.Models {
- if m.ID == smallModelCfg.Model {
- smallCatwalkModel = &m
- }
- }
- if largeCatwalkModel == nil {
- return Model{}, Model{}, errors.New("large model not found in provider config")
- }
- if smallCatwalkModel == nil {
- return Model{}, Model{}, errors.New("small model not found in provider config")
- }
- largeModelID := largeModelCfg.Model
- smallModelID := smallModelCfg.Model
- if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
- largeModelID += ":exacto"
- }
- if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
- smallModelID += ":exacto"
- }
- largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
- if err != nil {
- return Model{}, Model{}, err
- }
- smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
- if err != nil {
- return Model{}, Model{}, err
- }
- return Model{
- Model: largeModel,
- CatwalkCfg: *largeCatwalkModel,
- ModelCfg: largeModelCfg,
- }, Model{
- Model: smallModel,
- CatwalkCfg: *smallCatwalkModel,
- ModelCfg: smallModelCfg,
- }, nil
- }
- func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- var opts []anthropic.Option
- if strings.HasPrefix(apiKey, "Bearer ") {
- // NOTE: Prevent the SDK from picking up the API key from env.
- os.Setenv("ANTHROPIC_API_KEY", "")
- headers["Authorization"] = apiKey
- } else if apiKey != "" {
- // X-Api-Key header
- opts = append(opts, anthropic.WithAPIKey(apiKey))
- }
- if len(headers) > 0 {
- opts = append(opts, anthropic.WithHeaders(headers))
- }
- if baseURL != "" {
- opts = append(opts, anthropic.WithBaseURL(baseURL))
- }
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, anthropic.WithHTTPClient(httpClient))
- }
- return anthropic.New(opts...)
- }
- func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []openai.Option{
- openai.WithAPIKey(apiKey),
- openai.WithUseResponsesAPI(),
- }
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, openai.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openai.WithHeaders(headers))
- }
- if baseURL != "" {
- opts = append(opts, openai.WithBaseURL(baseURL))
- }
- return openai.New(opts...)
- }
- func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []openrouter.Option{
- openrouter.WithAPIKey(apiKey),
- }
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, openrouter.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openrouter.WithHeaders(headers))
- }
- return openrouter.New(opts...)
- }
- func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []vercel.Option{
- vercel.WithAPIKey(apiKey),
- }
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, vercel.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, vercel.WithHeaders(headers))
- }
- return vercel.New(opts...)
- }
- func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
- opts := []openaicompat.Option{
- openaicompat.WithBaseURL(baseURL),
- openaicompat.WithAPIKey(apiKey),
- }
- // Set HTTP client based on provider and debug mode.
- var httpClient *http.Client
- if providerID == string(catwalk.InferenceProviderCopilot) {
- opts = append(opts, openaicompat.WithUseResponsesAPI())
- httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
- } else if c.cfg.Options.Debug {
- httpClient = log.NewHTTPClient()
- }
- if httpClient != nil {
- opts = append(opts, openaicompat.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openaicompat.WithHeaders(headers))
- }
- for extraKey, extraValue := range extraBody {
- opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
- }
- return openaicompat.New(opts...)
- }
- func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
- opts := []azure.Option{
- azure.WithBaseURL(baseURL),
- azure.WithAPIKey(apiKey),
- azure.WithUseResponsesAPI(),
- }
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, azure.WithHTTPClient(httpClient))
- }
- if options == nil {
- options = make(map[string]string)
- }
- if apiVersion, ok := options["apiVersion"]; ok {
- opts = append(opts, azure.WithAPIVersion(apiVersion))
- }
- if len(headers) > 0 {
- opts = append(opts, azure.WithHeaders(headers))
- }
- return azure.New(opts...)
- }
- func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
- var opts []bedrock.Option
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, bedrock.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, bedrock.WithHeaders(headers))
- }
- bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
- if bearerToken != "" {
- opts = append(opts, bedrock.WithAPIKey(bearerToken))
- }
- return bedrock.New(opts...)
- }
- func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []google.Option{
- google.WithBaseURL(baseURL),
- google.WithGeminiAPIKey(apiKey),
- }
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, google.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, google.WithHeaders(headers))
- }
- return google.New(opts...)
- }
- func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
- opts := []google.Option{}
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, google.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, google.WithHeaders(headers))
- }
- project := options["project"]
- location := options["location"]
- opts = append(opts, google.WithVertex(project, location))
- return google.New(opts...)
- }
- func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
- opts := []hyper.Option{
- hyper.WithBaseURL(baseURL),
- hyper.WithAPIKey(apiKey),
- }
- if c.cfg.Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, hyper.WithHTTPClient(httpClient))
- }
- return hyper.New(opts...)
- }
- func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
- if model.Think {
- return true
- }
- if model.ProviderOptions == nil {
- return false
- }
- opts, err := anthropic.ParseOptions(model.ProviderOptions)
- if err != nil {
- return false
- }
- if opts.Thinking != nil {
- return true
- }
- return false
- }
- func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
- headers := maps.Clone(providerCfg.ExtraHeaders)
- if headers == nil {
- headers = make(map[string]string)
- }
- // handle special headers for anthropic
- if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
- if v, ok := headers["anthropic-beta"]; ok {
- headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
- } else {
- headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
- }
- }
- apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
- baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
- switch providerCfg.Type {
- case openai.Name:
- return c.buildOpenaiProvider(baseURL, apiKey, headers)
- case anthropic.Name:
- return c.buildAnthropicProvider(baseURL, apiKey, headers)
- case openrouter.Name:
- return c.buildOpenrouterProvider(baseURL, apiKey, headers)
- case vercel.Name:
- return c.buildVercelProvider(baseURL, apiKey, headers)
- case azure.Name:
- return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
- case bedrock.Name:
- return c.buildBedrockProvider(headers)
- case google.Name:
- return c.buildGoogleProvider(baseURL, apiKey, headers)
- case "google-vertex":
- return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
- case openaicompat.Name:
- if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
- if providerCfg.ExtraBody == nil {
- providerCfg.ExtraBody = map[string]any{}
- }
- providerCfg.ExtraBody["tool_stream"] = true
- }
- return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
- case hyper.Name:
- return c.buildHyperProvider(baseURL, apiKey)
- default:
- return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
- }
- }
- func isExactoSupported(modelID string) bool {
- supportedModels := []string{
- "moonshotai/kimi-k2-0905",
- "deepseek/deepseek-v3.1-terminus",
- "z-ai/glm-4.6",
- "openai/gpt-oss-120b",
- "qwen/qwen3-coder",
- }
- return slices.Contains(supportedModels, modelID)
- }
- func (c *coordinator) Cancel(sessionID string) {
- c.currentAgent.Cancel(sessionID)
- }
- func (c *coordinator) CancelAll() {
- c.currentAgent.CancelAll()
- }
- func (c *coordinator) ClearQueue(sessionID string) {
- c.currentAgent.ClearQueue(sessionID)
- }
- func (c *coordinator) IsBusy() bool {
- return c.currentAgent.IsBusy()
- }
- func (c *coordinator) IsSessionBusy(sessionID string) bool {
- return c.currentAgent.IsSessionBusy(sessionID)
- }
- func (c *coordinator) Model() Model {
- return c.currentAgent.Model()
- }
- func (c *coordinator) UpdateModels(ctx context.Context) error {
- // build the models again so we make sure we get the latest config
- large, small, err := c.buildAgentModels(ctx, false)
- if err != nil {
- return err
- }
- c.currentAgent.SetModels(large, small)
- agentCfg, ok := c.cfg.Agents[config.AgentCoder]
- if !ok {
- return errors.New("coder agent not configured")
- }
- tools, err := c.buildTools(ctx, agentCfg)
- if err != nil {
- return err
- }
- c.currentAgent.SetTools(tools)
- return nil
- }
- func (c *coordinator) QueuedPrompts(sessionID string) int {
- return c.currentAgent.QueuedPrompts(sessionID)
- }
- func (c *coordinator) QueuedPromptsList(sessionID string) []string {
- return c.currentAgent.QueuedPromptsList(sessionID)
- }
- func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
- providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
- if !ok {
- return errors.New("model provider not configured")
- }
- return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
- }
- func (c *coordinator) isUnauthorized(err error) bool {
- var providerErr *fantasy.ProviderError
- return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
- }
- func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
- if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
- slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
- return err
- }
- if err := c.UpdateModels(ctx); err != nil {
- return err
- }
- return nil
- }
- func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
- newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
- if err != nil {
- slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
- return err
- }
- providerCfg.APIKey = newAPIKey
- c.cfg.Providers.Set(providerCfg.ID, providerCfg)
- if err := c.UpdateModels(ctx); err != nil {
- return err
- }
- return nil
- }
|