coordinator.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. package agent
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "os"
  14. "slices"
  15. "strings"
  16. "charm.land/catwalk/pkg/catwalk"
  17. "charm.land/fantasy"
  18. "github.com/charmbracelet/crush/internal/agent/hyper"
  19. "github.com/charmbracelet/crush/internal/agent/prompt"
  20. "github.com/charmbracelet/crush/internal/agent/tools"
  21. "github.com/charmbracelet/crush/internal/config"
  22. "github.com/charmbracelet/crush/internal/filetracker"
  23. "github.com/charmbracelet/crush/internal/history"
  24. "github.com/charmbracelet/crush/internal/log"
  25. "github.com/charmbracelet/crush/internal/lsp"
  26. "github.com/charmbracelet/crush/internal/message"
  27. "github.com/charmbracelet/crush/internal/oauth/copilot"
  28. "github.com/charmbracelet/crush/internal/permission"
  29. "github.com/charmbracelet/crush/internal/session"
  30. "golang.org/x/sync/errgroup"
  31. "charm.land/fantasy/providers/anthropic"
  32. "charm.land/fantasy/providers/azure"
  33. "charm.land/fantasy/providers/bedrock"
  34. "charm.land/fantasy/providers/google"
  35. "charm.land/fantasy/providers/openai"
  36. "charm.land/fantasy/providers/openaicompat"
  37. "charm.land/fantasy/providers/openrouter"
  38. "charm.land/fantasy/providers/vercel"
  39. openaisdk "github.com/openai/openai-go/v2/option"
  40. "github.com/qjebbs/go-jsons"
  41. )
  42. type Coordinator interface {
  43. // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  44. // SetMainAgent(string)
  45. Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  46. Cancel(sessionID string)
  47. CancelAll()
  48. IsSessionBusy(sessionID string) bool
  49. IsBusy() bool
  50. QueuedPrompts(sessionID string) int
  51. QueuedPromptsList(sessionID string) []string
  52. ClearQueue(sessionID string)
  53. Summarize(context.Context, string) error
  54. Model() Model
  55. UpdateModels(ctx context.Context) error
  56. }
  57. type coordinator struct {
  58. cfg *config.Config
  59. sessions session.Service
  60. messages message.Service
  61. permissions permission.Service
  62. history history.Service
  63. filetracker filetracker.Service
  64. lspManager *lsp.Manager
  65. currentAgent SessionAgent
  66. agents map[string]SessionAgent
  67. readyWg errgroup.Group
  68. }
  69. func NewCoordinator(
  70. ctx context.Context,
  71. cfg *config.Config,
  72. sessions session.Service,
  73. messages message.Service,
  74. permissions permission.Service,
  75. history history.Service,
  76. filetracker filetracker.Service,
  77. lspManager *lsp.Manager,
  78. ) (Coordinator, error) {
  79. c := &coordinator{
  80. cfg: cfg,
  81. sessions: sessions,
  82. messages: messages,
  83. permissions: permissions,
  84. history: history,
  85. filetracker: filetracker,
  86. lspManager: lspManager,
  87. agents: make(map[string]SessionAgent),
  88. }
  89. agentCfg, ok := cfg.Agents[config.AgentCoder]
  90. if !ok {
  91. return nil, errors.New("coder agent not configured")
  92. }
  93. // TODO: make this dynamic when we support multiple agents
  94. prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
  95. if err != nil {
  96. return nil, err
  97. }
  98. agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
  99. if err != nil {
  100. return nil, err
  101. }
  102. c.currentAgent = agent
  103. c.agents[config.AgentCoder] = agent
  104. return c, nil
  105. }
  106. // Run implements Coordinator.
  107. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
  108. if err := c.readyWg.Wait(); err != nil {
  109. return nil, err
  110. }
  111. // refresh models before each run
  112. if err := c.UpdateModels(ctx); err != nil {
  113. return nil, fmt.Errorf("failed to update models: %w", err)
  114. }
  115. model := c.currentAgent.Model()
  116. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  117. if model.ModelCfg.MaxTokens != 0 {
  118. maxTokens = model.ModelCfg.MaxTokens
  119. }
  120. if !model.CatwalkCfg.SupportsImages && attachments != nil {
  121. // filter out image attachments
  122. filteredAttachments := make([]message.Attachment, 0, len(attachments))
  123. for _, att := range attachments {
  124. if att.IsText() {
  125. filteredAttachments = append(filteredAttachments, att)
  126. }
  127. }
  128. attachments = filteredAttachments
  129. }
  130. providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
  131. if !ok {
  132. return nil, errors.New("model provider not configured")
  133. }
  134. mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
  135. if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
  136. slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
  137. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  138. return nil, err
  139. }
  140. }
  141. run := func() (*fantasy.AgentResult, error) {
  142. return c.currentAgent.Run(ctx, SessionAgentCall{
  143. SessionID: sessionID,
  144. Prompt: prompt,
  145. Attachments: attachments,
  146. MaxOutputTokens: maxTokens,
  147. ProviderOptions: mergedOptions,
  148. Temperature: temp,
  149. TopP: topP,
  150. TopK: topK,
  151. FrequencyPenalty: freqPenalty,
  152. PresencePenalty: presPenalty,
  153. })
  154. }
  155. result, originalErr := run()
  156. if c.isUnauthorized(originalErr) {
  157. switch {
  158. case providerCfg.OAuthToken != nil:
  159. slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
  160. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  161. return nil, originalErr
  162. }
  163. slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
  164. return run()
  165. case strings.Contains(providerCfg.APIKeyTemplate, "$"):
  166. slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
  167. if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
  168. return nil, originalErr
  169. }
  170. slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
  171. return run()
  172. }
  173. }
  174. return result, originalErr
  175. }
  176. func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
  177. options := fantasy.ProviderOptions{}
  178. cfgOpts := []byte("{}")
  179. providerCfgOpts := []byte("{}")
  180. catwalkOpts := []byte("{}")
  181. if model.ModelCfg.ProviderOptions != nil {
  182. data, err := json.Marshal(model.ModelCfg.ProviderOptions)
  183. if err == nil {
  184. cfgOpts = data
  185. }
  186. }
  187. if providerCfg.ProviderOptions != nil {
  188. data, err := json.Marshal(providerCfg.ProviderOptions)
  189. if err == nil {
  190. providerCfgOpts = data
  191. }
  192. }
  193. if model.CatwalkCfg.Options.ProviderOptions != nil {
  194. data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
  195. if err == nil {
  196. catwalkOpts = data
  197. }
  198. }
  199. readers := []io.Reader{
  200. bytes.NewReader(catwalkOpts),
  201. bytes.NewReader(providerCfgOpts),
  202. bytes.NewReader(cfgOpts),
  203. }
  204. got, err := jsons.Merge(readers)
  205. if err != nil {
  206. slog.Error("Could not merge call config", "err", err)
  207. return options
  208. }
  209. mergedOptions := make(map[string]any)
  210. err = json.Unmarshal([]byte(got), &mergedOptions)
  211. if err != nil {
  212. slog.Error("Could not create config for call", "err", err)
  213. return options
  214. }
  215. providerType := providerCfg.Type
  216. if providerType == "hyper" {
  217. if strings.Contains(model.CatwalkCfg.ID, "claude") {
  218. providerType = anthropic.Name
  219. } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
  220. providerType = openai.Name
  221. } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
  222. providerType = google.Name
  223. } else {
  224. providerType = openaicompat.Name
  225. }
  226. }
  227. switch providerType {
  228. case openai.Name, azure.Name:
  229. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  230. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  231. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  232. }
  233. if openai.IsResponsesModel(model.CatwalkCfg.ID) {
  234. if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
  235. mergedOptions["reasoning_summary"] = "auto"
  236. mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
  237. }
  238. parsed, err := openai.ParseResponsesOptions(mergedOptions)
  239. if err == nil {
  240. options[openai.Name] = parsed
  241. }
  242. } else {
  243. parsed, err := openai.ParseOptions(mergedOptions)
  244. if err == nil {
  245. options[openai.Name] = parsed
  246. }
  247. }
  248. case anthropic.Name:
  249. _, hasThink := mergedOptions["thinking"]
  250. if !hasThink && model.ModelCfg.Think {
  251. mergedOptions["thinking"] = map[string]any{
  252. // TODO: kujtim see if we need to make this dynamic
  253. "budget_tokens": 2000,
  254. }
  255. }
  256. parsed, err := anthropic.ParseOptions(mergedOptions)
  257. if err == nil {
  258. options[anthropic.Name] = parsed
  259. }
  260. case openrouter.Name:
  261. _, hasReasoning := mergedOptions["reasoning"]
  262. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  263. mergedOptions["reasoning"] = map[string]any{
  264. "enabled": true,
  265. "effort": model.ModelCfg.ReasoningEffort,
  266. }
  267. }
  268. parsed, err := openrouter.ParseOptions(mergedOptions)
  269. if err == nil {
  270. options[openrouter.Name] = parsed
  271. }
  272. case vercel.Name:
  273. _, hasReasoning := mergedOptions["reasoning"]
  274. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  275. mergedOptions["reasoning"] = map[string]any{
  276. "enabled": true,
  277. "effort": model.ModelCfg.ReasoningEffort,
  278. }
  279. }
  280. parsed, err := vercel.ParseOptions(mergedOptions)
  281. if err == nil {
  282. options[vercel.Name] = parsed
  283. }
  284. case google.Name:
  285. _, hasReasoning := mergedOptions["thinking_config"]
  286. if !hasReasoning {
  287. mergedOptions["thinking_config"] = map[string]any{
  288. "thinking_budget": 2000,
  289. "include_thoughts": true,
  290. }
  291. }
  292. parsed, err := google.ParseOptions(mergedOptions)
  293. if err == nil {
  294. options[google.Name] = parsed
  295. }
  296. case openaicompat.Name:
  297. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  298. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  299. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  300. }
  301. parsed, err := openaicompat.ParseOptions(mergedOptions)
  302. if err == nil {
  303. options[openaicompat.Name] = parsed
  304. }
  305. }
  306. return options
  307. }
  308. func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
  309. modelOptions := getProviderOptions(model, cfg)
  310. temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
  311. topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
  312. topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
  313. freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
  314. presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
  315. return modelOptions, temp, topP, topK, freqPenalty, presPenalty
  316. }
  317. func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
  318. large, small, err := c.buildAgentModels(ctx, isSubAgent)
  319. if err != nil {
  320. return nil, err
  321. }
  322. largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
  323. result := NewSessionAgent(SessionAgentOptions{
  324. large,
  325. small,
  326. largeProviderCfg.SystemPromptPrefix,
  327. "",
  328. isSubAgent,
  329. c.cfg.Options.DisableAutoSummarize,
  330. c.permissions.SkipRequests(),
  331. c.sessions,
  332. c.messages,
  333. nil,
  334. })
  335. c.readyWg.Go(func() error {
  336. systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
  337. if err != nil {
  338. return err
  339. }
  340. result.SetSystemPrompt(systemPrompt)
  341. return nil
  342. })
  343. c.readyWg.Go(func() error {
  344. tools, err := c.buildTools(ctx, agent)
  345. if err != nil {
  346. return err
  347. }
  348. result.SetTools(tools)
  349. return nil
  350. })
  351. return result, nil
  352. }
  353. func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
  354. var allTools []fantasy.AgentTool
  355. if slices.Contains(agent.AllowedTools, AgentToolName) {
  356. agentTool, err := c.agentTool(ctx)
  357. if err != nil {
  358. return nil, err
  359. }
  360. allTools = append(allTools, agentTool)
  361. }
  362. if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
  363. agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
  364. if err != nil {
  365. return nil, err
  366. }
  367. allTools = append(allTools, agenticFetchTool)
  368. }
  369. // Get the model name for the agent
  370. modelName := ""
  371. if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
  372. if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
  373. modelName = model.Name
  374. }
  375. }
  376. allTools = append(allTools,
  377. tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
  378. tools.NewJobOutputTool(),
  379. tools.NewJobKillTool(),
  380. tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
  381. tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  382. tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  383. tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
  384. tools.NewGlobTool(c.cfg.WorkingDir()),
  385. tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
  386. tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
  387. tools.NewSourcegraphTool(nil),
  388. tools.NewTodosTool(c.sessions),
  389. tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
  390. tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
  391. )
  392. // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
  393. if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
  394. allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
  395. }
  396. if len(c.cfg.MCP) > 0 {
  397. allTools = append(
  398. allTools,
  399. tools.NewListMCPResourcesTool(c.cfg, c.permissions),
  400. tools.NewReadMCPResourceTool(c.cfg, c.permissions),
  401. )
  402. }
  403. var filteredTools []fantasy.AgentTool
  404. for _, tool := range allTools {
  405. if slices.Contains(agent.AllowedTools, tool.Info().Name) {
  406. filteredTools = append(filteredTools, tool)
  407. }
  408. }
  409. for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
  410. if agent.AllowedMCP == nil {
  411. // No MCP restrictions
  412. filteredTools = append(filteredTools, tool)
  413. continue
  414. }
  415. if len(agent.AllowedMCP) == 0 {
  416. // No MCPs allowed
  417. slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
  418. break
  419. }
  420. for mcp, tools := range agent.AllowedMCP {
  421. if mcp != tool.MCP() {
  422. continue
  423. }
  424. if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
  425. filteredTools = append(filteredTools, tool)
  426. }
  427. }
  428. slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
  429. }
  430. slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
  431. return strings.Compare(a.Info().Name, b.Info().Name)
  432. })
  433. return filteredTools, nil
  434. }
  435. // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
  436. func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
  437. largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
  438. if !ok {
  439. return Model{}, Model{}, errors.New("large model not selected")
  440. }
  441. smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
  442. if !ok {
  443. return Model{}, Model{}, errors.New("small model not selected")
  444. }
  445. largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
  446. if !ok {
  447. return Model{}, Model{}, errors.New("large model provider not configured")
  448. }
  449. largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
  450. if err != nil {
  451. return Model{}, Model{}, err
  452. }
  453. smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
  454. if !ok {
  455. return Model{}, Model{}, errors.New("large model provider not configured")
  456. }
  457. smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg, true)
  458. if err != nil {
  459. return Model{}, Model{}, err
  460. }
  461. var largeCatwalkModel *catwalk.Model
  462. var smallCatwalkModel *catwalk.Model
  463. for _, m := range largeProviderCfg.Models {
  464. if m.ID == largeModelCfg.Model {
  465. largeCatwalkModel = &m
  466. }
  467. }
  468. for _, m := range smallProviderCfg.Models {
  469. if m.ID == smallModelCfg.Model {
  470. smallCatwalkModel = &m
  471. }
  472. }
  473. if largeCatwalkModel == nil {
  474. return Model{}, Model{}, errors.New("large model not found in provider config")
  475. }
  476. if smallCatwalkModel == nil {
  477. return Model{}, Model{}, errors.New("small model not found in provider config")
  478. }
  479. largeModelID := largeModelCfg.Model
  480. smallModelID := smallModelCfg.Model
  481. if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
  482. largeModelID += ":exacto"
  483. }
  484. if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
  485. smallModelID += ":exacto"
  486. }
  487. largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
  488. if err != nil {
  489. return Model{}, Model{}, err
  490. }
  491. smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
  492. if err != nil {
  493. return Model{}, Model{}, err
  494. }
  495. return Model{
  496. Model: largeModel,
  497. CatwalkCfg: *largeCatwalkModel,
  498. ModelCfg: largeModelCfg,
  499. }, Model{
  500. Model: smallModel,
  501. CatwalkCfg: *smallCatwalkModel,
  502. ModelCfg: smallModelCfg,
  503. }, nil
  504. }
  505. func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  506. var opts []anthropic.Option
  507. if strings.HasPrefix(apiKey, "Bearer ") {
  508. // NOTE: Prevent the SDK from picking up the API key from env.
  509. os.Setenv("ANTHROPIC_API_KEY", "")
  510. headers["Authorization"] = apiKey
  511. } else if apiKey != "" {
  512. // X-Api-Key header
  513. opts = append(opts, anthropic.WithAPIKey(apiKey))
  514. }
  515. if len(headers) > 0 {
  516. opts = append(opts, anthropic.WithHeaders(headers))
  517. }
  518. if baseURL != "" {
  519. opts = append(opts, anthropic.WithBaseURL(baseURL))
  520. }
  521. if c.cfg.Options.Debug {
  522. httpClient := log.NewHTTPClient()
  523. opts = append(opts, anthropic.WithHTTPClient(httpClient))
  524. }
  525. return anthropic.New(opts...)
  526. }
  527. func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  528. opts := []openai.Option{
  529. openai.WithAPIKey(apiKey),
  530. openai.WithUseResponsesAPI(),
  531. }
  532. if c.cfg.Options.Debug {
  533. httpClient := log.NewHTTPClient()
  534. opts = append(opts, openai.WithHTTPClient(httpClient))
  535. }
  536. if len(headers) > 0 {
  537. opts = append(opts, openai.WithHeaders(headers))
  538. }
  539. if baseURL != "" {
  540. opts = append(opts, openai.WithBaseURL(baseURL))
  541. }
  542. return openai.New(opts...)
  543. }
  544. func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  545. opts := []openrouter.Option{
  546. openrouter.WithAPIKey(apiKey),
  547. }
  548. if c.cfg.Options.Debug {
  549. httpClient := log.NewHTTPClient()
  550. opts = append(opts, openrouter.WithHTTPClient(httpClient))
  551. }
  552. if len(headers) > 0 {
  553. opts = append(opts, openrouter.WithHeaders(headers))
  554. }
  555. return openrouter.New(opts...)
  556. }
  557. func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  558. opts := []vercel.Option{
  559. vercel.WithAPIKey(apiKey),
  560. }
  561. if c.cfg.Options.Debug {
  562. httpClient := log.NewHTTPClient()
  563. opts = append(opts, vercel.WithHTTPClient(httpClient))
  564. }
  565. if len(headers) > 0 {
  566. opts = append(opts, vercel.WithHeaders(headers))
  567. }
  568. return vercel.New(opts...)
  569. }
  570. func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
  571. opts := []openaicompat.Option{
  572. openaicompat.WithBaseURL(baseURL),
  573. openaicompat.WithAPIKey(apiKey),
  574. }
  575. // Set HTTP client based on provider and debug mode.
  576. var httpClient *http.Client
  577. if providerID == string(catwalk.InferenceProviderCopilot) {
  578. opts = append(opts, openaicompat.WithUseResponsesAPI())
  579. httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
  580. } else if c.cfg.Options.Debug {
  581. httpClient = log.NewHTTPClient()
  582. }
  583. if httpClient != nil {
  584. opts = append(opts, openaicompat.WithHTTPClient(httpClient))
  585. }
  586. if len(headers) > 0 {
  587. opts = append(opts, openaicompat.WithHeaders(headers))
  588. }
  589. for extraKey, extraValue := range extraBody {
  590. opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
  591. }
  592. return openaicompat.New(opts...)
  593. }
  594. func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  595. opts := []azure.Option{
  596. azure.WithBaseURL(baseURL),
  597. azure.WithAPIKey(apiKey),
  598. azure.WithUseResponsesAPI(),
  599. }
  600. if c.cfg.Options.Debug {
  601. httpClient := log.NewHTTPClient()
  602. opts = append(opts, azure.WithHTTPClient(httpClient))
  603. }
  604. if options == nil {
  605. options = make(map[string]string)
  606. }
  607. if apiVersion, ok := options["apiVersion"]; ok {
  608. opts = append(opts, azure.WithAPIVersion(apiVersion))
  609. }
  610. if len(headers) > 0 {
  611. opts = append(opts, azure.WithHeaders(headers))
  612. }
  613. return azure.New(opts...)
  614. }
  615. func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
  616. var opts []bedrock.Option
  617. if c.cfg.Options.Debug {
  618. httpClient := log.NewHTTPClient()
  619. opts = append(opts, bedrock.WithHTTPClient(httpClient))
  620. }
  621. if len(headers) > 0 {
  622. opts = append(opts, bedrock.WithHeaders(headers))
  623. }
  624. bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
  625. if bearerToken != "" {
  626. opts = append(opts, bedrock.WithAPIKey(bearerToken))
  627. }
  628. return bedrock.New(opts...)
  629. }
  630. func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  631. opts := []google.Option{
  632. google.WithBaseURL(baseURL),
  633. google.WithGeminiAPIKey(apiKey),
  634. }
  635. if c.cfg.Options.Debug {
  636. httpClient := log.NewHTTPClient()
  637. opts = append(opts, google.WithHTTPClient(httpClient))
  638. }
  639. if len(headers) > 0 {
  640. opts = append(opts, google.WithHeaders(headers))
  641. }
  642. return google.New(opts...)
  643. }
  644. func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  645. opts := []google.Option{}
  646. if c.cfg.Options.Debug {
  647. httpClient := log.NewHTTPClient()
  648. opts = append(opts, google.WithHTTPClient(httpClient))
  649. }
  650. if len(headers) > 0 {
  651. opts = append(opts, google.WithHeaders(headers))
  652. }
  653. project := options["project"]
  654. location := options["location"]
  655. opts = append(opts, google.WithVertex(project, location))
  656. return google.New(opts...)
  657. }
  658. func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
  659. opts := []hyper.Option{
  660. hyper.WithBaseURL(baseURL),
  661. hyper.WithAPIKey(apiKey),
  662. }
  663. if c.cfg.Options.Debug {
  664. httpClient := log.NewHTTPClient()
  665. opts = append(opts, hyper.WithHTTPClient(httpClient))
  666. }
  667. return hyper.New(opts...)
  668. }
  669. func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
  670. if model.Think {
  671. return true
  672. }
  673. if model.ProviderOptions == nil {
  674. return false
  675. }
  676. opts, err := anthropic.ParseOptions(model.ProviderOptions)
  677. if err != nil {
  678. return false
  679. }
  680. if opts.Thinking != nil {
  681. return true
  682. }
  683. return false
  684. }
  685. func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
  686. headers := maps.Clone(providerCfg.ExtraHeaders)
  687. if headers == nil {
  688. headers = make(map[string]string)
  689. }
  690. // handle special headers for anthropic
  691. if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
  692. if v, ok := headers["anthropic-beta"]; ok {
  693. headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
  694. } else {
  695. headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
  696. }
  697. }
  698. apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
  699. baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
  700. switch providerCfg.Type {
  701. case openai.Name:
  702. return c.buildOpenaiProvider(baseURL, apiKey, headers)
  703. case anthropic.Name:
  704. return c.buildAnthropicProvider(baseURL, apiKey, headers)
  705. case openrouter.Name:
  706. return c.buildOpenrouterProvider(baseURL, apiKey, headers)
  707. case vercel.Name:
  708. return c.buildVercelProvider(baseURL, apiKey, headers)
  709. case azure.Name:
  710. return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
  711. case bedrock.Name:
  712. return c.buildBedrockProvider(headers)
  713. case google.Name:
  714. return c.buildGoogleProvider(baseURL, apiKey, headers)
  715. case "google-vertex":
  716. return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
  717. case openaicompat.Name:
  718. if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
  719. if providerCfg.ExtraBody == nil {
  720. providerCfg.ExtraBody = map[string]any{}
  721. }
  722. providerCfg.ExtraBody["tool_stream"] = true
  723. }
  724. return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
  725. case hyper.Name:
  726. return c.buildHyperProvider(baseURL, apiKey)
  727. default:
  728. return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
  729. }
  730. }
  731. func isExactoSupported(modelID string) bool {
  732. supportedModels := []string{
  733. "moonshotai/kimi-k2-0905",
  734. "deepseek/deepseek-v3.1-terminus",
  735. "z-ai/glm-4.6",
  736. "openai/gpt-oss-120b",
  737. "qwen/qwen3-coder",
  738. }
  739. return slices.Contains(supportedModels, modelID)
  740. }
  741. func (c *coordinator) Cancel(sessionID string) {
  742. c.currentAgent.Cancel(sessionID)
  743. }
  744. func (c *coordinator) CancelAll() {
  745. c.currentAgent.CancelAll()
  746. }
  747. func (c *coordinator) ClearQueue(sessionID string) {
  748. c.currentAgent.ClearQueue(sessionID)
  749. }
  750. func (c *coordinator) IsBusy() bool {
  751. return c.currentAgent.IsBusy()
  752. }
  753. func (c *coordinator) IsSessionBusy(sessionID string) bool {
  754. return c.currentAgent.IsSessionBusy(sessionID)
  755. }
  756. func (c *coordinator) Model() Model {
  757. return c.currentAgent.Model()
  758. }
  759. func (c *coordinator) UpdateModels(ctx context.Context) error {
  760. // build the models again so we make sure we get the latest config
  761. large, small, err := c.buildAgentModels(ctx, false)
  762. if err != nil {
  763. return err
  764. }
  765. c.currentAgent.SetModels(large, small)
  766. agentCfg, ok := c.cfg.Agents[config.AgentCoder]
  767. if !ok {
  768. return errors.New("coder agent not configured")
  769. }
  770. tools, err := c.buildTools(ctx, agentCfg)
  771. if err != nil {
  772. return err
  773. }
  774. c.currentAgent.SetTools(tools)
  775. return nil
  776. }
  777. func (c *coordinator) QueuedPrompts(sessionID string) int {
  778. return c.currentAgent.QueuedPrompts(sessionID)
  779. }
  780. func (c *coordinator) QueuedPromptsList(sessionID string) []string {
  781. return c.currentAgent.QueuedPromptsList(sessionID)
  782. }
  783. func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
  784. providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
  785. if !ok {
  786. return errors.New("model provider not configured")
  787. }
  788. return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
  789. }
  790. func (c *coordinator) isUnauthorized(err error) bool {
  791. var providerErr *fantasy.ProviderError
  792. return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
  793. }
  794. func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
  795. if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
  796. slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
  797. return err
  798. }
  799. if err := c.UpdateModels(ctx); err != nil {
  800. return err
  801. }
  802. return nil
  803. }
  804. func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
  805. newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
  806. if err != nil {
  807. slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
  808. return err
  809. }
  810. providerCfg.APIKey = newAPIKey
  811. c.cfg.Providers.Set(providerCfg.ID, providerCfg)
  812. if err := c.UpdateModels(ctx); err != nil {
  813. return err
  814. }
  815. return nil
  816. }