coordinator.go 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093
  1. package agent
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "os"
  14. "path/filepath"
  15. "slices"
  16. "strings"
  17. "charm.land/catwalk/pkg/catwalk"
  18. "charm.land/fantasy"
  19. "github.com/charmbracelet/crush/internal/agent/hyper"
  20. "github.com/charmbracelet/crush/internal/agent/notify"
  21. "github.com/charmbracelet/crush/internal/agent/prompt"
  22. "github.com/charmbracelet/crush/internal/agent/tools"
  23. "github.com/charmbracelet/crush/internal/config"
  24. "github.com/charmbracelet/crush/internal/filetracker"
  25. "github.com/charmbracelet/crush/internal/history"
  26. "github.com/charmbracelet/crush/internal/home"
  27. "github.com/charmbracelet/crush/internal/log"
  28. "github.com/charmbracelet/crush/internal/lsp"
  29. "github.com/charmbracelet/crush/internal/message"
  30. "github.com/charmbracelet/crush/internal/oauth/copilot"
  31. "github.com/charmbracelet/crush/internal/permission"
  32. "github.com/charmbracelet/crush/internal/pubsub"
  33. "github.com/charmbracelet/crush/internal/session"
  34. "github.com/charmbracelet/crush/internal/skills"
  35. "golang.org/x/sync/errgroup"
  36. "charm.land/fantasy/providers/anthropic"
  37. "charm.land/fantasy/providers/azure"
  38. "charm.land/fantasy/providers/bedrock"
  39. "charm.land/fantasy/providers/google"
  40. "charm.land/fantasy/providers/openai"
  41. "charm.land/fantasy/providers/openaicompat"
  42. "charm.land/fantasy/providers/openrouter"
  43. "charm.land/fantasy/providers/vercel"
  44. openaisdk "github.com/charmbracelet/openai-go/option"
  45. "github.com/qjebbs/go-jsons"
  46. )
  47. // Coordinator errors.
  48. var (
  49. errCoderAgentNotConfigured = errors.New("coder agent not configured")
  50. errModelProviderNotConfigured = errors.New("model provider not configured")
  51. errLargeModelNotSelected = errors.New("large model not selected")
  52. errSmallModelNotSelected = errors.New("small model not selected")
  53. errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  54. errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  55. errLargeModelNotFound = errors.New("large model not found in provider config")
  56. errSmallModelNotFound = errors.New("small model not found in provider config")
  57. )
  58. type Coordinator interface {
  59. // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  60. // SetMainAgent(string)
  61. Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  62. Cancel(sessionID string)
  63. CancelAll()
  64. IsSessionBusy(sessionID string) bool
  65. IsBusy() bool
  66. QueuedPrompts(sessionID string) int
  67. QueuedPromptsList(sessionID string) []string
  68. ClearQueue(sessionID string)
  69. Summarize(context.Context, string) error
  70. Model() Model
  71. UpdateModels(ctx context.Context) error
  72. }
  73. type coordinator struct {
  74. cfg *config.ConfigStore
  75. sessions session.Service
  76. messages message.Service
  77. permissions permission.Service
  78. history history.Service
  79. filetracker filetracker.Service
  80. lspManager *lsp.Manager
  81. notify pubsub.Publisher[notify.Notification]
  82. currentAgent SessionAgent
  83. agents map[string]SessionAgent
  84. // Skills discovery results (session-start snapshot).
  85. allSkills []*skills.Skill // Pre-filter: all discovered after dedup.
  86. activeSkills []*skills.Skill // Post-filter: active skills only.
  87. skillTracker *skills.Tracker
  88. readyWg errgroup.Group
  89. }
  90. func NewCoordinator(
  91. ctx context.Context,
  92. cfg *config.ConfigStore,
  93. sessions session.Service,
  94. messages message.Service,
  95. permissions permission.Service,
  96. history history.Service,
  97. filetracker filetracker.Service,
  98. lspManager *lsp.Manager,
  99. notify pubsub.Publisher[notify.Notification],
  100. ) (Coordinator, error) {
  101. // Discover skills once at session start.
  102. allSkills, activeSkills := discoverSkills(cfg)
  103. skillTracker := skills.NewTracker(activeSkills)
  104. c := &coordinator{
  105. cfg: cfg,
  106. sessions: sessions,
  107. messages: messages,
  108. permissions: permissions,
  109. history: history,
  110. filetracker: filetracker,
  111. lspManager: lspManager,
  112. notify: notify,
  113. agents: make(map[string]SessionAgent),
  114. allSkills: allSkills,
  115. activeSkills: activeSkills,
  116. skillTracker: skillTracker,
  117. }
  118. agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
  119. if !ok {
  120. return nil, errCoderAgentNotConfigured
  121. }
  122. // TODO: make this dynamic when we support multiple agents
  123. prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
  124. if err != nil {
  125. return nil, err
  126. }
  127. agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
  128. if err != nil {
  129. return nil, err
  130. }
  131. c.currentAgent = agent
  132. c.agents[config.AgentCoder] = agent
  133. return c, nil
  134. }
  135. // Run implements Coordinator.
  136. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
  137. if err := c.readyWg.Wait(); err != nil {
  138. return nil, err
  139. }
  140. // refresh models before each run
  141. if err := c.UpdateModels(ctx); err != nil {
  142. return nil, fmt.Errorf("failed to update models: %w", err)
  143. }
  144. model := c.currentAgent.Model()
  145. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  146. if model.ModelCfg.MaxTokens != 0 {
  147. maxTokens = model.ModelCfg.MaxTokens
  148. }
  149. if !model.CatwalkCfg.SupportsImages && attachments != nil {
  150. // filter out image attachments
  151. filteredAttachments := make([]message.Attachment, 0, len(attachments))
  152. for _, att := range attachments {
  153. if att.IsText() {
  154. filteredAttachments = append(filteredAttachments, att)
  155. }
  156. }
  157. attachments = filteredAttachments
  158. }
  159. providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
  160. if !ok {
  161. return nil, errModelProviderNotConfigured
  162. }
  163. mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
  164. if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
  165. slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
  166. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  167. return nil, err
  168. }
  169. }
  170. run := func() (*fantasy.AgentResult, error) {
  171. return c.currentAgent.Run(ctx, SessionAgentCall{
  172. SessionID: sessionID,
  173. Prompt: prompt,
  174. Attachments: attachments,
  175. MaxOutputTokens: maxTokens,
  176. ProviderOptions: mergedOptions,
  177. Temperature: temp,
  178. TopP: topP,
  179. TopK: topK,
  180. FrequencyPenalty: freqPenalty,
  181. PresencePenalty: presPenalty,
  182. })
  183. }
  184. result, originalErr := run()
  185. if c.isUnauthorized(originalErr) {
  186. switch {
  187. case providerCfg.OAuthToken != nil:
  188. slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
  189. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  190. return nil, originalErr
  191. }
  192. slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
  193. return run()
  194. case strings.Contains(providerCfg.APIKeyTemplate, "$"):
  195. slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
  196. if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
  197. return nil, originalErr
  198. }
  199. slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
  200. return run()
  201. }
  202. }
  203. return result, originalErr
  204. }
  205. func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
  206. options := fantasy.ProviderOptions{}
  207. cfgOpts := []byte("{}")
  208. providerCfgOpts := []byte("{}")
  209. catwalkOpts := []byte("{}")
  210. if model.ModelCfg.ProviderOptions != nil {
  211. data, err := json.Marshal(model.ModelCfg.ProviderOptions)
  212. if err == nil {
  213. cfgOpts = data
  214. }
  215. }
  216. if providerCfg.ProviderOptions != nil {
  217. data, err := json.Marshal(providerCfg.ProviderOptions)
  218. if err == nil {
  219. providerCfgOpts = data
  220. }
  221. }
  222. if model.CatwalkCfg.Options.ProviderOptions != nil {
  223. data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
  224. if err == nil {
  225. catwalkOpts = data
  226. }
  227. }
  228. readers := []io.Reader{
  229. bytes.NewReader(catwalkOpts),
  230. bytes.NewReader(providerCfgOpts),
  231. bytes.NewReader(cfgOpts),
  232. }
  233. got, err := jsons.Merge(readers)
  234. if err != nil {
  235. slog.Error("Could not merge call config", "err", err)
  236. return options
  237. }
  238. mergedOptions := make(map[string]any)
  239. err = json.Unmarshal([]byte(got), &mergedOptions)
  240. if err != nil {
  241. slog.Error("Could not create config for call", "err", err)
  242. return options
  243. }
  244. providerType := providerCfg.Type
  245. if providerType == "hyper" {
  246. if strings.Contains(model.CatwalkCfg.ID, "claude") {
  247. providerType = anthropic.Name
  248. } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
  249. providerType = openai.Name
  250. } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
  251. providerType = google.Name
  252. } else {
  253. providerType = openaicompat.Name
  254. }
  255. }
  256. switch providerType {
  257. case openai.Name, azure.Name:
  258. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  259. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  260. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  261. }
  262. if openai.IsResponsesModel(model.CatwalkCfg.ID) {
  263. if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
  264. mergedOptions["reasoning_summary"] = "auto"
  265. mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
  266. }
  267. parsed, err := openai.ParseResponsesOptions(mergedOptions)
  268. if err == nil {
  269. options[openai.Name] = parsed
  270. }
  271. } else {
  272. parsed, err := openai.ParseOptions(mergedOptions)
  273. if err == nil {
  274. options[openai.Name] = parsed
  275. }
  276. }
  277. case anthropic.Name:
  278. var (
  279. _, hasEffort = mergedOptions["effort"]
  280. _, hasThink = mergedOptions["thinking"]
  281. )
  282. switch {
  283. case !hasEffort && model.ModelCfg.ReasoningEffort != "":
  284. mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
  285. case !hasThink && model.ModelCfg.Think:
  286. mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
  287. }
  288. parsed, err := anthropic.ParseOptions(mergedOptions)
  289. if err == nil {
  290. options[anthropic.Name] = parsed
  291. }
  292. case openrouter.Name:
  293. _, hasReasoning := mergedOptions["reasoning"]
  294. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  295. mergedOptions["reasoning"] = map[string]any{
  296. "enabled": true,
  297. "effort": model.ModelCfg.ReasoningEffort,
  298. }
  299. }
  300. parsed, err := openrouter.ParseOptions(mergedOptions)
  301. if err == nil {
  302. options[openrouter.Name] = parsed
  303. }
  304. case vercel.Name:
  305. _, hasReasoning := mergedOptions["reasoning"]
  306. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  307. mergedOptions["reasoning"] = map[string]any{
  308. "enabled": true,
  309. "effort": model.ModelCfg.ReasoningEffort,
  310. }
  311. }
  312. parsed, err := vercel.ParseOptions(mergedOptions)
  313. if err == nil {
  314. options[vercel.Name] = parsed
  315. }
  316. case google.Name:
  317. _, hasReasoning := mergedOptions["thinking_config"]
  318. if !hasReasoning {
  319. if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
  320. mergedOptions["thinking_config"] = map[string]any{
  321. "thinking_budget": 2000,
  322. "include_thoughts": true,
  323. }
  324. } else {
  325. mergedOptions["thinking_config"] = map[string]any{
  326. "thinking_level": model.ModelCfg.ReasoningEffort,
  327. "include_thoughts": true,
  328. }
  329. }
  330. }
  331. parsed, err := google.ParseOptions(mergedOptions)
  332. if err == nil {
  333. options[google.Name] = parsed
  334. }
  335. case openaicompat.Name:
  336. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  337. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  338. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  339. }
  340. parsed, err := openaicompat.ParseOptions(mergedOptions)
  341. if err == nil {
  342. options[openaicompat.Name] = parsed
  343. }
  344. }
  345. return options
  346. }
  347. func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
  348. modelOptions := getProviderOptions(model, cfg)
  349. temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
  350. topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
  351. topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
  352. freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
  353. presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
  354. return modelOptions, temp, topP, topK, freqPenalty, presPenalty
  355. }
  356. func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
  357. large, small, err := c.buildAgentModels(ctx, isSubAgent)
  358. if err != nil {
  359. return nil, err
  360. }
  361. largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
  362. result := NewSessionAgent(SessionAgentOptions{
  363. LargeModel: large,
  364. SmallModel: small,
  365. SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
  366. SystemPrompt: "",
  367. IsSubAgent: isSubAgent,
  368. DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
  369. IsYolo: c.permissions.SkipRequests(),
  370. Sessions: c.sessions,
  371. Messages: c.messages,
  372. Tools: nil,
  373. Notify: c.notify,
  374. })
  375. c.readyWg.Go(func() error {
  376. systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
  377. if err != nil {
  378. return err
  379. }
  380. result.SetSystemPrompt(systemPrompt)
  381. return nil
  382. })
  383. c.readyWg.Go(func() error {
  384. tools, err := c.buildTools(ctx, agent)
  385. if err != nil {
  386. return err
  387. }
  388. result.SetTools(tools)
  389. return nil
  390. })
  391. return result, nil
  392. }
  393. func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
  394. var allTools []fantasy.AgentTool
  395. if slices.Contains(agent.AllowedTools, AgentToolName) {
  396. agentTool, err := c.agentTool(ctx)
  397. if err != nil {
  398. return nil, err
  399. }
  400. allTools = append(allTools, agentTool)
  401. }
  402. if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
  403. agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
  404. if err != nil {
  405. return nil, err
  406. }
  407. allTools = append(allTools, agenticFetchTool)
  408. }
  409. // Get the model name for the agent
  410. modelName := ""
  411. if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
  412. if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
  413. modelName = model.Name
  414. }
  415. }
  416. logFile := filepath.Join(c.cfg.Config().Options.DataDirectory, "logs", "crush.log")
  417. allTools = append(allTools,
  418. tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
  419. tools.NewCrushInfoTool(c.cfg, c.lspManager, c.allSkills, c.activeSkills, c.skillTracker),
  420. tools.NewCrushLogsTool(logFile),
  421. tools.NewJobOutputTool(),
  422. tools.NewJobKillTool(),
  423. tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
  424. tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  425. tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  426. tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
  427. tools.NewGlobTool(c.cfg.WorkingDir()),
  428. tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
  429. tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
  430. tools.NewSourcegraphTool(nil),
  431. tools.NewTodosTool(c.sessions, c.cfg),
  432. tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.skillTracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
  433. tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  434. )
  435. // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
  436. if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
  437. allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
  438. }
  439. if len(c.cfg.Config().MCP) > 0 {
  440. allTools = append(
  441. allTools,
  442. tools.NewListMCPResourcesTool(c.cfg, c.permissions),
  443. tools.NewReadMCPResourceTool(c.cfg, c.permissions),
  444. )
  445. }
  446. var filteredTools []fantasy.AgentTool
  447. for _, tool := range allTools {
  448. if slices.Contains(agent.AllowedTools, tool.Info().Name) {
  449. filteredTools = append(filteredTools, tool)
  450. }
  451. }
  452. for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
  453. if agent.AllowedMCP == nil {
  454. // No MCP restrictions
  455. filteredTools = append(filteredTools, tool)
  456. continue
  457. }
  458. if len(agent.AllowedMCP) == 0 {
  459. // No MCPs allowed
  460. slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
  461. break
  462. }
  463. for mcp, tools := range agent.AllowedMCP {
  464. if mcp != tool.MCP() {
  465. continue
  466. }
  467. if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
  468. filteredTools = append(filteredTools, tool)
  469. break
  470. }
  471. slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
  472. }
  473. }
  474. slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
  475. return strings.Compare(a.Info().Name, b.Info().Name)
  476. })
  477. return filteredTools, nil
  478. }
  479. // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
  480. func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
  481. largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
  482. if !ok {
  483. return Model{}, Model{}, errLargeModelNotSelected
  484. }
  485. smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
  486. if !ok {
  487. return Model{}, Model{}, errSmallModelNotSelected
  488. }
  489. largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
  490. if !ok {
  491. return Model{}, Model{}, errLargeModelProviderNotConfigured
  492. }
  493. largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
  494. if err != nil {
  495. return Model{}, Model{}, err
  496. }
  497. smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
  498. if !ok {
  499. return Model{}, Model{}, errSmallModelProviderNotConfigured
  500. }
  501. smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
  502. if err != nil {
  503. return Model{}, Model{}, err
  504. }
  505. var largeCatwalkModel *catwalk.Model
  506. var smallCatwalkModel *catwalk.Model
  507. for _, m := range largeProviderCfg.Models {
  508. if m.ID == largeModelCfg.Model {
  509. largeCatwalkModel = &m
  510. }
  511. }
  512. for _, m := range smallProviderCfg.Models {
  513. if m.ID == smallModelCfg.Model {
  514. smallCatwalkModel = &m
  515. }
  516. }
  517. if largeCatwalkModel == nil {
  518. return Model{}, Model{}, errLargeModelNotFound
  519. }
  520. if smallCatwalkModel == nil {
  521. return Model{}, Model{}, errSmallModelNotFound
  522. }
  523. largeModelID := largeModelCfg.Model
  524. smallModelID := smallModelCfg.Model
  525. if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
  526. largeModelID += ":exacto"
  527. }
  528. if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
  529. smallModelID += ":exacto"
  530. }
  531. largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
  532. if err != nil {
  533. return Model{}, Model{}, err
  534. }
  535. smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
  536. if err != nil {
  537. return Model{}, Model{}, err
  538. }
  539. return Model{
  540. Model: largeModel,
  541. CatwalkCfg: *largeCatwalkModel,
  542. ModelCfg: largeModelCfg,
  543. }, Model{
  544. Model: smallModel,
  545. CatwalkCfg: *smallCatwalkModel,
  546. ModelCfg: smallModelCfg,
  547. }, nil
  548. }
  549. func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
  550. var opts []anthropic.Option
  551. switch {
  552. case strings.HasPrefix(apiKey, "Bearer "):
  553. // NOTE: Prevent the SDK from picking up the API key from env.
  554. os.Setenv("ANTHROPIC_API_KEY", "")
  555. headers["Authorization"] = apiKey
  556. case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
  557. // NOTE: Prevent the SDK from picking up the API key from env.
  558. os.Setenv("ANTHROPIC_API_KEY", "")
  559. headers["Authorization"] = "Bearer " + apiKey
  560. case apiKey != "":
  561. // X-Api-Key header
  562. opts = append(opts, anthropic.WithAPIKey(apiKey))
  563. }
  564. if len(headers) > 0 {
  565. opts = append(opts, anthropic.WithHeaders(headers))
  566. }
  567. if baseURL != "" {
  568. opts = append(opts, anthropic.WithBaseURL(baseURL))
  569. }
  570. if c.cfg.Config().Options.Debug {
  571. httpClient := log.NewHTTPClient()
  572. opts = append(opts, anthropic.WithHTTPClient(httpClient))
  573. }
  574. return anthropic.New(opts...)
  575. }
  576. func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  577. opts := []openai.Option{
  578. openai.WithAPIKey(apiKey),
  579. openai.WithUseResponsesAPI(),
  580. }
  581. if c.cfg.Config().Options.Debug {
  582. httpClient := log.NewHTTPClient()
  583. opts = append(opts, openai.WithHTTPClient(httpClient))
  584. }
  585. if len(headers) > 0 {
  586. opts = append(opts, openai.WithHeaders(headers))
  587. }
  588. if baseURL != "" {
  589. opts = append(opts, openai.WithBaseURL(baseURL))
  590. }
  591. return openai.New(opts...)
  592. }
  593. func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  594. opts := []openrouter.Option{
  595. openrouter.WithAPIKey(apiKey),
  596. }
  597. if c.cfg.Config().Options.Debug {
  598. httpClient := log.NewHTTPClient()
  599. opts = append(opts, openrouter.WithHTTPClient(httpClient))
  600. }
  601. if len(headers) > 0 {
  602. opts = append(opts, openrouter.WithHeaders(headers))
  603. }
  604. return openrouter.New(opts...)
  605. }
  606. func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  607. opts := []vercel.Option{
  608. vercel.WithAPIKey(apiKey),
  609. }
  610. if c.cfg.Config().Options.Debug {
  611. httpClient := log.NewHTTPClient()
  612. opts = append(opts, vercel.WithHTTPClient(httpClient))
  613. }
  614. if len(headers) > 0 {
  615. opts = append(opts, vercel.WithHeaders(headers))
  616. }
  617. return vercel.New(opts...)
  618. }
  619. func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
  620. opts := []openaicompat.Option{
  621. openaicompat.WithBaseURL(baseURL),
  622. openaicompat.WithAPIKey(apiKey),
  623. }
  624. // Set HTTP client based on provider and debug mode.
  625. var httpClient *http.Client
  626. if providerID == string(catwalk.InferenceProviderCopilot) {
  627. opts = append(opts, openaicompat.WithUseResponsesAPI())
  628. httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
  629. } else if c.cfg.Config().Options.Debug {
  630. httpClient = log.NewHTTPClient()
  631. }
  632. if httpClient != nil {
  633. opts = append(opts, openaicompat.WithHTTPClient(httpClient))
  634. }
  635. if len(headers) > 0 {
  636. opts = append(opts, openaicompat.WithHeaders(headers))
  637. }
  638. for extraKey, extraValue := range extraBody {
  639. opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
  640. }
  641. return openaicompat.New(opts...)
  642. }
  643. func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  644. opts := []azure.Option{
  645. azure.WithBaseURL(baseURL),
  646. azure.WithAPIKey(apiKey),
  647. azure.WithUseResponsesAPI(),
  648. }
  649. if c.cfg.Config().Options.Debug {
  650. httpClient := log.NewHTTPClient()
  651. opts = append(opts, azure.WithHTTPClient(httpClient))
  652. }
  653. if options == nil {
  654. options = make(map[string]string)
  655. }
  656. if apiVersion, ok := options["apiVersion"]; ok {
  657. opts = append(opts, azure.WithAPIVersion(apiVersion))
  658. }
  659. if len(headers) > 0 {
  660. opts = append(opts, azure.WithHeaders(headers))
  661. }
  662. return azure.New(opts...)
  663. }
  664. func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
  665. var opts []bedrock.Option
  666. if c.cfg.Config().Options.Debug {
  667. httpClient := log.NewHTTPClient()
  668. opts = append(opts, bedrock.WithHTTPClient(httpClient))
  669. }
  670. if len(headers) > 0 {
  671. opts = append(opts, bedrock.WithHeaders(headers))
  672. }
  673. switch {
  674. case apiKey != "":
  675. opts = append(opts, bedrock.WithAPIKey(apiKey))
  676. case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
  677. opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
  678. default:
  679. // Skip, let the SDK do authentication.
  680. }
  681. return bedrock.New(opts...)
  682. }
  683. func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  684. opts := []google.Option{
  685. google.WithBaseURL(baseURL),
  686. google.WithGeminiAPIKey(apiKey),
  687. }
  688. if c.cfg.Config().Options.Debug {
  689. httpClient := log.NewHTTPClient()
  690. opts = append(opts, google.WithHTTPClient(httpClient))
  691. }
  692. if len(headers) > 0 {
  693. opts = append(opts, google.WithHeaders(headers))
  694. }
  695. return google.New(opts...)
  696. }
  697. func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  698. opts := []google.Option{}
  699. if c.cfg.Config().Options.Debug {
  700. httpClient := log.NewHTTPClient()
  701. opts = append(opts, google.WithHTTPClient(httpClient))
  702. }
  703. if len(headers) > 0 {
  704. opts = append(opts, google.WithHeaders(headers))
  705. }
  706. project := options["project"]
  707. location := options["location"]
  708. opts = append(opts, google.WithVertex(project, location))
  709. return google.New(opts...)
  710. }
  711. func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
  712. opts := []hyper.Option{
  713. hyper.WithAPIKey(apiKey),
  714. }
  715. if c.cfg.Config().Options.Debug {
  716. httpClient := log.NewHTTPClient()
  717. opts = append(opts, hyper.WithHTTPClient(httpClient))
  718. }
  719. return hyper.New(opts...)
  720. }
  721. func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
  722. if model.Think {
  723. return true
  724. }
  725. opts, err := anthropic.ParseOptions(model.ProviderOptions)
  726. return err == nil && opts.Thinking != nil
  727. }
  728. func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
  729. headers := maps.Clone(providerCfg.ExtraHeaders)
  730. if headers == nil {
  731. headers = make(map[string]string)
  732. }
  733. // handle special headers for anthropic
  734. if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
  735. if v, ok := headers["anthropic-beta"]; ok {
  736. headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
  737. } else {
  738. headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
  739. }
  740. }
  741. apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
  742. baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
  743. switch providerCfg.Type {
  744. case openai.Name:
  745. return c.buildOpenaiProvider(baseURL, apiKey, headers)
  746. case anthropic.Name:
  747. return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
  748. case openrouter.Name:
  749. return c.buildOpenrouterProvider(baseURL, apiKey, headers)
  750. case vercel.Name:
  751. return c.buildVercelProvider(baseURL, apiKey, headers)
  752. case azure.Name:
  753. return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
  754. case bedrock.Name:
  755. return c.buildBedrockProvider(apiKey, headers)
  756. case google.Name:
  757. return c.buildGoogleProvider(baseURL, apiKey, headers)
  758. case "google-vertex":
  759. return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
  760. case openaicompat.Name:
  761. if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
  762. if providerCfg.ExtraBody == nil {
  763. providerCfg.ExtraBody = map[string]any{}
  764. }
  765. providerCfg.ExtraBody["tool_stream"] = true
  766. }
  767. return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
  768. case hyper.Name:
  769. return c.buildHyperProvider(apiKey)
  770. default:
  771. return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
  772. }
  773. }
  774. func isExactoSupported(modelID string) bool {
  775. supportedModels := []string{
  776. "moonshotai/kimi-k2-0905",
  777. "deepseek/deepseek-v3.1-terminus",
  778. "z-ai/glm-4.6",
  779. "openai/gpt-oss-120b",
  780. "qwen/qwen3-coder",
  781. }
  782. return slices.Contains(supportedModels, modelID)
  783. }
  784. func (c *coordinator) Cancel(sessionID string) {
  785. c.currentAgent.Cancel(sessionID)
  786. }
  787. func (c *coordinator) CancelAll() {
  788. c.currentAgent.CancelAll()
  789. }
  790. func (c *coordinator) ClearQueue(sessionID string) {
  791. c.currentAgent.ClearQueue(sessionID)
  792. }
  793. func (c *coordinator) IsBusy() bool {
  794. return c.currentAgent.IsBusy()
  795. }
  796. func (c *coordinator) IsSessionBusy(sessionID string) bool {
  797. return c.currentAgent.IsSessionBusy(sessionID)
  798. }
  799. func (c *coordinator) Model() Model {
  800. return c.currentAgent.Model()
  801. }
  802. func (c *coordinator) UpdateModels(ctx context.Context) error {
  803. // build the models again so we make sure we get the latest config
  804. large, small, err := c.buildAgentModels(ctx, false)
  805. if err != nil {
  806. return err
  807. }
  808. c.currentAgent.SetModels(large, small)
  809. agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
  810. if !ok {
  811. return errCoderAgentNotConfigured
  812. }
  813. tools, err := c.buildTools(ctx, agentCfg)
  814. if err != nil {
  815. return err
  816. }
  817. c.currentAgent.SetTools(tools)
  818. return nil
  819. }
  820. func (c *coordinator) QueuedPrompts(sessionID string) int {
  821. return c.currentAgent.QueuedPrompts(sessionID)
  822. }
  823. func (c *coordinator) QueuedPromptsList(sessionID string) []string {
  824. return c.currentAgent.QueuedPromptsList(sessionID)
  825. }
  826. func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
  827. providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
  828. if !ok {
  829. return errModelProviderNotConfigured
  830. }
  831. return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
  832. }
  833. func (c *coordinator) isUnauthorized(err error) bool {
  834. var providerErr *fantasy.ProviderError
  835. return (errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized) ||
  836. errors.Is(err, hyper.ErrUnauthorized)
  837. }
  838. func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
  839. if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
  840. slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
  841. return err
  842. }
  843. if err := c.UpdateModels(ctx); err != nil {
  844. return err
  845. }
  846. return nil
  847. }
  848. func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
  849. newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
  850. if err != nil {
  851. slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
  852. return err
  853. }
  854. providerCfg.APIKey = newAPIKey
  855. c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
  856. if err := c.UpdateModels(ctx); err != nil {
  857. return err
  858. }
  859. return nil
  860. }
  861. // subAgentParams holds the parameters for running a sub-agent.
  862. type subAgentParams struct {
  863. Agent SessionAgent
  864. SessionID string
  865. AgentMessageID string
  866. ToolCallID string
  867. Prompt string
  868. SessionTitle string
  869. // SessionSetup is an optional callback invoked after session creation
  870. // but before agent execution, for custom session configuration.
  871. SessionSetup func(sessionID string)
  872. }
  873. // runSubAgent runs a sub-agent and handles session management and cost accumulation.
  874. // It creates a sub-session, runs the agent with the given prompt, and propagates
  875. // the cost to the parent session.
  876. func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
  877. // Create sub-session
  878. agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
  879. session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
  880. if err != nil {
  881. return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
  882. }
  883. // Call session setup function if provided
  884. if params.SessionSetup != nil {
  885. params.SessionSetup(session.ID)
  886. }
  887. // Get model configuration
  888. model := params.Agent.Model()
  889. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  890. if model.ModelCfg.MaxTokens != 0 {
  891. maxTokens = model.ModelCfg.MaxTokens
  892. }
  893. providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
  894. if !ok {
  895. return fantasy.ToolResponse{}, errModelProviderNotConfigured
  896. }
  897. // Run the agent
  898. result, err := params.Agent.Run(ctx, SessionAgentCall{
  899. SessionID: session.ID,
  900. Prompt: params.Prompt,
  901. MaxOutputTokens: maxTokens,
  902. ProviderOptions: getProviderOptions(model, providerCfg),
  903. Temperature: model.ModelCfg.Temperature,
  904. TopP: model.ModelCfg.TopP,
  905. TopK: model.ModelCfg.TopK,
  906. FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
  907. PresencePenalty: model.ModelCfg.PresencePenalty,
  908. NonInteractive: true,
  909. })
  910. if err != nil {
  911. return fantasy.NewTextErrorResponse("error generating response"), nil
  912. }
  913. // Update parent session cost
  914. if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
  915. return fantasy.ToolResponse{}, err
  916. }
  917. return fantasy.NewTextResponse(result.Response.Content.Text()), nil
  918. }
  919. // updateParentSessionCost accumulates the cost from a child session to its parent session.
  920. func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
  921. childSession, err := c.sessions.Get(ctx, childSessionID)
  922. if err != nil {
  923. return fmt.Errorf("get child session: %w", err)
  924. }
  925. parentSession, err := c.sessions.Get(ctx, parentSessionID)
  926. if err != nil {
  927. return fmt.Errorf("get parent session: %w", err)
  928. }
  929. parentSession.Cost += childSession.Cost
  930. models := maps.Clone(c.cfg.Config().Models)
  931. for name, model := range models {
  932. model.ProviderOptions = maps.Clone(model.ProviderOptions)
  933. models[name] = model
  934. }
  935. if _, err := c.sessions.SaveWithModels(ctx, parentSession, models); err != nil {
  936. return fmt.Errorf("save parent session: %w", err)
  937. }
  938. return nil
  939. }
  940. // discoverSkills runs the skill discovery pipeline and returns both the
  941. // pre-filter (all discovered, after dedup) and post-filter (active) lists.
  942. func discoverSkills(cfg *config.ConfigStore) (allSkills, activeSkills []*skills.Skill) {
  943. discovered := skills.DiscoverBuiltin()
  944. opts := cfg.Config().Options
  945. if opts != nil && len(opts.SkillsPaths) > 0 {
  946. expandedPaths := make([]string, 0, len(opts.SkillsPaths))
  947. for _, pth := range opts.SkillsPaths {
  948. expanded := home.Long(pth)
  949. if strings.HasPrefix(expanded, "$") {
  950. if resolved, err := cfg.Resolver().ResolveValue(expanded); err == nil {
  951. expanded = resolved
  952. }
  953. }
  954. expandedPaths = append(expandedPaths, expanded)
  955. }
  956. discovered = append(discovered, skills.Discover(expandedPaths)...)
  957. }
  958. allSkills = skills.Deduplicate(discovered)
  959. var disabledSkills []string
  960. if opts != nil {
  961. disabledSkills = opts.DisabledSkills
  962. }
  963. activeSkills = skills.Filter(allSkills, disabledSkills)
  964. return allSkills, activeSkills
  965. }