coordinator.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. package agent
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "os"
  14. "slices"
  15. "strings"
  16. "charm.land/fantasy"
  17. "github.com/charmbracelet/catwalk/pkg/catwalk"
  18. "github.com/charmbracelet/crush/internal/agent/prompt"
  19. "github.com/charmbracelet/crush/internal/agent/tools"
  20. "github.com/charmbracelet/crush/internal/config"
  21. "github.com/charmbracelet/crush/internal/csync"
  22. "github.com/charmbracelet/crush/internal/history"
  23. "github.com/charmbracelet/crush/internal/log"
  24. "github.com/charmbracelet/crush/internal/lsp"
  25. "github.com/charmbracelet/crush/internal/message"
  26. "github.com/charmbracelet/crush/internal/permission"
  27. "github.com/charmbracelet/crush/internal/session"
  28. "golang.org/x/sync/errgroup"
  29. "charm.land/fantasy/providers/anthropic"
  30. "charm.land/fantasy/providers/azure"
  31. "charm.land/fantasy/providers/bedrock"
  32. "charm.land/fantasy/providers/google"
  33. "charm.land/fantasy/providers/openai"
  34. "charm.land/fantasy/providers/openaicompat"
  35. "charm.land/fantasy/providers/openrouter"
  36. openaisdk "github.com/openai/openai-go/v2/option"
  37. "github.com/qjebbs/go-jsons"
  38. )
  39. type Coordinator interface {
  40. // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  41. // SetMainAgent(string)
  42. Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  43. Cancel(sessionID string)
  44. CancelAll()
  45. IsSessionBusy(sessionID string) bool
  46. IsBusy() bool
  47. QueuedPrompts(sessionID string) int
  48. QueuedPromptsList(sessionID string) []string
  49. ClearQueue(sessionID string)
  50. Summarize(context.Context, string) error
  51. Model() Model
  52. UpdateModels(ctx context.Context) error
  53. }
  54. type coordinator struct {
  55. cfg *config.Config
  56. sessions session.Service
  57. messages message.Service
  58. permissions permission.Service
  59. history history.Service
  60. lspClients *csync.Map[string, *lsp.Client]
  61. currentAgent SessionAgent
  62. agents map[string]SessionAgent
  63. readyWg errgroup.Group
  64. }
  65. func NewCoordinator(
  66. ctx context.Context,
  67. cfg *config.Config,
  68. sessions session.Service,
  69. messages message.Service,
  70. permissions permission.Service,
  71. history history.Service,
  72. lspClients *csync.Map[string, *lsp.Client],
  73. ) (Coordinator, error) {
  74. c := &coordinator{
  75. cfg: cfg,
  76. sessions: sessions,
  77. messages: messages,
  78. permissions: permissions,
  79. history: history,
  80. lspClients: lspClients,
  81. agents: make(map[string]SessionAgent),
  82. }
  83. agentCfg, ok := cfg.Agents[config.AgentCoder]
  84. if !ok {
  85. return nil, errors.New("coder agent not configured")
  86. }
  87. // TODO: make this dynamic when we support multiple agents
  88. prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
  89. if err != nil {
  90. return nil, err
  91. }
  92. agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
  93. if err != nil {
  94. return nil, err
  95. }
  96. c.currentAgent = agent
  97. c.agents[config.AgentCoder] = agent
  98. return c, nil
  99. }
  100. // Run implements Coordinator.
  101. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
  102. if err := c.readyWg.Wait(); err != nil {
  103. return nil, err
  104. }
  105. model := c.currentAgent.Model()
  106. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  107. if model.ModelCfg.MaxTokens != 0 {
  108. maxTokens = model.ModelCfg.MaxTokens
  109. }
  110. if !model.CatwalkCfg.SupportsImages && attachments != nil {
  111. attachments = nil
  112. }
  113. providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
  114. if !ok {
  115. return nil, errors.New("model provider not configured")
  116. }
  117. mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
  118. run := func() (*fantasy.AgentResult, error) {
  119. return c.currentAgent.Run(ctx, SessionAgentCall{
  120. SessionID: sessionID,
  121. Prompt: prompt,
  122. Attachments: attachments,
  123. MaxOutputTokens: maxTokens,
  124. ProviderOptions: mergedOptions,
  125. Temperature: temp,
  126. TopP: topP,
  127. TopK: topK,
  128. FrequencyPenalty: freqPenalty,
  129. PresencePenalty: presPenalty,
  130. })
  131. }
  132. result, originalErr := run()
  133. if c.isUnauthorized(originalErr) {
  134. switch {
  135. case providerCfg.OAuthToken != nil:
  136. slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
  137. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  138. return nil, originalErr
  139. }
  140. slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
  141. return run()
  142. case strings.Contains(providerCfg.APIKeyTemplate, "$"):
  143. slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
  144. if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
  145. return nil, originalErr
  146. }
  147. slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
  148. return run()
  149. }
  150. }
  151. return result, originalErr
  152. }
  153. func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
  154. options := fantasy.ProviderOptions{}
  155. cfgOpts := []byte("{}")
  156. providerCfgOpts := []byte("{}")
  157. catwalkOpts := []byte("{}")
  158. if model.ModelCfg.ProviderOptions != nil {
  159. data, err := json.Marshal(model.ModelCfg.ProviderOptions)
  160. if err == nil {
  161. cfgOpts = data
  162. }
  163. }
  164. if providerCfg.ProviderOptions != nil {
  165. data, err := json.Marshal(providerCfg.ProviderOptions)
  166. if err == nil {
  167. providerCfgOpts = data
  168. }
  169. }
  170. if model.CatwalkCfg.Options.ProviderOptions != nil {
  171. data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
  172. if err == nil {
  173. catwalkOpts = data
  174. }
  175. }
  176. readers := []io.Reader{
  177. bytes.NewReader(catwalkOpts),
  178. bytes.NewReader(providerCfgOpts),
  179. bytes.NewReader(cfgOpts),
  180. }
  181. got, err := jsons.Merge(readers)
  182. if err != nil {
  183. slog.Error("Could not merge call config", "err", err)
  184. return options
  185. }
  186. mergedOptions := make(map[string]any)
  187. err = json.Unmarshal([]byte(got), &mergedOptions)
  188. if err != nil {
  189. slog.Error("Could not create config for call", "err", err)
  190. return options
  191. }
  192. switch providerCfg.Type {
  193. case openai.Name, azure.Name:
  194. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  195. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  196. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  197. }
  198. if openai.IsResponsesModel(model.CatwalkCfg.ID) {
  199. if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
  200. mergedOptions["reasoning_summary"] = "auto"
  201. mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
  202. }
  203. parsed, err := openai.ParseResponsesOptions(mergedOptions)
  204. if err == nil {
  205. options[openai.Name] = parsed
  206. }
  207. } else {
  208. parsed, err := openai.ParseOptions(mergedOptions)
  209. if err == nil {
  210. options[openai.Name] = parsed
  211. }
  212. }
  213. case anthropic.Name:
  214. _, hasThink := mergedOptions["thinking"]
  215. if !hasThink && model.ModelCfg.Think {
  216. mergedOptions["thinking"] = map[string]any{
  217. // TODO: kujtim see if we need to make this dynamic
  218. "budget_tokens": 2000,
  219. }
  220. }
  221. parsed, err := anthropic.ParseOptions(mergedOptions)
  222. if err == nil {
  223. options[anthropic.Name] = parsed
  224. }
  225. case openrouter.Name:
  226. _, hasReasoning := mergedOptions["reasoning"]
  227. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  228. mergedOptions["reasoning"] = map[string]any{
  229. "enabled": true,
  230. "effort": model.ModelCfg.ReasoningEffort,
  231. }
  232. }
  233. parsed, err := openrouter.ParseOptions(mergedOptions)
  234. if err == nil {
  235. options[openrouter.Name] = parsed
  236. }
  237. case google.Name:
  238. _, hasReasoning := mergedOptions["thinking_config"]
  239. if !hasReasoning {
  240. mergedOptions["thinking_config"] = map[string]any{
  241. "thinking_budget": 2000,
  242. "include_thoughts": true,
  243. }
  244. }
  245. parsed, err := google.ParseOptions(mergedOptions)
  246. if err == nil {
  247. options[google.Name] = parsed
  248. }
  249. case openaicompat.Name:
  250. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  251. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  252. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  253. }
  254. parsed, err := openaicompat.ParseOptions(mergedOptions)
  255. if err == nil {
  256. options[openaicompat.Name] = parsed
  257. }
  258. }
  259. return options
  260. }
  261. func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
  262. modelOptions := getProviderOptions(model, cfg)
  263. temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
  264. topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
  265. topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
  266. freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
  267. presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
  268. return modelOptions, temp, topP, topK, freqPenalty, presPenalty
  269. }
  270. func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
  271. large, small, err := c.buildAgentModels(ctx)
  272. if err != nil {
  273. return nil, err
  274. }
  275. systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
  276. if err != nil {
  277. return nil, err
  278. }
  279. largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
  280. result := NewSessionAgent(SessionAgentOptions{
  281. large,
  282. small,
  283. largeProviderCfg.SystemPromptPrefix,
  284. systemPrompt,
  285. isSubAgent,
  286. c.cfg.Options.DisableAutoSummarize,
  287. c.permissions.SkipRequests(),
  288. c.sessions,
  289. c.messages,
  290. nil,
  291. })
  292. c.readyWg.Go(func() error {
  293. tools, err := c.buildTools(ctx, agent)
  294. if err != nil {
  295. return err
  296. }
  297. result.SetTools(tools)
  298. return nil
  299. })
  300. return result, nil
  301. }
  302. func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
  303. var allTools []fantasy.AgentTool
  304. if slices.Contains(agent.AllowedTools, AgentToolName) {
  305. agentTool, err := c.agentTool(ctx)
  306. if err != nil {
  307. return nil, err
  308. }
  309. allTools = append(allTools, agentTool)
  310. }
  311. if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
  312. agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
  313. if err != nil {
  314. return nil, err
  315. }
  316. allTools = append(allTools, agenticFetchTool)
  317. }
  318. // Get the model name for the agent
  319. modelName := ""
  320. if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
  321. if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
  322. modelName = model.Name
  323. }
  324. }
  325. allTools = append(allTools,
  326. tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
  327. tools.NewJobOutputTool(),
  328. tools.NewJobKillTool(),
  329. tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
  330. tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
  331. tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
  332. tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
  333. tools.NewGlobTool(c.cfg.WorkingDir()),
  334. tools.NewGrepTool(c.cfg.WorkingDir()),
  335. tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
  336. tools.NewSourcegraphTool(nil),
  337. tools.NewTodosTool(c.sessions),
  338. tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
  339. tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
  340. )
  341. if len(c.cfg.LSP) > 0 {
  342. allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
  343. }
  344. var filteredTools []fantasy.AgentTool
  345. for _, tool := range allTools {
  346. if slices.Contains(agent.AllowedTools, tool.Info().Name) {
  347. filteredTools = append(filteredTools, tool)
  348. }
  349. }
  350. for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
  351. // Check MCP-specific disabled tools.
  352. if mcpCfg, ok := c.cfg.MCP[tool.MCP()]; ok {
  353. if slices.Contains(mcpCfg.DisabledTools, tool.MCPToolName()) {
  354. continue
  355. }
  356. }
  357. if agent.AllowedMCP == nil {
  358. // No MCP restrictions
  359. filteredTools = append(filteredTools, tool)
  360. continue
  361. }
  362. if len(agent.AllowedMCP) == 0 {
  363. // No MCPs allowed
  364. slog.Debug("no MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
  365. break
  366. }
  367. for mcp, tools := range agent.AllowedMCP {
  368. if mcp != tool.MCP() {
  369. continue
  370. }
  371. if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
  372. filteredTools = append(filteredTools, tool)
  373. }
  374. }
  375. slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
  376. }
  377. slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
  378. return strings.Compare(a.Info().Name, b.Info().Name)
  379. })
  380. return filteredTools, nil
  381. }
  382. // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
  383. func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
  384. largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
  385. if !ok {
  386. return Model{}, Model{}, errors.New("large model not selected")
  387. }
  388. smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
  389. if !ok {
  390. return Model{}, Model{}, errors.New("small model not selected")
  391. }
  392. largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
  393. if !ok {
  394. return Model{}, Model{}, errors.New("large model provider not configured")
  395. }
  396. largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
  397. if err != nil {
  398. return Model{}, Model{}, err
  399. }
  400. smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
  401. if !ok {
  402. return Model{}, Model{}, errors.New("large model provider not configured")
  403. }
  404. smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
  405. if err != nil {
  406. return Model{}, Model{}, err
  407. }
  408. var largeCatwalkModel *catwalk.Model
  409. var smallCatwalkModel *catwalk.Model
  410. for _, m := range largeProviderCfg.Models {
  411. if m.ID == largeModelCfg.Model {
  412. largeCatwalkModel = &m
  413. }
  414. }
  415. for _, m := range smallProviderCfg.Models {
  416. if m.ID == smallModelCfg.Model {
  417. smallCatwalkModel = &m
  418. }
  419. }
  420. if largeCatwalkModel == nil {
  421. return Model{}, Model{}, errors.New("large model not found in provider config")
  422. }
  423. if smallCatwalkModel == nil {
  424. return Model{}, Model{}, errors.New("snall model not found in provider config")
  425. }
  426. largeModelID := largeModelCfg.Model
  427. smallModelID := smallModelCfg.Model
  428. if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
  429. largeModelID += ":exacto"
  430. }
  431. if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
  432. smallModelID += ":exacto"
  433. }
  434. largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
  435. if err != nil {
  436. return Model{}, Model{}, err
  437. }
  438. smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
  439. if err != nil {
  440. return Model{}, Model{}, err
  441. }
  442. return Model{
  443. Model: largeModel,
  444. CatwalkCfg: *largeCatwalkModel,
  445. ModelCfg: largeModelCfg,
  446. }, Model{
  447. Model: smallModel,
  448. CatwalkCfg: *smallCatwalkModel,
  449. ModelCfg: smallModelCfg,
  450. }, nil
  451. }
  452. func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  453. var opts []anthropic.Option
  454. if strings.HasPrefix(apiKey, "Bearer ") {
  455. // NOTE: Prevent the SDK from picking up the API key from env.
  456. os.Setenv("ANTHROPIC_API_KEY", "")
  457. headers["Authorization"] = apiKey
  458. } else if apiKey != "" {
  459. // X-Api-Key header
  460. opts = append(opts, anthropic.WithAPIKey(apiKey))
  461. }
  462. if len(headers) > 0 {
  463. opts = append(opts, anthropic.WithHeaders(headers))
  464. }
  465. if baseURL != "" {
  466. opts = append(opts, anthropic.WithBaseURL(baseURL))
  467. }
  468. if c.cfg.Options.Debug {
  469. httpClient := log.NewHTTPClient()
  470. opts = append(opts, anthropic.WithHTTPClient(httpClient))
  471. }
  472. return anthropic.New(opts...)
  473. }
  474. func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  475. opts := []openai.Option{
  476. openai.WithAPIKey(apiKey),
  477. openai.WithUseResponsesAPI(),
  478. }
  479. if c.cfg.Options.Debug {
  480. httpClient := log.NewHTTPClient()
  481. opts = append(opts, openai.WithHTTPClient(httpClient))
  482. }
  483. if len(headers) > 0 {
  484. opts = append(opts, openai.WithHeaders(headers))
  485. }
  486. if baseURL != "" {
  487. opts = append(opts, openai.WithBaseURL(baseURL))
  488. }
  489. return openai.New(opts...)
  490. }
  491. func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  492. opts := []openrouter.Option{
  493. openrouter.WithAPIKey(apiKey),
  494. }
  495. if c.cfg.Options.Debug {
  496. httpClient := log.NewHTTPClient()
  497. opts = append(opts, openrouter.WithHTTPClient(httpClient))
  498. }
  499. if len(headers) > 0 {
  500. opts = append(opts, openrouter.WithHeaders(headers))
  501. }
  502. return openrouter.New(opts...)
  503. }
  504. func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) {
  505. opts := []openaicompat.Option{
  506. openaicompat.WithBaseURL(baseURL),
  507. openaicompat.WithAPIKey(apiKey),
  508. }
  509. if c.cfg.Options.Debug {
  510. httpClient := log.NewHTTPClient()
  511. opts = append(opts, openaicompat.WithHTTPClient(httpClient))
  512. }
  513. if len(headers) > 0 {
  514. opts = append(opts, openaicompat.WithHeaders(headers))
  515. }
  516. for extraKey, extraValue := range extraBody {
  517. opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
  518. }
  519. return openaicompat.New(opts...)
  520. }
  521. func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  522. opts := []azure.Option{
  523. azure.WithBaseURL(baseURL),
  524. azure.WithAPIKey(apiKey),
  525. azure.WithUseResponsesAPI(),
  526. }
  527. if c.cfg.Options.Debug {
  528. httpClient := log.NewHTTPClient()
  529. opts = append(opts, azure.WithHTTPClient(httpClient))
  530. }
  531. if options == nil {
  532. options = make(map[string]string)
  533. }
  534. if apiVersion, ok := options["apiVersion"]; ok {
  535. opts = append(opts, azure.WithAPIVersion(apiVersion))
  536. }
  537. if len(headers) > 0 {
  538. opts = append(opts, azure.WithHeaders(headers))
  539. }
  540. return azure.New(opts...)
  541. }
  542. func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
  543. var opts []bedrock.Option
  544. if c.cfg.Options.Debug {
  545. httpClient := log.NewHTTPClient()
  546. opts = append(opts, bedrock.WithHTTPClient(httpClient))
  547. }
  548. if len(headers) > 0 {
  549. opts = append(opts, bedrock.WithHeaders(headers))
  550. }
  551. bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
  552. if bearerToken != "" {
  553. opts = append(opts, bedrock.WithAPIKey(bearerToken))
  554. }
  555. return bedrock.New(opts...)
  556. }
  557. func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  558. opts := []google.Option{
  559. google.WithBaseURL(baseURL),
  560. google.WithGeminiAPIKey(apiKey),
  561. }
  562. if c.cfg.Options.Debug {
  563. httpClient := log.NewHTTPClient()
  564. opts = append(opts, google.WithHTTPClient(httpClient))
  565. }
  566. if len(headers) > 0 {
  567. opts = append(opts, google.WithHeaders(headers))
  568. }
  569. return google.New(opts...)
  570. }
  571. func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  572. opts := []google.Option{}
  573. if c.cfg.Options.Debug {
  574. httpClient := log.NewHTTPClient()
  575. opts = append(opts, google.WithHTTPClient(httpClient))
  576. }
  577. if len(headers) > 0 {
  578. opts = append(opts, google.WithHeaders(headers))
  579. }
  580. project := options["project"]
  581. location := options["location"]
  582. opts = append(opts, google.WithVertex(project, location))
  583. return google.New(opts...)
  584. }
  585. func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
  586. if model.Think {
  587. return true
  588. }
  589. if model.ProviderOptions == nil {
  590. return false
  591. }
  592. opts, err := anthropic.ParseOptions(model.ProviderOptions)
  593. if err != nil {
  594. return false
  595. }
  596. if opts.Thinking != nil {
  597. return true
  598. }
  599. return false
  600. }
  601. func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
  602. headers := maps.Clone(providerCfg.ExtraHeaders)
  603. if headers == nil {
  604. headers = make(map[string]string)
  605. }
  606. // handle special headers for anthropic
  607. if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
  608. if v, ok := headers["anthropic-beta"]; ok {
  609. headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
  610. } else {
  611. headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
  612. }
  613. }
  614. apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
  615. baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
  616. switch providerCfg.Type {
  617. case openai.Name:
  618. return c.buildOpenaiProvider(baseURL, apiKey, headers)
  619. case anthropic.Name:
  620. return c.buildAnthropicProvider(baseURL, apiKey, headers)
  621. case openrouter.Name:
  622. return c.buildOpenrouterProvider(baseURL, apiKey, headers)
  623. case azure.Name:
  624. return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
  625. case bedrock.Name:
  626. return c.buildBedrockProvider(headers)
  627. case google.Name:
  628. return c.buildGoogleProvider(baseURL, apiKey, headers)
  629. case "google-vertex":
  630. return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
  631. case openaicompat.Name:
  632. if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
  633. if providerCfg.ExtraBody == nil {
  634. providerCfg.ExtraBody = map[string]any{}
  635. }
  636. providerCfg.ExtraBody["tool_stream"] = true
  637. }
  638. return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
  639. default:
  640. return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
  641. }
  642. }
  643. func isExactoSupported(modelID string) bool {
  644. supportedModels := []string{
  645. "moonshotai/kimi-k2-0905",
  646. "deepseek/deepseek-v3.1-terminus",
  647. "z-ai/glm-4.6",
  648. "openai/gpt-oss-120b",
  649. "qwen/qwen3-coder",
  650. }
  651. return slices.Contains(supportedModels, modelID)
  652. }
  653. func (c *coordinator) Cancel(sessionID string) {
  654. c.currentAgent.Cancel(sessionID)
  655. }
  656. func (c *coordinator) CancelAll() {
  657. c.currentAgent.CancelAll()
  658. }
  659. func (c *coordinator) ClearQueue(sessionID string) {
  660. c.currentAgent.ClearQueue(sessionID)
  661. }
  662. func (c *coordinator) IsBusy() bool {
  663. return c.currentAgent.IsBusy()
  664. }
  665. func (c *coordinator) IsSessionBusy(sessionID string) bool {
  666. return c.currentAgent.IsSessionBusy(sessionID)
  667. }
  668. func (c *coordinator) Model() Model {
  669. return c.currentAgent.Model()
  670. }
  671. func (c *coordinator) UpdateModels(ctx context.Context) error {
  672. // build the models again so we make sure we get the latest config
  673. large, small, err := c.buildAgentModels(ctx)
  674. if err != nil {
  675. return err
  676. }
  677. c.currentAgent.SetModels(large, small)
  678. agentCfg, ok := c.cfg.Agents[config.AgentCoder]
  679. if !ok {
  680. return errors.New("coder agent not configured")
  681. }
  682. tools, err := c.buildTools(ctx, agentCfg)
  683. if err != nil {
  684. return err
  685. }
  686. c.currentAgent.SetTools(tools)
  687. return nil
  688. }
  689. func (c *coordinator) QueuedPrompts(sessionID string) int {
  690. return c.currentAgent.QueuedPrompts(sessionID)
  691. }
  692. func (c *coordinator) QueuedPromptsList(sessionID string) []string {
  693. return c.currentAgent.QueuedPromptsList(sessionID)
  694. }
  695. func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
  696. providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
  697. if !ok {
  698. return errors.New("model provider not configured")
  699. }
  700. return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
  701. }
  702. func (c *coordinator) isUnauthorized(err error) bool {
  703. var providerErr *fantasy.ProviderError
  704. return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
  705. }
  706. func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
  707. if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
  708. slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
  709. return err
  710. }
  711. if err := c.UpdateModels(ctx); err != nil {
  712. return err
  713. }
  714. return nil
  715. }
  716. func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
  717. newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
  718. if err != nil {
  719. slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
  720. return err
  721. }
  722. providerCfg.APIKey = newAPIKey
  723. c.cfg.Providers.Set(providerCfg.ID, providerCfg)
  724. if err := c.UpdateModels(ctx); err != nil {
  725. return err
  726. }
  727. return nil
  728. }