coordinator.go 30 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028
  1. package agent
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "os"
  14. "slices"
  15. "strings"
  16. "charm.land/catwalk/pkg/catwalk"
  17. "charm.land/fantasy"
  18. "github.com/charmbracelet/crush/internal/agent/hyper"
  19. "github.com/charmbracelet/crush/internal/agent/prompt"
  20. "github.com/charmbracelet/crush/internal/agent/tools"
  21. "github.com/charmbracelet/crush/internal/config"
  22. "github.com/charmbracelet/crush/internal/filetracker"
  23. "github.com/charmbracelet/crush/internal/history"
  24. "github.com/charmbracelet/crush/internal/log"
  25. "github.com/charmbracelet/crush/internal/lsp"
  26. "github.com/charmbracelet/crush/internal/message"
  27. "github.com/charmbracelet/crush/internal/oauth/copilot"
  28. "github.com/charmbracelet/crush/internal/permission"
  29. "github.com/charmbracelet/crush/internal/session"
  30. "golang.org/x/sync/errgroup"
  31. "charm.land/fantasy/providers/anthropic"
  32. "charm.land/fantasy/providers/azure"
  33. "charm.land/fantasy/providers/bedrock"
  34. "charm.land/fantasy/providers/google"
  35. "charm.land/fantasy/providers/openai"
  36. "charm.land/fantasy/providers/openaicompat"
  37. "charm.land/fantasy/providers/openrouter"
  38. "charm.land/fantasy/providers/vercel"
  39. openaisdk "github.com/openai/openai-go/v2/option"
  40. "github.com/qjebbs/go-jsons"
  41. )
  42. type Coordinator interface {
  43. // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  44. // SetMainAgent(string)
  45. Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  46. Cancel(sessionID string)
  47. CancelAll()
  48. IsSessionBusy(sessionID string) bool
  49. IsBusy() bool
  50. QueuedPrompts(sessionID string) int
  51. QueuedPromptsList(sessionID string) []string
  52. ClearQueue(sessionID string)
  53. Summarize(context.Context, string) error
  54. Model() Model
  55. UpdateModels(ctx context.Context) error
  56. }
  57. type coordinator struct {
  58. cfg *config.Config
  59. sessions session.Service
  60. messages message.Service
  61. permissions permission.Service
  62. history history.Service
  63. filetracker filetracker.Service
  64. lspManager *lsp.Manager
  65. currentAgent SessionAgent
  66. agents map[string]SessionAgent
  67. readyWg errgroup.Group
  68. }
  69. func NewCoordinator(
  70. ctx context.Context,
  71. cfg *config.Config,
  72. sessions session.Service,
  73. messages message.Service,
  74. permissions permission.Service,
  75. history history.Service,
  76. filetracker filetracker.Service,
  77. lspManager *lsp.Manager,
  78. ) (Coordinator, error) {
  79. c := &coordinator{
  80. cfg: cfg,
  81. sessions: sessions,
  82. messages: messages,
  83. permissions: permissions,
  84. history: history,
  85. filetracker: filetracker,
  86. lspManager: lspManager,
  87. agents: make(map[string]SessionAgent),
  88. }
  89. agentCfg, ok := cfg.Agents[config.AgentCoder]
  90. if !ok {
  91. return nil, errors.New("coder agent not configured")
  92. }
  93. // TODO: make this dynamic when we support multiple agents
  94. prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
  95. if err != nil {
  96. return nil, err
  97. }
  98. agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
  99. if err != nil {
  100. return nil, err
  101. }
  102. c.currentAgent = agent
  103. c.agents[config.AgentCoder] = agent
  104. return c, nil
  105. }
  106. // Run implements Coordinator.
  107. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
  108. if err := c.readyWg.Wait(); err != nil {
  109. return nil, err
  110. }
  111. // refresh models before each run
  112. if err := c.UpdateModels(ctx); err != nil {
  113. return nil, fmt.Errorf("failed to update models: %w", err)
  114. }
  115. model := c.currentAgent.Model()
  116. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  117. if model.ModelCfg.MaxTokens != 0 {
  118. maxTokens = model.ModelCfg.MaxTokens
  119. }
  120. if !model.CatwalkCfg.SupportsImages && attachments != nil {
  121. // filter out image attachments
  122. filteredAttachments := make([]message.Attachment, 0, len(attachments))
  123. for _, att := range attachments {
  124. if att.IsText() {
  125. filteredAttachments = append(filteredAttachments, att)
  126. }
  127. }
  128. attachments = filteredAttachments
  129. }
  130. providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
  131. if !ok {
  132. return nil, errors.New("model provider not configured")
  133. }
  134. mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
  135. if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
  136. slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
  137. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  138. return nil, err
  139. }
  140. }
  141. run := func() (*fantasy.AgentResult, error) {
  142. return c.currentAgent.Run(ctx, SessionAgentCall{
  143. SessionID: sessionID,
  144. Prompt: prompt,
  145. Attachments: attachments,
  146. MaxOutputTokens: maxTokens,
  147. ProviderOptions: mergedOptions,
  148. Temperature: temp,
  149. TopP: topP,
  150. TopK: topK,
  151. FrequencyPenalty: freqPenalty,
  152. PresencePenalty: presPenalty,
  153. })
  154. }
  155. result, originalErr := run()
  156. if c.isUnauthorized(originalErr) {
  157. switch {
  158. case providerCfg.OAuthToken != nil:
  159. slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
  160. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  161. return nil, originalErr
  162. }
  163. slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
  164. return run()
  165. case strings.Contains(providerCfg.APIKeyTemplate, "$"):
  166. slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
  167. if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
  168. return nil, originalErr
  169. }
  170. slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
  171. return run()
  172. }
  173. }
  174. return result, originalErr
  175. }
  176. func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
  177. options := fantasy.ProviderOptions{}
  178. cfgOpts := []byte("{}")
  179. providerCfgOpts := []byte("{}")
  180. catwalkOpts := []byte("{}")
  181. if model.ModelCfg.ProviderOptions != nil {
  182. data, err := json.Marshal(model.ModelCfg.ProviderOptions)
  183. if err == nil {
  184. cfgOpts = data
  185. }
  186. }
  187. if providerCfg.ProviderOptions != nil {
  188. data, err := json.Marshal(providerCfg.ProviderOptions)
  189. if err == nil {
  190. providerCfgOpts = data
  191. }
  192. }
  193. if model.CatwalkCfg.Options.ProviderOptions != nil {
  194. data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
  195. if err == nil {
  196. catwalkOpts = data
  197. }
  198. }
  199. readers := []io.Reader{
  200. bytes.NewReader(catwalkOpts),
  201. bytes.NewReader(providerCfgOpts),
  202. bytes.NewReader(cfgOpts),
  203. }
  204. got, err := jsons.Merge(readers)
  205. if err != nil {
  206. slog.Error("Could not merge call config", "err", err)
  207. return options
  208. }
  209. mergedOptions := make(map[string]any)
  210. err = json.Unmarshal([]byte(got), &mergedOptions)
  211. if err != nil {
  212. slog.Error("Could not create config for call", "err", err)
  213. return options
  214. }
  215. providerType := providerCfg.Type
  216. if providerType == "hyper" {
  217. if strings.Contains(model.CatwalkCfg.ID, "claude") {
  218. providerType = anthropic.Name
  219. } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
  220. providerType = openai.Name
  221. } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
  222. providerType = google.Name
  223. } else {
  224. providerType = openaicompat.Name
  225. }
  226. }
  227. switch providerType {
  228. case openai.Name, azure.Name:
  229. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  230. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  231. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  232. }
  233. if openai.IsResponsesModel(model.CatwalkCfg.ID) {
  234. if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
  235. mergedOptions["reasoning_summary"] = "auto"
  236. mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
  237. }
  238. parsed, err := openai.ParseResponsesOptions(mergedOptions)
  239. if err == nil {
  240. options[openai.Name] = parsed
  241. }
  242. } else {
  243. parsed, err := openai.ParseOptions(mergedOptions)
  244. if err == nil {
  245. options[openai.Name] = parsed
  246. }
  247. }
  248. case anthropic.Name:
  249. var (
  250. _, hasEffort = mergedOptions["effort"]
  251. _, hasThink = mergedOptions["thinking"]
  252. )
  253. switch {
  254. case !hasEffort && model.ModelCfg.ReasoningEffort != "":
  255. mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
  256. case !hasThink && model.ModelCfg.Think:
  257. mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
  258. }
  259. parsed, err := anthropic.ParseOptions(mergedOptions)
  260. if err == nil {
  261. options[anthropic.Name] = parsed
  262. }
  263. case openrouter.Name:
  264. _, hasReasoning := mergedOptions["reasoning"]
  265. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  266. mergedOptions["reasoning"] = map[string]any{
  267. "enabled": true,
  268. "effort": model.ModelCfg.ReasoningEffort,
  269. }
  270. }
  271. parsed, err := openrouter.ParseOptions(mergedOptions)
  272. if err == nil {
  273. options[openrouter.Name] = parsed
  274. }
  275. case vercel.Name:
  276. _, hasReasoning := mergedOptions["reasoning"]
  277. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  278. mergedOptions["reasoning"] = map[string]any{
  279. "enabled": true,
  280. "effort": model.ModelCfg.ReasoningEffort,
  281. }
  282. }
  283. parsed, err := vercel.ParseOptions(mergedOptions)
  284. if err == nil {
  285. options[vercel.Name] = parsed
  286. }
  287. case google.Name:
  288. _, hasReasoning := mergedOptions["thinking_config"]
  289. if !hasReasoning {
  290. if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
  291. mergedOptions["thinking_config"] = map[string]any{
  292. "thinking_budget": 2000,
  293. "include_thoughts": true,
  294. }
  295. } else {
  296. mergedOptions["thinking_config"] = map[string]any{
  297. "thinking_level": model.ModelCfg.ReasoningEffort,
  298. "include_thoughts": true,
  299. }
  300. }
  301. }
  302. parsed, err := google.ParseOptions(mergedOptions)
  303. if err == nil {
  304. options[google.Name] = parsed
  305. }
  306. case openaicompat.Name:
  307. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  308. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  309. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  310. }
  311. parsed, err := openaicompat.ParseOptions(mergedOptions)
  312. if err == nil {
  313. options[openaicompat.Name] = parsed
  314. }
  315. }
  316. return options
  317. }
  318. func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
  319. modelOptions := getProviderOptions(model, cfg)
  320. temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
  321. topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
  322. topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
  323. freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
  324. presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
  325. return modelOptions, temp, topP, topK, freqPenalty, presPenalty
  326. }
  327. func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
  328. large, small, err := c.buildAgentModels(ctx, isSubAgent)
  329. if err != nil {
  330. return nil, err
  331. }
  332. largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
  333. result := NewSessionAgent(SessionAgentOptions{
  334. large,
  335. small,
  336. largeProviderCfg.SystemPromptPrefix,
  337. "",
  338. isSubAgent,
  339. c.cfg.Options.DisableAutoSummarize,
  340. c.permissions.SkipRequests(),
  341. c.sessions,
  342. c.messages,
  343. nil,
  344. })
  345. c.readyWg.Go(func() error {
  346. systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
  347. if err != nil {
  348. return err
  349. }
  350. result.SetSystemPrompt(systemPrompt)
  351. return nil
  352. })
  353. c.readyWg.Go(func() error {
  354. tools, err := c.buildTools(ctx, agent)
  355. if err != nil {
  356. return err
  357. }
  358. result.SetTools(tools)
  359. return nil
  360. })
  361. return result, nil
  362. }
  363. func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
  364. var allTools []fantasy.AgentTool
  365. if slices.Contains(agent.AllowedTools, AgentToolName) {
  366. agentTool, err := c.agentTool(ctx)
  367. if err != nil {
  368. return nil, err
  369. }
  370. allTools = append(allTools, agentTool)
  371. }
  372. if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
  373. agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
  374. if err != nil {
  375. return nil, err
  376. }
  377. allTools = append(allTools, agenticFetchTool)
  378. }
  379. // Get the model name for the agent
  380. modelName := ""
  381. if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
  382. if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
  383. modelName = model.Name
  384. }
  385. }
  386. allTools = append(allTools,
  387. tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
  388. tools.NewJobOutputTool(),
  389. tools.NewJobKillTool(),
  390. tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
  391. tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  392. tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  393. tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
  394. tools.NewGlobTool(c.cfg.WorkingDir()),
  395. tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
  396. tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
  397. tools.NewSourcegraphTool(nil),
  398. tools.NewTodosTool(c.sessions),
  399. tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
  400. tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  401. )
  402. // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
  403. if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
  404. allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
  405. }
  406. if len(c.cfg.MCP) > 0 {
  407. allTools = append(
  408. allTools,
  409. tools.NewListMCPResourcesTool(c.cfg, c.permissions),
  410. tools.NewReadMCPResourceTool(c.cfg, c.permissions),
  411. )
  412. }
  413. var filteredTools []fantasy.AgentTool
  414. for _, tool := range allTools {
  415. if slices.Contains(agent.AllowedTools, tool.Info().Name) {
  416. filteredTools = append(filteredTools, tool)
  417. }
  418. }
  419. for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
  420. if agent.AllowedMCP == nil {
  421. // No MCP restrictions
  422. filteredTools = append(filteredTools, tool)
  423. continue
  424. }
  425. if len(agent.AllowedMCP) == 0 {
  426. // No MCPs allowed
  427. slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
  428. break
  429. }
  430. for mcp, tools := range agent.AllowedMCP {
  431. if mcp != tool.MCP() {
  432. continue
  433. }
  434. if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
  435. filteredTools = append(filteredTools, tool)
  436. break
  437. }
  438. slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
  439. }
  440. }
  441. slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
  442. return strings.Compare(a.Info().Name, b.Info().Name)
  443. })
  444. return filteredTools, nil
  445. }
  446. // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
  447. func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
  448. largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
  449. if !ok {
  450. return Model{}, Model{}, errors.New("large model not selected")
  451. }
  452. smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
  453. if !ok {
  454. return Model{}, Model{}, errors.New("small model not selected")
  455. }
  456. largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
  457. if !ok {
  458. return Model{}, Model{}, errors.New("large model provider not configured")
  459. }
  460. largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
  461. if err != nil {
  462. return Model{}, Model{}, err
  463. }
  464. smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
  465. if !ok {
  466. return Model{}, Model{}, errors.New("small model provider not configured")
  467. }
  468. smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
  469. if err != nil {
  470. return Model{}, Model{}, err
  471. }
  472. var largeCatwalkModel *catwalk.Model
  473. var smallCatwalkModel *catwalk.Model
  474. for _, m := range largeProviderCfg.Models {
  475. if m.ID == largeModelCfg.Model {
  476. largeCatwalkModel = &m
  477. }
  478. }
  479. for _, m := range smallProviderCfg.Models {
  480. if m.ID == smallModelCfg.Model {
  481. smallCatwalkModel = &m
  482. }
  483. }
  484. if largeCatwalkModel == nil {
  485. return Model{}, Model{}, errors.New("large model not found in provider config")
  486. }
  487. if smallCatwalkModel == nil {
  488. return Model{}, Model{}, errors.New("small model not found in provider config")
  489. }
  490. largeModelID := largeModelCfg.Model
  491. smallModelID := smallModelCfg.Model
  492. if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
  493. largeModelID += ":exacto"
  494. }
  495. if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
  496. smallModelID += ":exacto"
  497. }
  498. largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
  499. if err != nil {
  500. return Model{}, Model{}, err
  501. }
  502. smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
  503. if err != nil {
  504. return Model{}, Model{}, err
  505. }
  506. return Model{
  507. Model: largeModel,
  508. CatwalkCfg: *largeCatwalkModel,
  509. ModelCfg: largeModelCfg,
  510. }, Model{
  511. Model: smallModel,
  512. CatwalkCfg: *smallCatwalkModel,
  513. ModelCfg: smallModelCfg,
  514. }, nil
  515. }
  516. func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
  517. var opts []anthropic.Option
  518. switch {
  519. case strings.HasPrefix(apiKey, "Bearer "):
  520. // NOTE: Prevent the SDK from picking up the API key from env.
  521. os.Setenv("ANTHROPIC_API_KEY", "")
  522. headers["Authorization"] = apiKey
  523. case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
  524. // NOTE: Prevent the SDK from picking up the API key from env.
  525. os.Setenv("ANTHROPIC_API_KEY", "")
  526. headers["Authorization"] = "Bearer " + apiKey
  527. case apiKey != "":
  528. // X-Api-Key header
  529. opts = append(opts, anthropic.WithAPIKey(apiKey))
  530. }
  531. if len(headers) > 0 {
  532. opts = append(opts, anthropic.WithHeaders(headers))
  533. }
  534. if baseURL != "" {
  535. opts = append(opts, anthropic.WithBaseURL(baseURL))
  536. }
  537. if c.cfg.Options.Debug {
  538. httpClient := log.NewHTTPClient()
  539. opts = append(opts, anthropic.WithHTTPClient(httpClient))
  540. }
  541. return anthropic.New(opts...)
  542. }
  543. func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  544. opts := []openai.Option{
  545. openai.WithAPIKey(apiKey),
  546. openai.WithUseResponsesAPI(),
  547. }
  548. if c.cfg.Options.Debug {
  549. httpClient := log.NewHTTPClient()
  550. opts = append(opts, openai.WithHTTPClient(httpClient))
  551. }
  552. if len(headers) > 0 {
  553. opts = append(opts, openai.WithHeaders(headers))
  554. }
  555. if baseURL != "" {
  556. opts = append(opts, openai.WithBaseURL(baseURL))
  557. }
  558. return openai.New(opts...)
  559. }
  560. func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  561. opts := []openrouter.Option{
  562. openrouter.WithAPIKey(apiKey),
  563. }
  564. if c.cfg.Options.Debug {
  565. httpClient := log.NewHTTPClient()
  566. opts = append(opts, openrouter.WithHTTPClient(httpClient))
  567. }
  568. if len(headers) > 0 {
  569. opts = append(opts, openrouter.WithHeaders(headers))
  570. }
  571. return openrouter.New(opts...)
  572. }
  573. func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  574. opts := []vercel.Option{
  575. vercel.WithAPIKey(apiKey),
  576. }
  577. if c.cfg.Options.Debug {
  578. httpClient := log.NewHTTPClient()
  579. opts = append(opts, vercel.WithHTTPClient(httpClient))
  580. }
  581. if len(headers) > 0 {
  582. opts = append(opts, vercel.WithHeaders(headers))
  583. }
  584. return vercel.New(opts...)
  585. }
  586. func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
  587. opts := []openaicompat.Option{
  588. openaicompat.WithBaseURL(baseURL),
  589. openaicompat.WithAPIKey(apiKey),
  590. }
  591. // Set HTTP client based on provider and debug mode.
  592. var httpClient *http.Client
  593. if providerID == string(catwalk.InferenceProviderCopilot) {
  594. opts = append(opts, openaicompat.WithUseResponsesAPI())
  595. httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
  596. } else if c.cfg.Options.Debug {
  597. httpClient = log.NewHTTPClient()
  598. }
  599. if httpClient != nil {
  600. opts = append(opts, openaicompat.WithHTTPClient(httpClient))
  601. }
  602. if len(headers) > 0 {
  603. opts = append(opts, openaicompat.WithHeaders(headers))
  604. }
  605. for extraKey, extraValue := range extraBody {
  606. opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
  607. }
  608. return openaicompat.New(opts...)
  609. }
  610. func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  611. opts := []azure.Option{
  612. azure.WithBaseURL(baseURL),
  613. azure.WithAPIKey(apiKey),
  614. azure.WithUseResponsesAPI(),
  615. }
  616. if c.cfg.Options.Debug {
  617. httpClient := log.NewHTTPClient()
  618. opts = append(opts, azure.WithHTTPClient(httpClient))
  619. }
  620. if options == nil {
  621. options = make(map[string]string)
  622. }
  623. if apiVersion, ok := options["apiVersion"]; ok {
  624. opts = append(opts, azure.WithAPIVersion(apiVersion))
  625. }
  626. if len(headers) > 0 {
  627. opts = append(opts, azure.WithHeaders(headers))
  628. }
  629. return azure.New(opts...)
  630. }
  631. func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
  632. var opts []bedrock.Option
  633. if c.cfg.Options.Debug {
  634. httpClient := log.NewHTTPClient()
  635. opts = append(opts, bedrock.WithHTTPClient(httpClient))
  636. }
  637. if len(headers) > 0 {
  638. opts = append(opts, bedrock.WithHeaders(headers))
  639. }
  640. bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
  641. if bearerToken != "" {
  642. opts = append(opts, bedrock.WithAPIKey(bearerToken))
  643. }
  644. return bedrock.New(opts...)
  645. }
  646. func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  647. opts := []google.Option{
  648. google.WithBaseURL(baseURL),
  649. google.WithGeminiAPIKey(apiKey),
  650. }
  651. if c.cfg.Options.Debug {
  652. httpClient := log.NewHTTPClient()
  653. opts = append(opts, google.WithHTTPClient(httpClient))
  654. }
  655. if len(headers) > 0 {
  656. opts = append(opts, google.WithHeaders(headers))
  657. }
  658. return google.New(opts...)
  659. }
  660. func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  661. opts := []google.Option{}
  662. if c.cfg.Options.Debug {
  663. httpClient := log.NewHTTPClient()
  664. opts = append(opts, google.WithHTTPClient(httpClient))
  665. }
  666. if len(headers) > 0 {
  667. opts = append(opts, google.WithHeaders(headers))
  668. }
  669. project := options["project"]
  670. location := options["location"]
  671. opts = append(opts, google.WithVertex(project, location))
  672. return google.New(opts...)
  673. }
  674. func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
  675. opts := []hyper.Option{
  676. hyper.WithBaseURL(baseURL),
  677. hyper.WithAPIKey(apiKey),
  678. }
  679. if c.cfg.Options.Debug {
  680. httpClient := log.NewHTTPClient()
  681. opts = append(opts, hyper.WithHTTPClient(httpClient))
  682. }
  683. return hyper.New(opts...)
  684. }
  685. func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
  686. if model.Think {
  687. return true
  688. }
  689. if model.ProviderOptions == nil {
  690. return false
  691. }
  692. opts, err := anthropic.ParseOptions(model.ProviderOptions)
  693. if err != nil {
  694. return false
  695. }
  696. if opts.Thinking != nil {
  697. return true
  698. }
  699. return false
  700. }
  701. func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
  702. headers := maps.Clone(providerCfg.ExtraHeaders)
  703. if headers == nil {
  704. headers = make(map[string]string)
  705. }
  706. // handle special headers for anthropic
  707. if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
  708. if v, ok := headers["anthropic-beta"]; ok {
  709. headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
  710. } else {
  711. headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
  712. }
  713. }
  714. apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
  715. baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
  716. switch providerCfg.Type {
  717. case openai.Name:
  718. return c.buildOpenaiProvider(baseURL, apiKey, headers)
  719. case anthropic.Name:
  720. return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
  721. case openrouter.Name:
  722. return c.buildOpenrouterProvider(baseURL, apiKey, headers)
  723. case vercel.Name:
  724. return c.buildVercelProvider(baseURL, apiKey, headers)
  725. case azure.Name:
  726. return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
  727. case bedrock.Name:
  728. return c.buildBedrockProvider(headers)
  729. case google.Name:
  730. return c.buildGoogleProvider(baseURL, apiKey, headers)
  731. case "google-vertex":
  732. return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
  733. case openaicompat.Name:
  734. if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
  735. if providerCfg.ExtraBody == nil {
  736. providerCfg.ExtraBody = map[string]any{}
  737. }
  738. providerCfg.ExtraBody["tool_stream"] = true
  739. }
  740. return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
  741. case hyper.Name:
  742. return c.buildHyperProvider(baseURL, apiKey)
  743. default:
  744. return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
  745. }
  746. }
  747. func isExactoSupported(modelID string) bool {
  748. supportedModels := []string{
  749. "moonshotai/kimi-k2-0905",
  750. "deepseek/deepseek-v3.1-terminus",
  751. "z-ai/glm-4.6",
  752. "openai/gpt-oss-120b",
  753. "qwen/qwen3-coder",
  754. }
  755. return slices.Contains(supportedModels, modelID)
  756. }
  757. func (c *coordinator) Cancel(sessionID string) {
  758. c.currentAgent.Cancel(sessionID)
  759. }
  760. func (c *coordinator) CancelAll() {
  761. c.currentAgent.CancelAll()
  762. }
  763. func (c *coordinator) ClearQueue(sessionID string) {
  764. c.currentAgent.ClearQueue(sessionID)
  765. }
  766. func (c *coordinator) IsBusy() bool {
  767. return c.currentAgent.IsBusy()
  768. }
  769. func (c *coordinator) IsSessionBusy(sessionID string) bool {
  770. return c.currentAgent.IsSessionBusy(sessionID)
  771. }
  772. func (c *coordinator) Model() Model {
  773. return c.currentAgent.Model()
  774. }
  775. func (c *coordinator) UpdateModels(ctx context.Context) error {
  776. // build the models again so we make sure we get the latest config
  777. large, small, err := c.buildAgentModels(ctx, false)
  778. if err != nil {
  779. return err
  780. }
  781. c.currentAgent.SetModels(large, small)
  782. agentCfg, ok := c.cfg.Agents[config.AgentCoder]
  783. if !ok {
  784. return errors.New("coder agent not configured")
  785. }
  786. tools, err := c.buildTools(ctx, agentCfg)
  787. if err != nil {
  788. return err
  789. }
  790. c.currentAgent.SetTools(tools)
  791. return nil
  792. }
  793. func (c *coordinator) QueuedPrompts(sessionID string) int {
  794. return c.currentAgent.QueuedPrompts(sessionID)
  795. }
  796. func (c *coordinator) QueuedPromptsList(sessionID string) []string {
  797. return c.currentAgent.QueuedPromptsList(sessionID)
  798. }
  799. func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
  800. providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
  801. if !ok {
  802. return errors.New("model provider not configured")
  803. }
  804. return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
  805. }
  806. func (c *coordinator) isUnauthorized(err error) bool {
  807. var providerErr *fantasy.ProviderError
  808. return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
  809. }
  810. func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
  811. if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
  812. slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
  813. return err
  814. }
  815. if err := c.UpdateModels(ctx); err != nil {
  816. return err
  817. }
  818. return nil
  819. }
  820. func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
  821. newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
  822. if err != nil {
  823. slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
  824. return err
  825. }
  826. providerCfg.APIKey = newAPIKey
  827. c.cfg.Providers.Set(providerCfg.ID, providerCfg)
  828. if err := c.UpdateModels(ctx); err != nil {
  829. return err
  830. }
  831. return nil
  832. }
  833. // subAgentParams holds the parameters for running a sub-agent.
  834. type subAgentParams struct {
  835. Agent SessionAgent
  836. SessionID string
  837. AgentMessageID string
  838. ToolCallID string
  839. Prompt string
  840. SessionTitle string
  841. // SessionSetup is an optional callback invoked after session creation
  842. // but before agent execution, for custom session configuration.
  843. SessionSetup func(sessionID string)
  844. }
  845. // runSubAgent runs a sub-agent and handles session management and cost accumulation.
  846. // It creates a sub-session, runs the agent with the given prompt, and propagates
  847. // the cost to the parent session.
  848. func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
  849. // Create sub-session
  850. agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
  851. session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
  852. if err != nil {
  853. return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
  854. }
  855. // Call session setup function if provided
  856. if params.SessionSetup != nil {
  857. params.SessionSetup(session.ID)
  858. }
  859. // Get model configuration
  860. model := params.Agent.Model()
  861. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  862. if model.ModelCfg.MaxTokens != 0 {
  863. maxTokens = model.ModelCfg.MaxTokens
  864. }
  865. providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
  866. if !ok {
  867. return fantasy.ToolResponse{}, errors.New("model provider not configured")
  868. }
  869. // Run the agent
  870. result, err := params.Agent.Run(ctx, SessionAgentCall{
  871. SessionID: session.ID,
  872. Prompt: params.Prompt,
  873. MaxOutputTokens: maxTokens,
  874. ProviderOptions: getProviderOptions(model, providerCfg),
  875. Temperature: model.ModelCfg.Temperature,
  876. TopP: model.ModelCfg.TopP,
  877. TopK: model.ModelCfg.TopK,
  878. FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
  879. PresencePenalty: model.ModelCfg.PresencePenalty,
  880. })
  881. if err != nil {
  882. return fantasy.NewTextErrorResponse("error generating response"), nil
  883. }
  884. // Update parent session cost
  885. if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
  886. return fantasy.ToolResponse{}, err
  887. }
  888. return fantasy.NewTextResponse(result.Response.Content.Text()), nil
  889. }
  890. // updateParentSessionCost accumulates the cost from a child session to its parent session.
  891. func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
  892. childSession, err := c.sessions.Get(ctx, childSessionID)
  893. if err != nil {
  894. return fmt.Errorf("get child session: %w", err)
  895. }
  896. parentSession, err := c.sessions.Get(ctx, parentSessionID)
  897. if err != nil {
  898. return fmt.Errorf("get parent session: %w", err)
  899. }
  900. parentSession.Cost += childSession.Cost
  901. if _, err := c.sessions.Save(ctx, parentSession); err != nil {
  902. return fmt.Errorf("save parent session: %w", err)
  903. }
  904. return nil
  905. }