| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049 |
- package agent
- import (
- "bytes"
- "cmp"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "maps"
- "net/http"
- "os"
- "slices"
- "strings"
- "charm.land/catwalk/pkg/catwalk"
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/agent/hyper"
- "github.com/charmbracelet/crush/internal/agent/notify"
- "github.com/charmbracelet/crush/internal/agent/prompt"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/filetracker"
- "github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/hooks"
- "github.com/charmbracelet/crush/internal/log"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/oauth/copilot"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/pubsub"
- "github.com/charmbracelet/crush/internal/session"
- "golang.org/x/sync/errgroup"
- "charm.land/fantasy/providers/anthropic"
- "charm.land/fantasy/providers/azure"
- "charm.land/fantasy/providers/bedrock"
- "charm.land/fantasy/providers/google"
- "charm.land/fantasy/providers/openai"
- "charm.land/fantasy/providers/openaicompat"
- "charm.land/fantasy/providers/openrouter"
- "charm.land/fantasy/providers/vercel"
- openaisdk "github.com/charmbracelet/openai-go/option"
- "github.com/qjebbs/go-jsons"
- )
- // Coordinator errors.
- var (
- errCoderAgentNotConfigured = errors.New("coder agent not configured")
- errModelProviderNotConfigured = errors.New("model provider not configured")
- errLargeModelNotSelected = errors.New("large model not selected")
- errSmallModelNotSelected = errors.New("small model not selected")
- errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
- errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
- errLargeModelNotFound = errors.New("large model not found in provider config")
- errSmallModelNotFound = errors.New("small model not found in provider config")
- )
- type Coordinator interface {
- // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
- // SetMainAgent(string)
- Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
- Cancel(sessionID string)
- CancelAll()
- IsSessionBusy(sessionID string) bool
- IsBusy() bool
- QueuedPrompts(sessionID string) int
- QueuedPromptsList(sessionID string) []string
- ClearQueue(sessionID string)
- Summarize(context.Context, string) error
- Model() Model
- UpdateModels(ctx context.Context) error
- }
- type coordinator struct {
- cfg *config.ConfigStore
- sessions session.Service
- messages message.Service
- permissions permission.Service
- history history.Service
- filetracker filetracker.Service
- lspManager *lsp.Manager
- notify pubsub.Publisher[notify.Notification]
- currentAgent SessionAgent
- agents map[string]SessionAgent
- readyWg errgroup.Group
- }
- func NewCoordinator(
- ctx context.Context,
- cfg *config.ConfigStore,
- sessions session.Service,
- messages message.Service,
- permissions permission.Service,
- history history.Service,
- filetracker filetracker.Service,
- lspManager *lsp.Manager,
- notify pubsub.Publisher[notify.Notification],
- ) (Coordinator, error) {
- c := &coordinator{
- cfg: cfg,
- sessions: sessions,
- messages: messages,
- permissions: permissions,
- history: history,
- filetracker: filetracker,
- lspManager: lspManager,
- notify: notify,
- agents: make(map[string]SessionAgent),
- }
- agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
- if !ok {
- return nil, errCoderAgentNotConfigured
- }
- // TODO: make this dynamic when we support multiple agents
- prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
- if err != nil {
- return nil, err
- }
- agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
- if err != nil {
- return nil, err
- }
- c.currentAgent = agent
- c.agents[config.AgentCoder] = agent
- return c, nil
- }
- // Run implements Coordinator.
- func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
- if err := c.readyWg.Wait(); err != nil {
- return nil, err
- }
- // refresh models before each run
- if err := c.UpdateModels(ctx); err != nil {
- return nil, fmt.Errorf("failed to update models: %w", err)
- }
- model := c.currentAgent.Model()
- maxTokens := model.CatwalkCfg.DefaultMaxTokens
- if model.ModelCfg.MaxTokens != 0 {
- maxTokens = model.ModelCfg.MaxTokens
- }
- if !model.CatwalkCfg.SupportsImages && attachments != nil {
- // filter out image attachments
- filteredAttachments := make([]message.Attachment, 0, len(attachments))
- for _, att := range attachments {
- if att.IsText() {
- filteredAttachments = append(filteredAttachments, att)
- }
- }
- attachments = filteredAttachments
- }
- providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
- if !ok {
- return nil, errModelProviderNotConfigured
- }
- mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
- if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
- slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- return nil, err
- }
- }
- run := func() (*fantasy.AgentResult, error) {
- return c.currentAgent.Run(ctx, SessionAgentCall{
- SessionID: sessionID,
- Prompt: prompt,
- Attachments: attachments,
- MaxOutputTokens: maxTokens,
- ProviderOptions: mergedOptions,
- Temperature: temp,
- TopP: topP,
- TopK: topK,
- FrequencyPenalty: freqPenalty,
- PresencePenalty: presPenalty,
- })
- }
- result, originalErr := run()
- if c.isUnauthorized(originalErr) {
- switch {
- case providerCfg.OAuthToken != nil:
- slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
- return run()
- case strings.Contains(providerCfg.APIKeyTemplate, "$"):
- slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
- if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
- return run()
- }
- }
- return result, originalErr
- }
- func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
- options := fantasy.ProviderOptions{}
- cfgOpts := []byte("{}")
- providerCfgOpts := []byte("{}")
- catwalkOpts := []byte("{}")
- if model.ModelCfg.ProviderOptions != nil {
- data, err := json.Marshal(model.ModelCfg.ProviderOptions)
- if err == nil {
- cfgOpts = data
- }
- }
- if providerCfg.ProviderOptions != nil {
- data, err := json.Marshal(providerCfg.ProviderOptions)
- if err == nil {
- providerCfgOpts = data
- }
- }
- if model.CatwalkCfg.Options.ProviderOptions != nil {
- data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
- if err == nil {
- catwalkOpts = data
- }
- }
- readers := []io.Reader{
- bytes.NewReader(catwalkOpts),
- bytes.NewReader(providerCfgOpts),
- bytes.NewReader(cfgOpts),
- }
- got, err := jsons.Merge(readers)
- if err != nil {
- slog.Error("Could not merge call config", "err", err)
- return options
- }
- mergedOptions := make(map[string]any)
- err = json.Unmarshal([]byte(got), &mergedOptions)
- if err != nil {
- slog.Error("Could not create config for call", "err", err)
- return options
- }
- providerType := providerCfg.Type
- if providerType == "hyper" {
- if strings.Contains(model.CatwalkCfg.ID, "claude") {
- providerType = anthropic.Name
- } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
- providerType = openai.Name
- } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
- providerType = google.Name
- } else {
- providerType = openaicompat.Name
- }
- }
- switch providerType {
- case openai.Name, azure.Name:
- _, hasReasoningEffort := mergedOptions["reasoning_effort"]
- if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
- }
- if openai.IsResponsesModel(model.CatwalkCfg.ID) {
- if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
- mergedOptions["reasoning_summary"] = "auto"
- mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
- }
- parsed, err := openai.ParseResponsesOptions(mergedOptions)
- if err == nil {
- options[openai.Name] = parsed
- }
- } else {
- parsed, err := openai.ParseOptions(mergedOptions)
- if err == nil {
- options[openai.Name] = parsed
- }
- }
- case anthropic.Name:
- var (
- _, hasEffort = mergedOptions["effort"]
- _, hasThink = mergedOptions["thinking"]
- )
- switch {
- case !hasEffort && model.ModelCfg.ReasoningEffort != "":
- mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
- case !hasThink && model.ModelCfg.Think:
- mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
- }
- parsed, err := anthropic.ParseOptions(mergedOptions)
- if err == nil {
- options[anthropic.Name] = parsed
- }
- case openrouter.Name:
- _, hasReasoning := mergedOptions["reasoning"]
- if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning"] = map[string]any{
- "enabled": true,
- "effort": model.ModelCfg.ReasoningEffort,
- }
- }
- parsed, err := openrouter.ParseOptions(mergedOptions)
- if err == nil {
- options[openrouter.Name] = parsed
- }
- case vercel.Name:
- _, hasReasoning := mergedOptions["reasoning"]
- if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning"] = map[string]any{
- "enabled": true,
- "effort": model.ModelCfg.ReasoningEffort,
- }
- }
- parsed, err := vercel.ParseOptions(mergedOptions)
- if err == nil {
- options[vercel.Name] = parsed
- }
- case google.Name:
- _, hasReasoning := mergedOptions["thinking_config"]
- if !hasReasoning {
- if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
- mergedOptions["thinking_config"] = map[string]any{
- "thinking_budget": 2000,
- "include_thoughts": true,
- }
- } else {
- mergedOptions["thinking_config"] = map[string]any{
- "thinking_level": model.ModelCfg.ReasoningEffort,
- "include_thoughts": true,
- }
- }
- }
- parsed, err := google.ParseOptions(mergedOptions)
- if err == nil {
- options[google.Name] = parsed
- }
- case openaicompat.Name:
- _, hasReasoningEffort := mergedOptions["reasoning_effort"]
- if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
- mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
- }
- parsed, err := openaicompat.ParseOptions(mergedOptions)
- if err == nil {
- options[openaicompat.Name] = parsed
- }
- }
- return options
- }
- func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
- modelOptions := getProviderOptions(model, cfg)
- temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
- topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
- topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
- freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
- presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
- return modelOptions, temp, topP, topK, freqPenalty, presPenalty
- }
- func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
- large, small, err := c.buildAgentModels(ctx, isSubAgent)
- if err != nil {
- return nil, err
- }
- largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
- result := NewSessionAgent(SessionAgentOptions{
- LargeModel: large,
- SmallModel: small,
- SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
- SystemPrompt: "",
- IsSubAgent: isSubAgent,
- DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
- IsYolo: c.permissions.SkipRequests(),
- Sessions: c.sessions,
- Messages: c.messages,
- Tools: nil,
- Notify: c.notify,
- })
- c.readyWg.Go(func() error {
- systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
- if err != nil {
- return err
- }
- result.SetSystemPrompt(systemPrompt)
- return nil
- })
- c.readyWg.Go(func() error {
- tools, err := c.buildTools(ctx, agent)
- if err != nil {
- return err
- }
- result.SetTools(tools)
- return nil
- })
- return result, nil
- }
- func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
- var allTools []fantasy.AgentTool
- if slices.Contains(agent.AllowedTools, AgentToolName) {
- agentTool, err := c.agentTool(ctx)
- if err != nil {
- return nil, err
- }
- allTools = append(allTools, agentTool)
- }
- if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
- agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
- if err != nil {
- return nil, err
- }
- allTools = append(allTools, agenticFetchTool)
- }
- // Get the model name for the agent
- modelName := ""
- if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
- if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
- modelName = model.Name
- }
- }
- allTools = append(allTools,
- tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
- tools.NewJobOutputTool(),
- tools.NewJobKillTool(),
- tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
- tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
- tools.NewGlobTool(c.cfg.WorkingDir()),
- tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
- tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
- tools.NewSourcegraphTool(nil),
- tools.NewTodosTool(c.sessions),
- tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
- tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
- )
- // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
- if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
- allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
- }
- if len(c.cfg.Config().MCP) > 0 {
- allTools = append(
- allTools,
- tools.NewListMCPResourcesTool(c.cfg, c.permissions),
- tools.NewReadMCPResourceTool(c.cfg, c.permissions),
- )
- }
- var filteredTools []fantasy.AgentTool
- for _, tool := range allTools {
- if slices.Contains(agent.AllowedTools, tool.Info().Name) {
- filteredTools = append(filteredTools, tool)
- }
- }
- for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
- if agent.AllowedMCP == nil {
- // No MCP restrictions
- filteredTools = append(filteredTools, tool)
- continue
- }
- if len(agent.AllowedMCP) == 0 {
- // No MCPs allowed
- slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
- break
- }
- for mcp, tools := range agent.AllowedMCP {
- if mcp != tool.MCP() {
- continue
- }
- if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
- filteredTools = append(filteredTools, tool)
- break
- }
- slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
- }
- }
- slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
- return strings.Compare(a.Info().Name, b.Info().Name)
- })
- // Wrap tools with hook interception if PreToolUse hooks are configured.
- if preToolHooks := c.cfg.Config().Hooks[hooks.EventPreToolUse]; len(preToolHooks) > 0 {
- runner := hooks.NewRunner(preToolHooks, c.cfg.WorkingDir(), c.cfg.WorkingDir())
- for i, tool := range filteredTools {
- filteredTools[i] = newHookedTool(tool, runner)
- }
- }
- return filteredTools, nil
- }
- // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
- func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
- largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
- if !ok {
- return Model{}, Model{}, errLargeModelNotSelected
- }
- smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
- if !ok {
- return Model{}, Model{}, errSmallModelNotSelected
- }
- largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
- if !ok {
- return Model{}, Model{}, errLargeModelProviderNotConfigured
- }
- largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
- if err != nil {
- return Model{}, Model{}, err
- }
- smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
- if !ok {
- return Model{}, Model{}, errSmallModelProviderNotConfigured
- }
- smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
- if err != nil {
- return Model{}, Model{}, err
- }
- var largeCatwalkModel *catwalk.Model
- var smallCatwalkModel *catwalk.Model
- for _, m := range largeProviderCfg.Models {
- if m.ID == largeModelCfg.Model {
- largeCatwalkModel = &m
- }
- }
- for _, m := range smallProviderCfg.Models {
- if m.ID == smallModelCfg.Model {
- smallCatwalkModel = &m
- }
- }
- if largeCatwalkModel == nil {
- return Model{}, Model{}, errLargeModelNotFound
- }
- if smallCatwalkModel == nil {
- return Model{}, Model{}, errSmallModelNotFound
- }
- largeModelID := largeModelCfg.Model
- smallModelID := smallModelCfg.Model
- if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
- largeModelID += ":exacto"
- }
- if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
- smallModelID += ":exacto"
- }
- largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
- if err != nil {
- return Model{}, Model{}, err
- }
- smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
- if err != nil {
- return Model{}, Model{}, err
- }
- return Model{
- Model: largeModel,
- CatwalkCfg: *largeCatwalkModel,
- ModelCfg: largeModelCfg,
- }, Model{
- Model: smallModel,
- CatwalkCfg: *smallCatwalkModel,
- ModelCfg: smallModelCfg,
- }, nil
- }
- func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
- var opts []anthropic.Option
- switch {
- case strings.HasPrefix(apiKey, "Bearer "):
- // NOTE: Prevent the SDK from picking up the API key from env.
- os.Setenv("ANTHROPIC_API_KEY", "")
- headers["Authorization"] = apiKey
- case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
- // NOTE: Prevent the SDK from picking up the API key from env.
- os.Setenv("ANTHROPIC_API_KEY", "")
- headers["Authorization"] = "Bearer " + apiKey
- case apiKey != "":
- // X-Api-Key header
- opts = append(opts, anthropic.WithAPIKey(apiKey))
- }
- if len(headers) > 0 {
- opts = append(opts, anthropic.WithHeaders(headers))
- }
- if baseURL != "" {
- opts = append(opts, anthropic.WithBaseURL(baseURL))
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, anthropic.WithHTTPClient(httpClient))
- }
- return anthropic.New(opts...)
- }
- func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []openai.Option{
- openai.WithAPIKey(apiKey),
- openai.WithUseResponsesAPI(),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, openai.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openai.WithHeaders(headers))
- }
- if baseURL != "" {
- opts = append(opts, openai.WithBaseURL(baseURL))
- }
- return openai.New(opts...)
- }
- func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []openrouter.Option{
- openrouter.WithAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, openrouter.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openrouter.WithHeaders(headers))
- }
- return openrouter.New(opts...)
- }
- func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []vercel.Option{
- vercel.WithAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, vercel.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, vercel.WithHeaders(headers))
- }
- return vercel.New(opts...)
- }
- func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
- opts := []openaicompat.Option{
- openaicompat.WithBaseURL(baseURL),
- openaicompat.WithAPIKey(apiKey),
- }
- // Set HTTP client based on provider and debug mode.
- var httpClient *http.Client
- if providerID == string(catwalk.InferenceProviderCopilot) {
- opts = append(opts, openaicompat.WithUseResponsesAPI())
- httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
- } else if c.cfg.Config().Options.Debug {
- httpClient = log.NewHTTPClient()
- }
- if httpClient != nil {
- opts = append(opts, openaicompat.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, openaicompat.WithHeaders(headers))
- }
- for extraKey, extraValue := range extraBody {
- opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
- }
- return openaicompat.New(opts...)
- }
- func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
- opts := []azure.Option{
- azure.WithBaseURL(baseURL),
- azure.WithAPIKey(apiKey),
- azure.WithUseResponsesAPI(),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, azure.WithHTTPClient(httpClient))
- }
- if options == nil {
- options = make(map[string]string)
- }
- if apiVersion, ok := options["apiVersion"]; ok {
- opts = append(opts, azure.WithAPIVersion(apiVersion))
- }
- if len(headers) > 0 {
- opts = append(opts, azure.WithHeaders(headers))
- }
- return azure.New(opts...)
- }
- func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
- var opts []bedrock.Option
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, bedrock.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, bedrock.WithHeaders(headers))
- }
- switch {
- case apiKey != "":
- opts = append(opts, bedrock.WithAPIKey(apiKey))
- case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
- opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
- default:
- // Skip, let the SDK do authentication.
- }
- return bedrock.New(opts...)
- }
- func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
- opts := []google.Option{
- google.WithBaseURL(baseURL),
- google.WithGeminiAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, google.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, google.WithHeaders(headers))
- }
- return google.New(opts...)
- }
- func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
- opts := []google.Option{}
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, google.WithHTTPClient(httpClient))
- }
- if len(headers) > 0 {
- opts = append(opts, google.WithHeaders(headers))
- }
- project := options["project"]
- location := options["location"]
- opts = append(opts, google.WithVertex(project, location))
- return google.New(opts...)
- }
- func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
- opts := []hyper.Option{
- hyper.WithAPIKey(apiKey),
- }
- if c.cfg.Config().Options.Debug {
- httpClient := log.NewHTTPClient()
- opts = append(opts, hyper.WithHTTPClient(httpClient))
- }
- return hyper.New(opts...)
- }
- func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
- if model.Think {
- return true
- }
- opts, err := anthropic.ParseOptions(model.ProviderOptions)
- return err == nil && opts.Thinking != nil
- }
- func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
- headers := maps.Clone(providerCfg.ExtraHeaders)
- if headers == nil {
- headers = make(map[string]string)
- }
- // handle special headers for anthropic
- if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
- if v, ok := headers["anthropic-beta"]; ok {
- headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
- } else {
- headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
- }
- }
- apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
- baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
- switch providerCfg.Type {
- case openai.Name:
- return c.buildOpenaiProvider(baseURL, apiKey, headers)
- case anthropic.Name:
- return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
- case openrouter.Name:
- return c.buildOpenrouterProvider(baseURL, apiKey, headers)
- case vercel.Name:
- return c.buildVercelProvider(baseURL, apiKey, headers)
- case azure.Name:
- return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
- case bedrock.Name:
- return c.buildBedrockProvider(apiKey, headers)
- case google.Name:
- return c.buildGoogleProvider(baseURL, apiKey, headers)
- case "google-vertex":
- return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
- case openaicompat.Name:
- if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
- if providerCfg.ExtraBody == nil {
- providerCfg.ExtraBody = map[string]any{}
- }
- providerCfg.ExtraBody["tool_stream"] = true
- }
- return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
- case hyper.Name:
- return c.buildHyperProvider(apiKey)
- default:
- return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
- }
- }
- func isExactoSupported(modelID string) bool {
- supportedModels := []string{
- "moonshotai/kimi-k2-0905",
- "deepseek/deepseek-v3.1-terminus",
- "z-ai/glm-4.6",
- "openai/gpt-oss-120b",
- "qwen/qwen3-coder",
- }
- return slices.Contains(supportedModels, modelID)
- }
- func (c *coordinator) Cancel(sessionID string) {
- c.currentAgent.Cancel(sessionID)
- }
- func (c *coordinator) CancelAll() {
- c.currentAgent.CancelAll()
- }
- func (c *coordinator) ClearQueue(sessionID string) {
- c.currentAgent.ClearQueue(sessionID)
- }
- func (c *coordinator) IsBusy() bool {
- return c.currentAgent.IsBusy()
- }
- func (c *coordinator) IsSessionBusy(sessionID string) bool {
- return c.currentAgent.IsSessionBusy(sessionID)
- }
- func (c *coordinator) Model() Model {
- return c.currentAgent.Model()
- }
- func (c *coordinator) UpdateModels(ctx context.Context) error {
- // build the models again so we make sure we get the latest config
- large, small, err := c.buildAgentModels(ctx, false)
- if err != nil {
- return err
- }
- c.currentAgent.SetModels(large, small)
- agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
- if !ok {
- return errCoderAgentNotConfigured
- }
- tools, err := c.buildTools(ctx, agentCfg)
- if err != nil {
- return err
- }
- c.currentAgent.SetTools(tools)
- return nil
- }
- func (c *coordinator) QueuedPrompts(sessionID string) int {
- return c.currentAgent.QueuedPrompts(sessionID)
- }
- func (c *coordinator) QueuedPromptsList(sessionID string) []string {
- return c.currentAgent.QueuedPromptsList(sessionID)
- }
- func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
- providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
- if !ok {
- return errModelProviderNotConfigured
- }
- return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
- }
- func (c *coordinator) isUnauthorized(err error) bool {
- var providerErr *fantasy.ProviderError
- return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
- }
- func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
- if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
- slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
- return err
- }
- if err := c.UpdateModels(ctx); err != nil {
- return err
- }
- return nil
- }
- func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
- newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
- if err != nil {
- slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
- return err
- }
- providerCfg.APIKey = newAPIKey
- c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
- if err := c.UpdateModels(ctx); err != nil {
- return err
- }
- return nil
- }
- // subAgentParams holds the parameters for running a sub-agent.
- type subAgentParams struct {
- Agent SessionAgent
- SessionID string
- AgentMessageID string
- ToolCallID string
- Prompt string
- SessionTitle string
- // SessionSetup is an optional callback invoked after session creation
- // but before agent execution, for custom session configuration.
- SessionSetup func(sessionID string)
- }
- // runSubAgent runs a sub-agent and handles session management and cost accumulation.
- // It creates a sub-session, runs the agent with the given prompt, and propagates
- // the cost to the parent session.
- func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
- // Create sub-session
- agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
- session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
- }
- // Call session setup function if provided
- if params.SessionSetup != nil {
- params.SessionSetup(session.ID)
- }
- // Get model configuration
- model := params.Agent.Model()
- maxTokens := model.CatwalkCfg.DefaultMaxTokens
- if model.ModelCfg.MaxTokens != 0 {
- maxTokens = model.ModelCfg.MaxTokens
- }
- providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
- if !ok {
- return fantasy.ToolResponse{}, errModelProviderNotConfigured
- }
- // Run the agent
- result, err := params.Agent.Run(ctx, SessionAgentCall{
- SessionID: session.ID,
- Prompt: params.Prompt,
- MaxOutputTokens: maxTokens,
- ProviderOptions: getProviderOptions(model, providerCfg),
- Temperature: model.ModelCfg.Temperature,
- TopP: model.ModelCfg.TopP,
- TopK: model.ModelCfg.TopK,
- FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
- PresencePenalty: model.ModelCfg.PresencePenalty,
- NonInteractive: true,
- })
- if err != nil {
- return fantasy.NewTextErrorResponse("error generating response"), nil
- }
- // Update parent session cost
- if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
- return fantasy.ToolResponse{}, err
- }
- return fantasy.NewTextResponse(result.Response.Content.Text()), nil
- }
- // updateParentSessionCost accumulates the cost from a child session to its parent session.
- func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
- childSession, err := c.sessions.Get(ctx, childSessionID)
- if err != nil {
- return fmt.Errorf("get child session: %w", err)
- }
- parentSession, err := c.sessions.Get(ctx, parentSessionID)
- if err != nil {
- return fmt.Errorf("get parent session: %w", err)
- }
- parentSession.Cost += childSession.Cost
- if _, err := c.sessions.Save(ctx, parentSession); err != nil {
- return fmt.Errorf("save parent session: %w", err)
- }
- return nil
- }
|