coordinator.go 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. package agent
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "os"
  14. "slices"
  15. "strings"
  16. "charm.land/catwalk/pkg/catwalk"
  17. "charm.land/fantasy"
  18. "github.com/charmbracelet/crush/internal/agent/hyper"
  19. "github.com/charmbracelet/crush/internal/agent/notify"
  20. "github.com/charmbracelet/crush/internal/agent/prompt"
  21. "github.com/charmbracelet/crush/internal/agent/tools"
  22. "github.com/charmbracelet/crush/internal/config"
  23. "github.com/charmbracelet/crush/internal/filetracker"
  24. "github.com/charmbracelet/crush/internal/history"
  25. "github.com/charmbracelet/crush/internal/log"
  26. "github.com/charmbracelet/crush/internal/lsp"
  27. "github.com/charmbracelet/crush/internal/message"
  28. "github.com/charmbracelet/crush/internal/oauth/copilot"
  29. "github.com/charmbracelet/crush/internal/permission"
  30. "github.com/charmbracelet/crush/internal/pubsub"
  31. "github.com/charmbracelet/crush/internal/session"
  32. "golang.org/x/sync/errgroup"
  33. "charm.land/fantasy/providers/anthropic"
  34. "charm.land/fantasy/providers/azure"
  35. "charm.land/fantasy/providers/bedrock"
  36. "charm.land/fantasy/providers/google"
  37. "charm.land/fantasy/providers/openai"
  38. "charm.land/fantasy/providers/openaicompat"
  39. "charm.land/fantasy/providers/openrouter"
  40. "charm.land/fantasy/providers/vercel"
  41. openaisdk "github.com/charmbracelet/openai-go/option"
  42. "github.com/qjebbs/go-jsons"
  43. )
  44. // Coordinator errors.
  45. var (
  46. errCoderAgentNotConfigured = errors.New("coder agent not configured")
  47. errModelProviderNotConfigured = errors.New("model provider not configured")
  48. errLargeModelNotSelected = errors.New("large model not selected")
  49. errSmallModelNotSelected = errors.New("small model not selected")
  50. errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  51. errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  52. errLargeModelNotFound = errors.New("large model not found in provider config")
  53. errSmallModelNotFound = errors.New("small model not found in provider config")
  54. )
  55. type Coordinator interface {
  56. // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  57. // SetMainAgent(string)
  58. Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  59. Cancel(sessionID string)
  60. CancelAll()
  61. IsSessionBusy(sessionID string) bool
  62. IsBusy() bool
  63. QueuedPrompts(sessionID string) int
  64. QueuedPromptsList(sessionID string) []string
  65. ClearQueue(sessionID string)
  66. Summarize(context.Context, string) error
  67. Model() Model
  68. UpdateModels(ctx context.Context) error
  69. }
  70. type coordinator struct {
  71. cfg *config.ConfigStore
  72. sessions session.Service
  73. messages message.Service
  74. permissions permission.Service
  75. history history.Service
  76. filetracker filetracker.Service
  77. lspManager *lsp.Manager
  78. notify pubsub.Publisher[notify.Notification]
  79. currentAgent SessionAgent
  80. agents map[string]SessionAgent
  81. readyWg errgroup.Group
  82. }
  83. func NewCoordinator(
  84. ctx context.Context,
  85. cfg *config.ConfigStore,
  86. sessions session.Service,
  87. messages message.Service,
  88. permissions permission.Service,
  89. history history.Service,
  90. filetracker filetracker.Service,
  91. lspManager *lsp.Manager,
  92. notify pubsub.Publisher[notify.Notification],
  93. ) (Coordinator, error) {
  94. c := &coordinator{
  95. cfg: cfg,
  96. sessions: sessions,
  97. messages: messages,
  98. permissions: permissions,
  99. history: history,
  100. filetracker: filetracker,
  101. lspManager: lspManager,
  102. notify: notify,
  103. agents: make(map[string]SessionAgent),
  104. }
  105. agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
  106. if !ok {
  107. return nil, errCoderAgentNotConfigured
  108. }
  109. // TODO: make this dynamic when we support multiple agents
  110. prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
  111. if err != nil {
  112. return nil, err
  113. }
  114. agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
  115. if err != nil {
  116. return nil, err
  117. }
  118. c.currentAgent = agent
  119. c.agents[config.AgentCoder] = agent
  120. return c, nil
  121. }
  122. // Run implements Coordinator.
  123. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
  124. if err := c.readyWg.Wait(); err != nil {
  125. return nil, err
  126. }
  127. // refresh models before each run
  128. if err := c.UpdateModels(ctx); err != nil {
  129. return nil, fmt.Errorf("failed to update models: %w", err)
  130. }
  131. model := c.currentAgent.Model()
  132. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  133. if model.ModelCfg.MaxTokens != 0 {
  134. maxTokens = model.ModelCfg.MaxTokens
  135. }
  136. if !model.CatwalkCfg.SupportsImages && attachments != nil {
  137. // filter out image attachments
  138. filteredAttachments := make([]message.Attachment, 0, len(attachments))
  139. for _, att := range attachments {
  140. if att.IsText() {
  141. filteredAttachments = append(filteredAttachments, att)
  142. }
  143. }
  144. attachments = filteredAttachments
  145. }
  146. providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
  147. if !ok {
  148. return nil, errModelProviderNotConfigured
  149. }
  150. mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
  151. if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
  152. slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
  153. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  154. return nil, err
  155. }
  156. }
  157. run := func() (*fantasy.AgentResult, error) {
  158. return c.currentAgent.Run(ctx, SessionAgentCall{
  159. SessionID: sessionID,
  160. Prompt: prompt,
  161. Attachments: attachments,
  162. MaxOutputTokens: maxTokens,
  163. ProviderOptions: mergedOptions,
  164. Temperature: temp,
  165. TopP: topP,
  166. TopK: topK,
  167. FrequencyPenalty: freqPenalty,
  168. PresencePenalty: presPenalty,
  169. })
  170. }
  171. result, originalErr := run()
  172. if c.isUnauthorized(originalErr) {
  173. switch {
  174. case providerCfg.OAuthToken != nil:
  175. slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
  176. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  177. return nil, originalErr
  178. }
  179. slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
  180. return run()
  181. case strings.Contains(providerCfg.APIKeyTemplate, "$"):
  182. slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
  183. if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
  184. return nil, originalErr
  185. }
  186. slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
  187. return run()
  188. }
  189. }
  190. return result, originalErr
  191. }
  192. func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
  193. options := fantasy.ProviderOptions{}
  194. cfgOpts := []byte("{}")
  195. providerCfgOpts := []byte("{}")
  196. catwalkOpts := []byte("{}")
  197. if model.ModelCfg.ProviderOptions != nil {
  198. data, err := json.Marshal(model.ModelCfg.ProviderOptions)
  199. if err == nil {
  200. cfgOpts = data
  201. }
  202. }
  203. if providerCfg.ProviderOptions != nil {
  204. data, err := json.Marshal(providerCfg.ProviderOptions)
  205. if err == nil {
  206. providerCfgOpts = data
  207. }
  208. }
  209. if model.CatwalkCfg.Options.ProviderOptions != nil {
  210. data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
  211. if err == nil {
  212. catwalkOpts = data
  213. }
  214. }
  215. readers := []io.Reader{
  216. bytes.NewReader(catwalkOpts),
  217. bytes.NewReader(providerCfgOpts),
  218. bytes.NewReader(cfgOpts),
  219. }
  220. got, err := jsons.Merge(readers)
  221. if err != nil {
  222. slog.Error("Could not merge call config", "err", err)
  223. return options
  224. }
  225. mergedOptions := make(map[string]any)
  226. err = json.Unmarshal([]byte(got), &mergedOptions)
  227. if err != nil {
  228. slog.Error("Could not create config for call", "err", err)
  229. return options
  230. }
  231. providerType := providerCfg.Type
  232. if providerType == "hyper" {
  233. if strings.Contains(model.CatwalkCfg.ID, "claude") {
  234. providerType = anthropic.Name
  235. } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
  236. providerType = openai.Name
  237. } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
  238. providerType = google.Name
  239. } else {
  240. providerType = openaicompat.Name
  241. }
  242. }
  243. switch providerType {
  244. case openai.Name, azure.Name:
  245. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  246. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  247. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  248. }
  249. if openai.IsResponsesModel(model.CatwalkCfg.ID) {
  250. if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
  251. mergedOptions["reasoning_summary"] = "auto"
  252. mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
  253. }
  254. parsed, err := openai.ParseResponsesOptions(mergedOptions)
  255. if err == nil {
  256. options[openai.Name] = parsed
  257. }
  258. } else {
  259. parsed, err := openai.ParseOptions(mergedOptions)
  260. if err == nil {
  261. options[openai.Name] = parsed
  262. }
  263. }
  264. case anthropic.Name:
  265. var (
  266. _, hasEffort = mergedOptions["effort"]
  267. _, hasThink = mergedOptions["thinking"]
  268. )
  269. switch {
  270. case !hasEffort && model.ModelCfg.ReasoningEffort != "":
  271. mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
  272. case !hasThink && model.ModelCfg.Think:
  273. mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
  274. }
  275. parsed, err := anthropic.ParseOptions(mergedOptions)
  276. if err == nil {
  277. options[anthropic.Name] = parsed
  278. }
  279. case openrouter.Name:
  280. _, hasReasoning := mergedOptions["reasoning"]
  281. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  282. mergedOptions["reasoning"] = map[string]any{
  283. "enabled": true,
  284. "effort": model.ModelCfg.ReasoningEffort,
  285. }
  286. }
  287. parsed, err := openrouter.ParseOptions(mergedOptions)
  288. if err == nil {
  289. options[openrouter.Name] = parsed
  290. }
  291. case vercel.Name:
  292. _, hasReasoning := mergedOptions["reasoning"]
  293. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  294. mergedOptions["reasoning"] = map[string]any{
  295. "enabled": true,
  296. "effort": model.ModelCfg.ReasoningEffort,
  297. }
  298. }
  299. parsed, err := vercel.ParseOptions(mergedOptions)
  300. if err == nil {
  301. options[vercel.Name] = parsed
  302. }
  303. case google.Name:
  304. _, hasReasoning := mergedOptions["thinking_config"]
  305. if !hasReasoning {
  306. if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
  307. mergedOptions["thinking_config"] = map[string]any{
  308. "thinking_budget": 2000,
  309. "include_thoughts": true,
  310. }
  311. } else {
  312. mergedOptions["thinking_config"] = map[string]any{
  313. "thinking_level": model.ModelCfg.ReasoningEffort,
  314. "include_thoughts": true,
  315. }
  316. }
  317. }
  318. parsed, err := google.ParseOptions(mergedOptions)
  319. if err == nil {
  320. options[google.Name] = parsed
  321. }
  322. case openaicompat.Name:
  323. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  324. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  325. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  326. }
  327. parsed, err := openaicompat.ParseOptions(mergedOptions)
  328. if err == nil {
  329. options[openaicompat.Name] = parsed
  330. }
  331. }
  332. return options
  333. }
  334. func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
  335. modelOptions := getProviderOptions(model, cfg)
  336. temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
  337. topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
  338. topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
  339. freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
  340. presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
  341. return modelOptions, temp, topP, topK, freqPenalty, presPenalty
  342. }
  343. func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
  344. large, small, err := c.buildAgentModels(ctx, isSubAgent)
  345. if err != nil {
  346. return nil, err
  347. }
  348. largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
  349. result := NewSessionAgent(SessionAgentOptions{
  350. LargeModel: large,
  351. SmallModel: small,
  352. SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
  353. SystemPrompt: "",
  354. IsSubAgent: isSubAgent,
  355. DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
  356. IsYolo: c.permissions.SkipRequests(),
  357. Sessions: c.sessions,
  358. Messages: c.messages,
  359. Tools: nil,
  360. Notify: c.notify,
  361. })
  362. c.readyWg.Go(func() error {
  363. systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
  364. if err != nil {
  365. return err
  366. }
  367. result.SetSystemPrompt(systemPrompt)
  368. return nil
  369. })
  370. c.readyWg.Go(func() error {
  371. tools, err := c.buildTools(ctx, agent)
  372. if err != nil {
  373. return err
  374. }
  375. result.SetTools(tools)
  376. return nil
  377. })
  378. return result, nil
  379. }
  380. func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
  381. var allTools []fantasy.AgentTool
  382. if slices.Contains(agent.AllowedTools, AgentToolName) {
  383. agentTool, err := c.agentTool(ctx)
  384. if err != nil {
  385. return nil, err
  386. }
  387. allTools = append(allTools, agentTool)
  388. }
  389. if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
  390. agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
  391. if err != nil {
  392. return nil, err
  393. }
  394. allTools = append(allTools, agenticFetchTool)
  395. }
  396. // Get the model name for the agent
  397. modelName := ""
  398. if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
  399. if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
  400. modelName = model.Name
  401. }
  402. }
  403. allTools = append(allTools,
  404. tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
  405. tools.NewJobOutputTool(),
  406. tools.NewJobKillTool(),
  407. tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
  408. tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  409. tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  410. tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
  411. tools.NewGlobTool(c.cfg.WorkingDir()),
  412. tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
  413. tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
  414. tools.NewSourcegraphTool(nil),
  415. tools.NewTodosTool(c.sessions),
  416. tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
  417. tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  418. )
  419. // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
  420. if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
  421. allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
  422. }
  423. if len(c.cfg.Config().MCP) > 0 {
  424. allTools = append(
  425. allTools,
  426. tools.NewListMCPResourcesTool(c.cfg, c.permissions),
  427. tools.NewReadMCPResourceTool(c.cfg, c.permissions),
  428. )
  429. }
  430. var filteredTools []fantasy.AgentTool
  431. for _, tool := range allTools {
  432. if slices.Contains(agent.AllowedTools, tool.Info().Name) {
  433. filteredTools = append(filteredTools, tool)
  434. }
  435. }
  436. for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
  437. if agent.AllowedMCP == nil {
  438. // No MCP restrictions
  439. filteredTools = append(filteredTools, tool)
  440. continue
  441. }
  442. if len(agent.AllowedMCP) == 0 {
  443. // No MCPs allowed
  444. slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
  445. break
  446. }
  447. for mcp, tools := range agent.AllowedMCP {
  448. if mcp != tool.MCP() {
  449. continue
  450. }
  451. if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
  452. filteredTools = append(filteredTools, tool)
  453. break
  454. }
  455. slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
  456. }
  457. }
  458. slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
  459. return strings.Compare(a.Info().Name, b.Info().Name)
  460. })
  461. return filteredTools, nil
  462. }
  463. // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
  464. func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
  465. largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
  466. if !ok {
  467. return Model{}, Model{}, errLargeModelNotSelected
  468. }
  469. smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
  470. if !ok {
  471. return Model{}, Model{}, errSmallModelNotSelected
  472. }
  473. largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
  474. if !ok {
  475. return Model{}, Model{}, errLargeModelProviderNotConfigured
  476. }
  477. largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
  478. if err != nil {
  479. return Model{}, Model{}, err
  480. }
  481. smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
  482. if !ok {
  483. return Model{}, Model{}, errSmallModelProviderNotConfigured
  484. }
  485. smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
  486. if err != nil {
  487. return Model{}, Model{}, err
  488. }
  489. var largeCatwalkModel *catwalk.Model
  490. var smallCatwalkModel *catwalk.Model
  491. for _, m := range largeProviderCfg.Models {
  492. if m.ID == largeModelCfg.Model {
  493. largeCatwalkModel = &m
  494. }
  495. }
  496. for _, m := range smallProviderCfg.Models {
  497. if m.ID == smallModelCfg.Model {
  498. smallCatwalkModel = &m
  499. }
  500. }
  501. if largeCatwalkModel == nil {
  502. return Model{}, Model{}, errLargeModelNotFound
  503. }
  504. if smallCatwalkModel == nil {
  505. return Model{}, Model{}, errSmallModelNotFound
  506. }
  507. largeModelID := largeModelCfg.Model
  508. smallModelID := smallModelCfg.Model
  509. if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
  510. largeModelID += ":exacto"
  511. }
  512. if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
  513. smallModelID += ":exacto"
  514. }
  515. largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
  516. if err != nil {
  517. return Model{}, Model{}, err
  518. }
  519. smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
  520. if err != nil {
  521. return Model{}, Model{}, err
  522. }
  523. return Model{
  524. Model: largeModel,
  525. CatwalkCfg: *largeCatwalkModel,
  526. ModelCfg: largeModelCfg,
  527. }, Model{
  528. Model: smallModel,
  529. CatwalkCfg: *smallCatwalkModel,
  530. ModelCfg: smallModelCfg,
  531. }, nil
  532. }
  533. func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
  534. var opts []anthropic.Option
  535. switch {
  536. case strings.HasPrefix(apiKey, "Bearer "):
  537. // NOTE: Prevent the SDK from picking up the API key from env.
  538. os.Setenv("ANTHROPIC_API_KEY", "")
  539. headers["Authorization"] = apiKey
  540. case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
  541. // NOTE: Prevent the SDK from picking up the API key from env.
  542. os.Setenv("ANTHROPIC_API_KEY", "")
  543. headers["Authorization"] = "Bearer " + apiKey
  544. case apiKey != "":
  545. // X-Api-Key header
  546. opts = append(opts, anthropic.WithAPIKey(apiKey))
  547. }
  548. if len(headers) > 0 {
  549. opts = append(opts, anthropic.WithHeaders(headers))
  550. }
  551. if baseURL != "" {
  552. opts = append(opts, anthropic.WithBaseURL(baseURL))
  553. }
  554. if c.cfg.Config().Options.Debug {
  555. httpClient := log.NewHTTPClient()
  556. opts = append(opts, anthropic.WithHTTPClient(httpClient))
  557. }
  558. return anthropic.New(opts...)
  559. }
  560. func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  561. opts := []openai.Option{
  562. openai.WithAPIKey(apiKey),
  563. openai.WithUseResponsesAPI(),
  564. }
  565. if c.cfg.Config().Options.Debug {
  566. httpClient := log.NewHTTPClient()
  567. opts = append(opts, openai.WithHTTPClient(httpClient))
  568. }
  569. if len(headers) > 0 {
  570. opts = append(opts, openai.WithHeaders(headers))
  571. }
  572. if baseURL != "" {
  573. opts = append(opts, openai.WithBaseURL(baseURL))
  574. }
  575. return openai.New(opts...)
  576. }
  577. func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  578. opts := []openrouter.Option{
  579. openrouter.WithAPIKey(apiKey),
  580. }
  581. if c.cfg.Config().Options.Debug {
  582. httpClient := log.NewHTTPClient()
  583. opts = append(opts, openrouter.WithHTTPClient(httpClient))
  584. }
  585. if len(headers) > 0 {
  586. opts = append(opts, openrouter.WithHeaders(headers))
  587. }
  588. return openrouter.New(opts...)
  589. }
  590. func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  591. opts := []vercel.Option{
  592. vercel.WithAPIKey(apiKey),
  593. }
  594. if c.cfg.Config().Options.Debug {
  595. httpClient := log.NewHTTPClient()
  596. opts = append(opts, vercel.WithHTTPClient(httpClient))
  597. }
  598. if len(headers) > 0 {
  599. opts = append(opts, vercel.WithHeaders(headers))
  600. }
  601. return vercel.New(opts...)
  602. }
  603. func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
  604. opts := []openaicompat.Option{
  605. openaicompat.WithBaseURL(baseURL),
  606. openaicompat.WithAPIKey(apiKey),
  607. }
  608. // Set HTTP client based on provider and debug mode.
  609. var httpClient *http.Client
  610. if providerID == string(catwalk.InferenceProviderCopilot) {
  611. opts = append(opts, openaicompat.WithUseResponsesAPI())
  612. httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
  613. } else if c.cfg.Config().Options.Debug {
  614. httpClient = log.NewHTTPClient()
  615. }
  616. if httpClient != nil {
  617. opts = append(opts, openaicompat.WithHTTPClient(httpClient))
  618. }
  619. if len(headers) > 0 {
  620. opts = append(opts, openaicompat.WithHeaders(headers))
  621. }
  622. for extraKey, extraValue := range extraBody {
  623. opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
  624. }
  625. return openaicompat.New(opts...)
  626. }
  627. func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  628. opts := []azure.Option{
  629. azure.WithBaseURL(baseURL),
  630. azure.WithAPIKey(apiKey),
  631. azure.WithUseResponsesAPI(),
  632. }
  633. if c.cfg.Config().Options.Debug {
  634. httpClient := log.NewHTTPClient()
  635. opts = append(opts, azure.WithHTTPClient(httpClient))
  636. }
  637. if options == nil {
  638. options = make(map[string]string)
  639. }
  640. if apiVersion, ok := options["apiVersion"]; ok {
  641. opts = append(opts, azure.WithAPIVersion(apiVersion))
  642. }
  643. if len(headers) > 0 {
  644. opts = append(opts, azure.WithHeaders(headers))
  645. }
  646. return azure.New(opts...)
  647. }
  648. func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
  649. var opts []bedrock.Option
  650. if c.cfg.Config().Options.Debug {
  651. httpClient := log.NewHTTPClient()
  652. opts = append(opts, bedrock.WithHTTPClient(httpClient))
  653. }
  654. if len(headers) > 0 {
  655. opts = append(opts, bedrock.WithHeaders(headers))
  656. }
  657. switch {
  658. case apiKey != "":
  659. opts = append(opts, bedrock.WithAPIKey(apiKey))
  660. case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
  661. opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
  662. default:
  663. // Skip, let the SDK do authentication.
  664. }
  665. return bedrock.New(opts...)
  666. }
  667. func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  668. opts := []google.Option{
  669. google.WithBaseURL(baseURL),
  670. google.WithGeminiAPIKey(apiKey),
  671. }
  672. if c.cfg.Config().Options.Debug {
  673. httpClient := log.NewHTTPClient()
  674. opts = append(opts, google.WithHTTPClient(httpClient))
  675. }
  676. if len(headers) > 0 {
  677. opts = append(opts, google.WithHeaders(headers))
  678. }
  679. return google.New(opts...)
  680. }
  681. func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  682. opts := []google.Option{}
  683. if c.cfg.Config().Options.Debug {
  684. httpClient := log.NewHTTPClient()
  685. opts = append(opts, google.WithHTTPClient(httpClient))
  686. }
  687. if len(headers) > 0 {
  688. opts = append(opts, google.WithHeaders(headers))
  689. }
  690. project := options["project"]
  691. location := options["location"]
  692. opts = append(opts, google.WithVertex(project, location))
  693. return google.New(opts...)
  694. }
  695. func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
  696. opts := []hyper.Option{
  697. hyper.WithAPIKey(apiKey),
  698. }
  699. if c.cfg.Config().Options.Debug {
  700. httpClient := log.NewHTTPClient()
  701. opts = append(opts, hyper.WithHTTPClient(httpClient))
  702. }
  703. return hyper.New(opts...)
  704. }
  705. func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
  706. if model.Think {
  707. return true
  708. }
  709. opts, err := anthropic.ParseOptions(model.ProviderOptions)
  710. return err == nil && opts.Thinking != nil
  711. }
  712. func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
  713. headers := maps.Clone(providerCfg.ExtraHeaders)
  714. if headers == nil {
  715. headers = make(map[string]string)
  716. }
  717. // handle special headers for anthropic
  718. if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
  719. if v, ok := headers["anthropic-beta"]; ok {
  720. headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
  721. } else {
  722. headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
  723. }
  724. }
  725. apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
  726. baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
  727. switch providerCfg.Type {
  728. case openai.Name:
  729. return c.buildOpenaiProvider(baseURL, apiKey, headers)
  730. case anthropic.Name:
  731. return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
  732. case openrouter.Name:
  733. return c.buildOpenrouterProvider(baseURL, apiKey, headers)
  734. case vercel.Name:
  735. return c.buildVercelProvider(baseURL, apiKey, headers)
  736. case azure.Name:
  737. return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
  738. case bedrock.Name:
  739. return c.buildBedrockProvider(apiKey, headers)
  740. case google.Name:
  741. return c.buildGoogleProvider(baseURL, apiKey, headers)
  742. case "google-vertex":
  743. return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
  744. case openaicompat.Name:
  745. if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
  746. if providerCfg.ExtraBody == nil {
  747. providerCfg.ExtraBody = map[string]any{}
  748. }
  749. providerCfg.ExtraBody["tool_stream"] = true
  750. }
  751. return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
  752. case hyper.Name:
  753. return c.buildHyperProvider(apiKey)
  754. default:
  755. return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
  756. }
  757. }
  758. func isExactoSupported(modelID string) bool {
  759. supportedModels := []string{
  760. "moonshotai/kimi-k2-0905",
  761. "deepseek/deepseek-v3.1-terminus",
  762. "z-ai/glm-4.6",
  763. "openai/gpt-oss-120b",
  764. "qwen/qwen3-coder",
  765. }
  766. return slices.Contains(supportedModels, modelID)
  767. }
  768. func (c *coordinator) Cancel(sessionID string) {
  769. c.currentAgent.Cancel(sessionID)
  770. }
  771. func (c *coordinator) CancelAll() {
  772. c.currentAgent.CancelAll()
  773. }
  774. func (c *coordinator) ClearQueue(sessionID string) {
  775. c.currentAgent.ClearQueue(sessionID)
  776. }
  777. func (c *coordinator) IsBusy() bool {
  778. return c.currentAgent.IsBusy()
  779. }
  780. func (c *coordinator) IsSessionBusy(sessionID string) bool {
  781. return c.currentAgent.IsSessionBusy(sessionID)
  782. }
  783. func (c *coordinator) Model() Model {
  784. return c.currentAgent.Model()
  785. }
  786. func (c *coordinator) UpdateModels(ctx context.Context) error {
  787. // build the models again so we make sure we get the latest config
  788. large, small, err := c.buildAgentModels(ctx, false)
  789. if err != nil {
  790. return err
  791. }
  792. c.currentAgent.SetModels(large, small)
  793. agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
  794. if !ok {
  795. return errCoderAgentNotConfigured
  796. }
  797. tools, err := c.buildTools(ctx, agentCfg)
  798. if err != nil {
  799. return err
  800. }
  801. c.currentAgent.SetTools(tools)
  802. return nil
  803. }
  804. func (c *coordinator) QueuedPrompts(sessionID string) int {
  805. return c.currentAgent.QueuedPrompts(sessionID)
  806. }
  807. func (c *coordinator) QueuedPromptsList(sessionID string) []string {
  808. return c.currentAgent.QueuedPromptsList(sessionID)
  809. }
  810. func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
  811. providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
  812. if !ok {
  813. return errModelProviderNotConfigured
  814. }
  815. return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
  816. }
  817. func (c *coordinator) isUnauthorized(err error) bool {
  818. var providerErr *fantasy.ProviderError
  819. return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
  820. }
  821. func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
  822. if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
  823. slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
  824. return err
  825. }
  826. if err := c.UpdateModels(ctx); err != nil {
  827. return err
  828. }
  829. return nil
  830. }
  831. func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
  832. newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
  833. if err != nil {
  834. slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
  835. return err
  836. }
  837. providerCfg.APIKey = newAPIKey
  838. c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
  839. if err := c.UpdateModels(ctx); err != nil {
  840. return err
  841. }
  842. return nil
  843. }
  844. // subAgentParams holds the parameters for running a sub-agent.
  845. type subAgentParams struct {
  846. Agent SessionAgent
  847. SessionID string
  848. AgentMessageID string
  849. ToolCallID string
  850. Prompt string
  851. SessionTitle string
  852. // SessionSetup is an optional callback invoked after session creation
  853. // but before agent execution, for custom session configuration.
  854. SessionSetup func(sessionID string)
  855. }
  856. // runSubAgent runs a sub-agent and handles session management and cost accumulation.
  857. // It creates a sub-session, runs the agent with the given prompt, and propagates
  858. // the cost to the parent session.
  859. func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
  860. // Create sub-session
  861. agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
  862. session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
  863. if err != nil {
  864. return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
  865. }
  866. // Call session setup function if provided
  867. if params.SessionSetup != nil {
  868. params.SessionSetup(session.ID)
  869. }
  870. // Get model configuration
  871. model := params.Agent.Model()
  872. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  873. if model.ModelCfg.MaxTokens != 0 {
  874. maxTokens = model.ModelCfg.MaxTokens
  875. }
  876. providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
  877. if !ok {
  878. return fantasy.ToolResponse{}, errModelProviderNotConfigured
  879. }
  880. // Run the agent
  881. result, err := params.Agent.Run(ctx, SessionAgentCall{
  882. SessionID: session.ID,
  883. Prompt: params.Prompt,
  884. MaxOutputTokens: maxTokens,
  885. ProviderOptions: getProviderOptions(model, providerCfg),
  886. Temperature: model.ModelCfg.Temperature,
  887. TopP: model.ModelCfg.TopP,
  888. TopK: model.ModelCfg.TopK,
  889. FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
  890. PresencePenalty: model.ModelCfg.PresencePenalty,
  891. NonInteractive: true,
  892. })
  893. if err != nil {
  894. return fantasy.NewTextErrorResponse("error generating response"), nil
  895. }
  896. // Update parent session cost
  897. if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
  898. return fantasy.ToolResponse{}, err
  899. }
  900. return fantasy.NewTextResponse(result.Response.Content.Text()), nil
  901. }
  902. // updateParentSessionCost accumulates the cost from a child session to its parent session.
  903. func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
  904. childSession, err := c.sessions.Get(ctx, childSessionID)
  905. if err != nil {
  906. return fmt.Errorf("get child session: %w", err)
  907. }
  908. parentSession, err := c.sessions.Get(ctx, parentSessionID)
  909. if err != nil {
  910. return fmt.Errorf("get parent session: %w", err)
  911. }
  912. parentSession.Cost += childSession.Cost
  913. if _, err := c.sessions.Save(ctx, parentSession); err != nil {
  914. return fmt.Errorf("save parent session: %w", err)
  915. }
  916. return nil
  917. }