| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039 |
- package agent
- import (
- "bytes"
- "cmp"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "maps"
- "net/http"
- "os"
- "slices"
- "strings"
- "charm.land/catwalk/pkg/catwalk"
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/agent/hyper"
- "github.com/charmbracelet/crush/internal/agent/notify"
- "github.com/charmbracelet/crush/internal/agent/prompt"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/filetracker"
- "github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/log"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/oauth/copilot"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/pubsub"
- "github.com/charmbracelet/crush/internal/session"
- "golang.org/x/sync/errgroup"
- "charm.land/fantasy/providers/anthropic"
- "charm.land/fantasy/providers/azure"
- "charm.land/fantasy/providers/bedrock"
- "charm.land/fantasy/providers/google"
- "charm.land/fantasy/providers/openai"
- "charm.land/fantasy/providers/openaicompat"
- "charm.land/fantasy/providers/openrouter"
- "charm.land/fantasy/providers/vercel"
- openaisdk "github.com/charmbracelet/openai-go/option"
- "github.com/qjebbs/go-jsons"
- )
- // Coordinator errors.
- var (
- errCoderAgentNotConfigured = errors.New("coder agent not configured")
- errModelProviderNotConfigured = errors.New("model provider not configured")
- errLargeModelNotSelected = errors.New("large model not selected")
- errSmallModelNotSelected = errors.New("small model not selected")
- errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
- errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
- errLargeModelNotFound = errors.New("large model not found in provider config")
- errSmallModelNotFound = errors.New("small model not found in provider config")
- )
- type Coordinator interface {
- // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
- // SetMainAgent(string)
- Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
- Cancel(sessionID string)
- CancelAll()
- IsSessionBusy(sessionID string) bool
- IsBusy() bool
- QueuedPrompts(sessionID string) int
- QueuedPromptsList(sessionID string) []string
- ClearQueue(sessionID string)
- Summarize(context.Context, string) error
- Model() Model
- UpdateModels(ctx context.Context) error
- }
- type coordinator struct {
- cfg *config.ConfigStore
- sessions session.Service
- messages message.Service
- permissions permission.Service
- history history.Service
- filetracker filetracker.Service
- lspManager *lsp.Manager
- notify pubsub.Publisher[notify.Notification]
- currentAgent SessionAgent
- agents map[string]SessionAgent
- readyWg errgroup.Group
- }
- func NewCoordinator(
- ctx context.Context,
- cfg *config.ConfigStore,
- sessions session.Service,
- messages message.Service,
- permissions permission.Service,
- history history.Service,
- filetracker filetracker.Service,
- lspManager *lsp.Manager,
- notify pubsub.Publisher[notify.Notification],
- ) (Coordinator, error) {
- c := &coordinator{
- cfg: cfg,
- sessions: sessions,
- messages: messages,
- permissions: permissions,
- history: history,
- filetracker: filetracker,
- lspManager: lspManager,
- notify: notify,
- agents: make(map[string]SessionAgent),
- }
- agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
- if !ok {
- return nil, errCoderAgentNotConfigured
- }
- // TODO: make this dynamic when we support multiple agents
- prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
- if err != nil {
- return nil, err
- }
- agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
- if err != nil {
- return nil, err
- }
- c.currentAgent = agent
- c.agents[config.AgentCoder] = agent
- return c, nil
- }
- // Run implements Coordinator.
- func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
- if err := c.readyWg.Wait(); err != nil {
- return nil, err
- }
- // refresh models before each run
- if err := c.UpdateModels(ctx); err != nil {
- return nil, fmt.Errorf("failed to update models: %w", err)
- }
- model := c.currentAgent.Model()
- maxTokens := model.CatwalkCfg.DefaultMaxTokens
- if model.ModelCfg.MaxTokens != 0 {
- maxTokens = model.ModelCfg.MaxTokens
- }
- if !model.CatwalkCfg.SupportsImages && attachments != nil {
- // filter out image attachments
- filteredAttachments := make([]message.Attachment, 0, len(attachments))
- for _, att := range attachments {
- if att.IsText() {
- filteredAttachments = append(filteredAttachments, att)
- }
- }
- attachments = filteredAttachments
- }
- providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
- if !ok {
- return nil, errModelProviderNotConfigured
- }
- mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
- if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
- slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- return nil, err
- }
- }
- run := func() (*fantasy.AgentResult, error) {
- return c.currentAgent.Run(ctx, SessionAgentCall{
- SessionID: sessionID,
- Prompt: prompt,
- Attachments: attachments,
- MaxOutputTokens: maxTokens,
- ProviderOptions: mergedOptions,
- Temperature: temp,
- TopP: topP,
- TopK: topK,
- FrequencyPenalty: freqPenalty,
- PresencePenalty: presPenalty,
- })
- }
- result, originalErr := run()
- if c.isUnauthorized(originalErr) {
- switch {
- case providerCfg.OAuthToken != nil:
- slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
- return run()
- case strings.Contains(providerCfg.APIKeyTemplate, "$"):
- slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
- if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
- return run()
- }
- }
- return result, originalErr
- }
- func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
- options := fantasy.ProviderOptions{}
- cfgOpts := []byte("{}")
- providerCfgOpts := []byte("{}")
- catwalkOpts := []byte("{}")
- if model.ModelCfg.ProviderOptions != nil {
- data, err := json.Marshal(model.ModelCfg.ProviderOptions)
- if err == nil {
- cfgOpts = data
- }
- }
- if providerCfg.ProviderOptions != nil {
- data, err := json.Marshal(providerCfg.ProviderOptions)
- if err == nil {
- providerCfgOpts = data
- }
- }
- if model.CatwalkCfg.Options.ProviderOptions != nil {
- data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
- if err == nil {
- catwalkOpts = data
- }
- }
- readers := []io.Reader{
- bytes.NewReader(catwalkOpts),
- bytes.NewReader(providerCfgOpts),
- bytes.NewReader(cfgOpts),
- }
- got, err := jsons.Merge(readers)
- if err != nil {
- slog.Error("Could not merge call config", "err", err)
- return options
- }
- mergedOptions := make(map[string]any)
- err = json.Unmarshal([]byte(got), &mergedOptions)
- if err != nil {
- slog.Error("Could not create config for call", "err", err)
- return options
- }
- providerType := providerCfg.Type
- if providerType == "hyper" {
- if strings.Contains(model.CatwalkCfg.ID, "claude") {
- providerType = anthropic.Name
- } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
- providerType = openai.Name
- } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
- providerType = google.Name
- } else {
- providerType = openaicompat.Name
- }
- }
- switch providerType {
- case openai.Name, azure.Name:
- _, hasReasoningEffort := mergedOptions["reasoning_effort"]
- if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
- }
- if openai.IsResponsesModel(model.CatwalkCfg.ID) {
- if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
- mergedOptions["reasoning_summary"] = "auto"
- mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
- }
- parsed, err := openai.ParseResponsesOptions(mergedOptions)
- if err == nil {
- options[openai.Name] = parsed
- }
- } else {
- parsed, err := openai.ParseOptions(mergedOptions)
- if err == nil {
- options[openai.Name] = parsed
- }
- }
- case anthropic.Name:
- var (
- _, hasEffort = mergedOptions["effort"]
- _, hasThink = mergedOptions["thinking"]
- )
- switch {
- case !hasEffort && model.ModelCfg.ReasoningEffort != "":
- mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
- case !hasThink && model.ModelCfg.Think:
- mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
- }
- parsed, err := anthropic.ParseOptions(mergedOptions)
- if err == nil {
- options[anthropic.Name] = parsed
- }
- case openrouter.Name:
- _, hasReasoning := mergedOptions["reasoning"]
- if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning"] = map[string]any{
- "enabled": true,
- "effort": model.ModelCfg.ReasoningEffort,
- }
- }
- parsed, err := openrouter.ParseOptions(mergedOptions)
- if err == nil {
- options[openrouter.Name] = parsed
- }
- case vercel.Name:
- _, hasReasoning := mergedOptions["reasoning"]
- if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning"] = map[string]any{
- "enabled": true,
- "effort": model.ModelCfg.ReasoningEffort,
- }
- }
- parsed, err := vercel.ParseOptions(mergedOptions)
- if err == nil {
- options[vercel.Name] = parsed
- }
- case google.Name:
- _, hasReasoning := mergedOptions["thinking_config"]
- if !hasReasoning {
- if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
- mergedOptions["thinking_config"] = map[string]any{
- "thinking_budget": 2000,
- "include_thoughts": true,
- }
- } else {
- mergedOptions["thinking_config"] = map[string]any{
- "thinking_level": model.ModelCfg.ReasoningEffort,
- "include_thoughts": true,
- }
- }
- }
- parsed, err := google.ParseOptions(mergedOptions)
- if err == nil {
- options[google.Name] = parsed
- }
- case openaicompat.Name:
- _, hasReasoningEffort := mergedOptions["reasoning_effort"]
- if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
- }
- parsed, err := openaicompat.ParseOptions(mergedOptions)
- if err == nil {
- options[openaicompat.Name] = parsed
- }
- }
- return options
- }
- func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
- modelOptions := getProviderOptions(model, cfg)
- temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
- topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
- topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
- freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
- presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
- return modelOptions, temp, topP, topK, freqPenalty, presPenalty
- }
- func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
- large, small, err := c.buildAgentModels(ctx, isSubAgent)
- if err != nil {
- return nil, err
- }
- largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
- result := NewSessionAgent(SessionAgentOptions{
- LargeModel: large,
- SmallModel: small,
- SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
- SystemPrompt: "",
- IsSubAgent: isSubAgent,
- DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
- IsYolo: c.permissions.SkipRequests(),
- Sessions: c.sessions,
- Messages: c.messages,
- Tools: nil,
- Notify: c.notify,
- })
- c.readyWg.Go(func() error {
- systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
- if err != nil {
- return err
- }
- result.SetSystemPrompt(systemPrompt)
- return nil
- })
- c.readyWg.Go(func() error {
- tools, err := c.buildTools(ctx, agent)
- if err != nil {
- return err
- }
- result.SetTools(tools)
- return nil
- })
- return result, nil
- }
- func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
- var allTools []fantasy.AgentTool
- if slices.Contains(agent.AllowedTools, AgentToolName) {
- agentTool, err := c.agentTool(ctx)
- if err != nil {
- return nil, err
- }
- allTools = append(allTools, agentTool)
- }
- if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
- agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
- if err != nil {
- return nil, err
- }
- allTools = append(allTools, agenticFetchTool)
- }
- // Get the model name for the agent
- modelName := ""
- if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
- if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
- modelName = model.Name
- }
- }
- allTools = append(allTools,
- tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
- tools.NewJobOutputTool(),
- tools.NewJobKillTool(),
- tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
- tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
- tools.NewGlobTool(c.cfg.WorkingDir()),
- tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
- tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
- tools.NewSourcegraphTool(nil),
- tools.NewTodosTool(c.sessions),
- tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
- tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- )
- // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
- if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
- allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
- }
- if len(c.cfg.Config().MCP) > 0 {
- allTools = append(
- allTools,
- tools.NewListMCPResourcesTool(c.cfg, c.permissions),
- tools.NewReadMCPResourceTool(c.cfg, c.permissions),
- )
- }
- var filteredTools []fantasy.AgentTool
- for _, tool := range allTools {
- if slices.Contains(agent.AllowedTools, tool.Info().Name) {
- filteredTools = append(filteredTools, tool)
- }
- }
- for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
- if agent.AllowedMCP == nil {
- // No MCP restrictions
- filteredTools = append(filteredTools, tool)
- continue
- }
- if len(agent.AllowedMCP) == 0 {
- // No MCPs allowed
- slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
- break
- }
- for mcp, tools := range agent.AllowedMCP {
- if mcp != tool.MCP() {
- continue
- }
- if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
- filteredTools = append(filteredTools, tool)
- break
- }
- slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
- }
- }
- slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
- return strings.Compare(a.Info().Name, b.Info().Name)
- })
- return filteredTools, nil
- }
- // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
- func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
- largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
- if !ok {
- return Model{}, Model{}, errLargeModelNotSelected
- }
- smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
- if !ok {
- return Model{}, Model{}, errSmallModelNotSelected
- }
- largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
- if !ok {
- return Model{}, Model{}, errLargeModelProviderNotConfigured
- }
- largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
- if err != nil {
- return Model{}, Model{}, err
- }
- smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
- if !ok {
- return Model{}, Model{}, errSmallModelProviderNotConfigured
- }
- smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
- if err != nil {
- return Model{}, Model{}, err
- }
- var largeCatwalkModel *catwalk.Model
- var smallCatwalkModel *catwalk.Model
- for _, m := range largeProviderCfg.Models {
- if m.ID == largeModelCfg.Model {
- largeCatwalkModel = &m
- }
- }
- for _, m := range smallProviderCfg.Models {
- if m.ID == smallModelCfg.Model {
- smallCatwalkModel = &m
- }
- }
- if largeCatwalkModel == nil {
- return Model{}, Model{}, errLargeModelNotFound
- }
- if smallCatwalkModel == nil {
- return Model{}, Model{}, errSmallModelNotFound
- }
- largeModelID := largeModelCfg.Model
- smallModelID := smallModelCfg.Model
- if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
- largeModelID += ":exacto"
- }
- if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
- smallModelID += ":exacto"
- }
- largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
- if err != nil {
- return Model{}, Model{}, err
- }
- smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
- if err != nil {
- return Model{}, Model{}, err
- }
- return Model{
- Model: largeModel,
- CatwalkCfg: *largeCatwalkModel,
- ModelCfg: largeModelCfg,
- }, Model{
- Model: smallModel,
- CatwalkCfg: *smallCatwalkModel,
- ModelCfg: smallModelCfg,
- }, nil
- }
- func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
- var opts []anthropic.Option
- switch {
- case strings.HasPrefix(apiKey, "Bearer "):
- // NOTE: Prevent the SDK from picking up the API key from env.
- os.Setenv("ANTHROPIC_API_KEY", "")
- headers["Authorization"] = apiKey
- case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
- // NOTE: Prevent the SDK from picking up the API key from env.
- os.Setenv("ANTHROPIC_API_KEY", "")
- headers["Authorization"] = "Bearer " + apiKey
- case apiKey != "":
- // X-Api-Key header
- opts = append(opts, anthropic.WithAPIKey(apiKey))
- }
- if len(headers) > 0 {
- opts = append(opts, anthropic.WithHeaders(headers))
- }
- if baseURL != "" {
- opts = append(opts, anthropic.WithBaseURL(baseURL))
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, anthropic.WithHTTPClient(httpClient))
- }
- return anthropic.New(opts...)
- }
- func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []openai.Option{
- openai.WithAPIKey(apiKey),
- openai.WithUseResponsesAPI(),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, openai.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openai.WithHeaders(headers))
- }
- if baseURL != "" {
- opts = append(opts, openai.WithBaseURL(baseURL))
- }
- return openai.New(opts...)
- }
- func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []openrouter.Option{
- openrouter.WithAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, openrouter.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openrouter.WithHeaders(headers))
- }
- return openrouter.New(opts...)
- }
- func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []vercel.Option{
- vercel.WithAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, vercel.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, vercel.WithHeaders(headers))
- }
- return vercel.New(opts...)
- }
- func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
- opts := []openaicompat.Option{
- openaicompat.WithBaseURL(baseURL),
- openaicompat.WithAPIKey(apiKey),
- }
- // Set HTTP client based on provider and debug mode.
- var httpClient *http.Client
- if providerID == string(catwalk.InferenceProviderCopilot) {
- opts = append(opts, openaicompat.WithUseResponsesAPI())
- httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
- } else if c.cfg.Config().Options.Debug {
- httpClient = log.NewHTTPClient()
- }
- if httpClient != nil {
- opts = append(opts, openaicompat.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openaicompat.WithHeaders(headers))
- }
- for extraKey, extraValue := range extraBody {
- opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
- }
- return openaicompat.New(opts...)
- }
- func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
- opts := []azure.Option{
- azure.WithBaseURL(baseURL),
- azure.WithAPIKey(apiKey),
- azure.WithUseResponsesAPI(),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, azure.WithHTTPClient(httpClient))
- }
- if options == nil {
- options = make(map[string]string)
- }
- if apiVersion, ok := options["apiVersion"]; ok {
- opts = append(opts, azure.WithAPIVersion(apiVersion))
- }
- if len(headers) > 0 {
- opts = append(opts, azure.WithHeaders(headers))
- }
- return azure.New(opts...)
- }
- func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
- var opts []bedrock.Option
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, bedrock.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, bedrock.WithHeaders(headers))
- }
- switch {
- case apiKey != "":
- opts = append(opts, bedrock.WithAPIKey(apiKey))
- case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
- opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
- default:
- // Skip, let the SDK do authentication.
- }
- return bedrock.New(opts...)
- }
- func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []google.Option{
- google.WithBaseURL(baseURL),
- google.WithGeminiAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, google.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, google.WithHeaders(headers))
- }
- return google.New(opts...)
- }
- func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
- opts := []google.Option{}
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, google.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, google.WithHeaders(headers))
- }
- project := options["project"]
- location := options["location"]
- opts = append(opts, google.WithVertex(project, location))
- return google.New(opts...)
- }
- func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
- opts := []hyper.Option{
- hyper.WithAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, hyper.WithHTTPClient(httpClient))
- }
- return hyper.New(opts...)
- }
- func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
- if model.Think {
- return true
- }
- opts, err := anthropic.ParseOptions(model.ProviderOptions)
- return err == nil && opts.Thinking != nil
- }
- func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
- headers := maps.Clone(providerCfg.ExtraHeaders)
- if headers == nil {
- headers = make(map[string]string)
- }
- // handle special headers for anthropic
- if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
- if v, ok := headers["anthropic-beta"]; ok {
- headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
- } else {
- headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
- }
- }
- apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
- baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
- switch providerCfg.Type {
- case openai.Name:
- return c.buildOpenaiProvider(baseURL, apiKey, headers)
- case anthropic.Name:
- return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
- case openrouter.Name:
- return c.buildOpenrouterProvider(baseURL, apiKey, headers)
- case vercel.Name:
- return c.buildVercelProvider(baseURL, apiKey, headers)
- case azure.Name:
- return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
- case bedrock.Name:
- return c.buildBedrockProvider(apiKey, headers)
- case google.Name:
- return c.buildGoogleProvider(baseURL, apiKey, headers)
- case "google-vertex":
- return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
- case openaicompat.Name:
- if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
- if providerCfg.ExtraBody == nil {
- providerCfg.ExtraBody = map[string]any{}
- }
- providerCfg.ExtraBody["tool_stream"] = true
- }
- return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
- case hyper.Name:
- return c.buildHyperProvider(apiKey)
- default:
- return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
- }
- }
- func isExactoSupported(modelID string) bool {
- supportedModels := []string{
- "moonshotai/kimi-k2-0905",
- "deepseek/deepseek-v3.1-terminus",
- "z-ai/glm-4.6",
- "openai/gpt-oss-120b",
- "qwen/qwen3-coder",
- }
- return slices.Contains(supportedModels, modelID)
- }
- func (c *coordinator) Cancel(sessionID string) {
- c.currentAgent.Cancel(sessionID)
- }
- func (c *coordinator) CancelAll() {
- c.currentAgent.CancelAll()
- }
- func (c *coordinator) ClearQueue(sessionID string) {
- c.currentAgent.ClearQueue(sessionID)
- }
- func (c *coordinator) IsBusy() bool {
- return c.currentAgent.IsBusy()
- }
- func (c *coordinator) IsSessionBusy(sessionID string) bool {
- return c.currentAgent.IsSessionBusy(sessionID)
- }
- func (c *coordinator) Model() Model {
- return c.currentAgent.Model()
- }
- func (c *coordinator) UpdateModels(ctx context.Context) error {
- // build the models again so we make sure we get the latest config
- large, small, err := c.buildAgentModels(ctx, false)
- if err != nil {
- return err
- }
- c.currentAgent.SetModels(large, small)
- agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
- if !ok {
- return errCoderAgentNotConfigured
- }
- tools, err := c.buildTools(ctx, agentCfg)
- if err != nil {
- return err
- }
- c.currentAgent.SetTools(tools)
- return nil
- }
- func (c *coordinator) QueuedPrompts(sessionID string) int {
- return c.currentAgent.QueuedPrompts(sessionID)
- }
- func (c *coordinator) QueuedPromptsList(sessionID string) []string {
- return c.currentAgent.QueuedPromptsList(sessionID)
- }
- func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
- providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
- if !ok {
- return errModelProviderNotConfigured
- }
- return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
- }
- func (c *coordinator) isUnauthorized(err error) bool {
- var providerErr *fantasy.ProviderError
- return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
- }
- func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
- if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
- slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
- return err
- }
- if err := c.UpdateModels(ctx); err != nil {
- return err
- }
- return nil
- }
- func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
- newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
- if err != nil {
- slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
- return err
- }
- providerCfg.APIKey = newAPIKey
- c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
- if err := c.UpdateModels(ctx); err != nil {
- return err
- }
- return nil
- }
- // subAgentParams holds the parameters for running a sub-agent.
- type subAgentParams struct {
- Agent SessionAgent
- SessionID string
- AgentMessageID string
- ToolCallID string
- Prompt string
- SessionTitle string
- // SessionSetup is an optional callback invoked after session creation
- // but before agent execution, for custom session configuration.
- SessionSetup func(sessionID string)
- }
- // runSubAgent runs a sub-agent and handles session management and cost accumulation.
- // It creates a sub-session, runs the agent with the given prompt, and propagates
- // the cost to the parent session.
- func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
- // Create sub-session
- agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
- session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
- }
- // Call session setup function if provided
- if params.SessionSetup != nil {
- params.SessionSetup(session.ID)
- }
- // Get model configuration
- model := params.Agent.Model()
- maxTokens := model.CatwalkCfg.DefaultMaxTokens
- if model.ModelCfg.MaxTokens != 0 {
- maxTokens = model.ModelCfg.MaxTokens
- }
- providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
- if !ok {
- return fantasy.ToolResponse{}, errModelProviderNotConfigured
- }
- // Run the agent
- result, err := params.Agent.Run(ctx, SessionAgentCall{
- SessionID: session.ID,
- Prompt: params.Prompt,
- MaxOutputTokens: maxTokens,
- ProviderOptions: getProviderOptions(model, providerCfg),
- Temperature: model.ModelCfg.Temperature,
- TopP: model.ModelCfg.TopP,
- TopK: model.ModelCfg.TopK,
- FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
- PresencePenalty: model.ModelCfg.PresencePenalty,
- NonInteractive: true,
- })
- if err != nil {
- return fantasy.NewTextErrorResponse("error generating response"), nil
- }
- // Update parent session cost
- if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
- return fantasy.ToolResponse{}, err
- }
- return fantasy.NewTextResponse(result.Response.Content.Text()), nil
- }
- // updateParentSessionCost accumulates the cost from a child session to its parent session.
- func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
- childSession, err := c.sessions.Get(ctx, childSessionID)
- if err != nil {
- return fmt.Errorf("get child session: %w", err)
- }
- parentSession, err := c.sessions.Get(ctx, parentSessionID)
- if err != nil {
- return fmt.Errorf("get parent session: %w", err)
- }
- parentSession.Cost += childSession.Cost
- if _, err := c.sessions.Save(ctx, parentSession); err != nil {
- return fmt.Errorf("save parent session: %w", err)
- }
- return nil
- }
|