2
0

coordinator.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. package agent
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "os"
  14. "slices"
  15. "strings"
  16. "charm.land/fantasy"
  17. "github.com/charmbracelet/catwalk/pkg/catwalk"
  18. "github.com/charmbracelet/crush/internal/agent/prompt"
  19. "github.com/charmbracelet/crush/internal/agent/tools"
  20. "github.com/charmbracelet/crush/internal/config"
  21. "github.com/charmbracelet/crush/internal/csync"
  22. "github.com/charmbracelet/crush/internal/history"
  23. "github.com/charmbracelet/crush/internal/log"
  24. "github.com/charmbracelet/crush/internal/lsp"
  25. "github.com/charmbracelet/crush/internal/message"
  26. "github.com/charmbracelet/crush/internal/permission"
  27. "github.com/charmbracelet/crush/internal/session"
  28. "golang.org/x/sync/errgroup"
  29. "charm.land/fantasy/providers/anthropic"
  30. "charm.land/fantasy/providers/azure"
  31. "charm.land/fantasy/providers/bedrock"
  32. "charm.land/fantasy/providers/google"
  33. "charm.land/fantasy/providers/openai"
  34. "charm.land/fantasy/providers/openaicompat"
  35. "charm.land/fantasy/providers/openrouter"
  36. openaisdk "github.com/openai/openai-go/v2/option"
  37. "github.com/qjebbs/go-jsons"
  38. )
  39. type Coordinator interface {
  40. // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  41. // SetMainAgent(string)
  42. Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  43. Cancel(sessionID string)
  44. CancelAll()
  45. IsSessionBusy(sessionID string) bool
  46. IsBusy() bool
  47. QueuedPrompts(sessionID string) int
  48. ClearQueue(sessionID string)
  49. Summarize(context.Context, string) error
  50. Model() Model
  51. UpdateModels(ctx context.Context) error
  52. }
  53. type coordinator struct {
  54. cfg *config.Config
  55. sessions session.Service
  56. messages message.Service
  57. permissions permission.Service
  58. history history.Service
  59. lspClients *csync.Map[string, *lsp.Client]
  60. currentAgent SessionAgent
  61. agents map[string]SessionAgent
  62. readyWg errgroup.Group
  63. }
  64. func NewCoordinator(
  65. ctx context.Context,
  66. cfg *config.Config,
  67. sessions session.Service,
  68. messages message.Service,
  69. permissions permission.Service,
  70. history history.Service,
  71. lspClients *csync.Map[string, *lsp.Client],
  72. ) (Coordinator, error) {
  73. c := &coordinator{
  74. cfg: cfg,
  75. sessions: sessions,
  76. messages: messages,
  77. permissions: permissions,
  78. history: history,
  79. lspClients: lspClients,
  80. agents: make(map[string]SessionAgent),
  81. }
  82. agentCfg, ok := cfg.Agents[config.AgentCoder]
  83. if !ok {
  84. return nil, errors.New("coder agent not configured")
  85. }
  86. // TODO: make this dynamic when we support multiple agents
  87. prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
  88. if err != nil {
  89. return nil, err
  90. }
  91. agent, err := c.buildAgent(ctx, prompt, agentCfg)
  92. if err != nil {
  93. return nil, err
  94. }
  95. c.currentAgent = agent
  96. c.agents[config.AgentCoder] = agent
  97. return c, nil
  98. }
  99. // Run implements Coordinator.
  100. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
  101. if err := c.readyWg.Wait(); err != nil {
  102. return nil, err
  103. }
  104. model := c.currentAgent.Model()
  105. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  106. if model.ModelCfg.MaxTokens != 0 {
  107. maxTokens = model.ModelCfg.MaxTokens
  108. }
  109. if !model.CatwalkCfg.SupportsImages && attachments != nil {
  110. attachments = nil
  111. }
  112. providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
  113. if !ok {
  114. return nil, errors.New("model provider not configured")
  115. }
  116. mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
  117. run := func() (*fantasy.AgentResult, error) {
  118. return c.currentAgent.Run(ctx, SessionAgentCall{
  119. SessionID: sessionID,
  120. Prompt: prompt,
  121. Attachments: attachments,
  122. MaxOutputTokens: maxTokens,
  123. ProviderOptions: mergedOptions,
  124. Temperature: temp,
  125. TopP: topP,
  126. TopK: topK,
  127. FrequencyPenalty: freqPenalty,
  128. PresencePenalty: presPenalty,
  129. })
  130. }
  131. result, originalErr := run()
  132. if c.isUnauthorized(originalErr) {
  133. switch {
  134. case providerCfg.OAuthToken != nil:
  135. slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
  136. if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
  137. return nil, originalErr
  138. }
  139. slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
  140. return run()
  141. case strings.Contains(providerCfg.APIKeyTemplate, "$"):
  142. slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
  143. if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
  144. return nil, originalErr
  145. }
  146. slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
  147. return run()
  148. }
  149. }
  150. return result, originalErr
  151. }
  152. func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
  153. options := fantasy.ProviderOptions{}
  154. cfgOpts := []byte("{}")
  155. providerCfgOpts := []byte("{}")
  156. catwalkOpts := []byte("{}")
  157. if model.ModelCfg.ProviderOptions != nil {
  158. data, err := json.Marshal(model.ModelCfg.ProviderOptions)
  159. if err == nil {
  160. cfgOpts = data
  161. }
  162. }
  163. if providerCfg.ProviderOptions != nil {
  164. data, err := json.Marshal(providerCfg.ProviderOptions)
  165. if err == nil {
  166. providerCfgOpts = data
  167. }
  168. }
  169. if model.CatwalkCfg.Options.ProviderOptions != nil {
  170. data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
  171. if err == nil {
  172. catwalkOpts = data
  173. }
  174. }
  175. readers := []io.Reader{
  176. bytes.NewReader(catwalkOpts),
  177. bytes.NewReader(providerCfgOpts),
  178. bytes.NewReader(cfgOpts),
  179. }
  180. got, err := jsons.Merge(readers)
  181. if err != nil {
  182. slog.Error("Could not merge call config", "err", err)
  183. return options
  184. }
  185. mergedOptions := make(map[string]any)
  186. err = json.Unmarshal([]byte(got), &mergedOptions)
  187. if err != nil {
  188. slog.Error("Could not create config for call", "err", err)
  189. return options
  190. }
  191. switch providerCfg.Type {
  192. case openai.Name, azure.Name:
  193. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  194. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  195. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  196. }
  197. if openai.IsResponsesModel(model.CatwalkCfg.ID) {
  198. if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
  199. mergedOptions["reasoning_summary"] = "auto"
  200. mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
  201. }
  202. parsed, err := openai.ParseResponsesOptions(mergedOptions)
  203. if err == nil {
  204. options[openai.Name] = parsed
  205. }
  206. } else {
  207. parsed, err := openai.ParseOptions(mergedOptions)
  208. if err == nil {
  209. options[openai.Name] = parsed
  210. }
  211. }
  212. case anthropic.Name:
  213. _, hasThink := mergedOptions["thinking"]
  214. if !hasThink && model.ModelCfg.Think {
  215. mergedOptions["thinking"] = map[string]any{
  216. // TODO: kujtim see if we need to make this dynamic
  217. "budget_tokens": 2000,
  218. }
  219. }
  220. parsed, err := anthropic.ParseOptions(mergedOptions)
  221. if err == nil {
  222. options[anthropic.Name] = parsed
  223. }
  224. case openrouter.Name:
  225. _, hasReasoning := mergedOptions["reasoning"]
  226. if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
  227. mergedOptions["reasoning"] = map[string]any{
  228. "enabled": true,
  229. "effort": model.ModelCfg.ReasoningEffort,
  230. }
  231. }
  232. parsed, err := openrouter.ParseOptions(mergedOptions)
  233. if err == nil {
  234. options[openrouter.Name] = parsed
  235. }
  236. case google.Name:
  237. _, hasReasoning := mergedOptions["thinking_config"]
  238. if !hasReasoning {
  239. mergedOptions["thinking_config"] = map[string]any{
  240. "thinking_budget": 2000,
  241. "include_thoughts": true,
  242. }
  243. }
  244. parsed, err := google.ParseOptions(mergedOptions)
  245. if err == nil {
  246. options[google.Name] = parsed
  247. }
  248. case openaicompat.Name:
  249. _, hasReasoningEffort := mergedOptions["reasoning_effort"]
  250. if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
  251. mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
  252. }
  253. parsed, err := openaicompat.ParseOptions(mergedOptions)
  254. if err == nil {
  255. options[openaicompat.Name] = parsed
  256. }
  257. }
  258. return options
  259. }
  260. func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
  261. modelOptions := getProviderOptions(model, cfg)
  262. temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
  263. topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
  264. topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
  265. freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
  266. presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
  267. return modelOptions, temp, topP, topK, freqPenalty, presPenalty
  268. }
  269. func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
  270. large, small, err := c.buildAgentModels(ctx)
  271. if err != nil {
  272. return nil, err
  273. }
  274. systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
  275. if err != nil {
  276. return nil, err
  277. }
  278. largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
  279. result := NewSessionAgent(SessionAgentOptions{
  280. large,
  281. small,
  282. largeProviderCfg.SystemPromptPrefix,
  283. systemPrompt,
  284. c.cfg.Options.DisableAutoSummarize,
  285. c.permissions.SkipRequests(),
  286. c.sessions,
  287. c.messages,
  288. nil,
  289. })
  290. c.readyWg.Go(func() error {
  291. tools, err := c.buildTools(ctx, agent)
  292. if err != nil {
  293. return err
  294. }
  295. result.SetTools(tools)
  296. return nil
  297. })
  298. return result, nil
  299. }
  300. func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
  301. var allTools []fantasy.AgentTool
  302. if slices.Contains(agent.AllowedTools, AgentToolName) {
  303. agentTool, err := c.agentTool(ctx)
  304. if err != nil {
  305. return nil, err
  306. }
  307. allTools = append(allTools, agentTool)
  308. }
  309. if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
  310. agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
  311. if err != nil {
  312. return nil, err
  313. }
  314. allTools = append(allTools, agenticFetchTool)
  315. }
  316. // Get the model name for the agent
  317. modelName := ""
  318. if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
  319. if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
  320. modelName = model.Name
  321. }
  322. }
  323. allTools = append(allTools,
  324. tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
  325. tools.NewJobOutputTool(),
  326. tools.NewJobKillTool(),
  327. tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
  328. tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
  329. tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
  330. tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
  331. tools.NewGlobTool(c.cfg.WorkingDir()),
  332. tools.NewGrepTool(c.cfg.WorkingDir()),
  333. tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
  334. tools.NewSourcegraphTool(nil),
  335. tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
  336. tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
  337. )
  338. if len(c.cfg.LSP) > 0 {
  339. allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
  340. }
  341. var filteredTools []fantasy.AgentTool
  342. for _, tool := range allTools {
  343. if slices.Contains(agent.AllowedTools, tool.Info().Name) {
  344. filteredTools = append(filteredTools, tool)
  345. }
  346. }
  347. for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
  348. // Check MCP-specific disabled tools.
  349. if mcpCfg, ok := c.cfg.MCP[tool.MCP()]; ok {
  350. if slices.Contains(mcpCfg.DisabledTools, tool.MCPToolName()) {
  351. continue
  352. }
  353. }
  354. if agent.AllowedMCP == nil {
  355. // No MCP restrictions
  356. filteredTools = append(filteredTools, tool)
  357. continue
  358. }
  359. if len(agent.AllowedMCP) == 0 {
  360. // No MCPs allowed
  361. slog.Debug("no MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
  362. break
  363. }
  364. for mcp, tools := range agent.AllowedMCP {
  365. if mcp != tool.MCP() {
  366. continue
  367. }
  368. if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
  369. filteredTools = append(filteredTools, tool)
  370. }
  371. }
  372. slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
  373. }
  374. slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
  375. return strings.Compare(a.Info().Name, b.Info().Name)
  376. })
  377. return filteredTools, nil
  378. }
  379. // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
  380. func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
  381. largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
  382. if !ok {
  383. return Model{}, Model{}, errors.New("large model not selected")
  384. }
  385. smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
  386. if !ok {
  387. return Model{}, Model{}, errors.New("small model not selected")
  388. }
  389. largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
  390. if !ok {
  391. return Model{}, Model{}, errors.New("large model provider not configured")
  392. }
  393. largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
  394. if err != nil {
  395. return Model{}, Model{}, err
  396. }
  397. smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
  398. if !ok {
  399. return Model{}, Model{}, errors.New("large model provider not configured")
  400. }
  401. smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
  402. if err != nil {
  403. return Model{}, Model{}, err
  404. }
  405. var largeCatwalkModel *catwalk.Model
  406. var smallCatwalkModel *catwalk.Model
  407. for _, m := range largeProviderCfg.Models {
  408. if m.ID == largeModelCfg.Model {
  409. largeCatwalkModel = &m
  410. }
  411. }
  412. for _, m := range smallProviderCfg.Models {
  413. if m.ID == smallModelCfg.Model {
  414. smallCatwalkModel = &m
  415. }
  416. }
  417. if largeCatwalkModel == nil {
  418. return Model{}, Model{}, errors.New("large model not found in provider config")
  419. }
  420. if smallCatwalkModel == nil {
  421. return Model{}, Model{}, errors.New("snall model not found in provider config")
  422. }
  423. largeModelID := largeModelCfg.Model
  424. smallModelID := smallModelCfg.Model
  425. if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
  426. largeModelID += ":exacto"
  427. }
  428. if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
  429. smallModelID += ":exacto"
  430. }
  431. largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
  432. if err != nil {
  433. return Model{}, Model{}, err
  434. }
  435. smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
  436. if err != nil {
  437. return Model{}, Model{}, err
  438. }
  439. return Model{
  440. Model: largeModel,
  441. CatwalkCfg: *largeCatwalkModel,
  442. ModelCfg: largeModelCfg,
  443. }, Model{
  444. Model: smallModel,
  445. CatwalkCfg: *smallCatwalkModel,
  446. ModelCfg: smallModelCfg,
  447. }, nil
  448. }
  449. func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  450. var opts []anthropic.Option
  451. if strings.HasPrefix(apiKey, "Bearer ") {
  452. // NOTE: Prevent the SDK from picking up the API key from env.
  453. os.Setenv("ANTHROPIC_API_KEY", "")
  454. headers["Authorization"] = apiKey
  455. } else if apiKey != "" {
  456. // X-Api-Key header
  457. opts = append(opts, anthropic.WithAPIKey(apiKey))
  458. }
  459. if len(headers) > 0 {
  460. opts = append(opts, anthropic.WithHeaders(headers))
  461. }
  462. if baseURL != "" {
  463. opts = append(opts, anthropic.WithBaseURL(baseURL))
  464. }
  465. if c.cfg.Options.Debug {
  466. httpClient := log.NewHTTPClient()
  467. opts = append(opts, anthropic.WithHTTPClient(httpClient))
  468. }
  469. return anthropic.New(opts...)
  470. }
  471. func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  472. opts := []openai.Option{
  473. openai.WithAPIKey(apiKey),
  474. openai.WithUseResponsesAPI(),
  475. }
  476. if c.cfg.Options.Debug {
  477. httpClient := log.NewHTTPClient()
  478. opts = append(opts, openai.WithHTTPClient(httpClient))
  479. }
  480. if len(headers) > 0 {
  481. opts = append(opts, openai.WithHeaders(headers))
  482. }
  483. if baseURL != "" {
  484. opts = append(opts, openai.WithBaseURL(baseURL))
  485. }
  486. return openai.New(opts...)
  487. }
  488. func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  489. opts := []openrouter.Option{
  490. openrouter.WithAPIKey(apiKey),
  491. }
  492. if c.cfg.Options.Debug {
  493. httpClient := log.NewHTTPClient()
  494. opts = append(opts, openrouter.WithHTTPClient(httpClient))
  495. }
  496. if len(headers) > 0 {
  497. opts = append(opts, openrouter.WithHeaders(headers))
  498. }
  499. return openrouter.New(opts...)
  500. }
  501. func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) {
  502. opts := []openaicompat.Option{
  503. openaicompat.WithBaseURL(baseURL),
  504. openaicompat.WithAPIKey(apiKey),
  505. }
  506. if c.cfg.Options.Debug {
  507. httpClient := log.NewHTTPClient()
  508. opts = append(opts, openaicompat.WithHTTPClient(httpClient))
  509. }
  510. if len(headers) > 0 {
  511. opts = append(opts, openaicompat.WithHeaders(headers))
  512. }
  513. for extraKey, extraValue := range extraBody {
  514. opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
  515. }
  516. return openaicompat.New(opts...)
  517. }
  518. func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  519. opts := []azure.Option{
  520. azure.WithBaseURL(baseURL),
  521. azure.WithAPIKey(apiKey),
  522. azure.WithUseResponsesAPI(),
  523. }
  524. if c.cfg.Options.Debug {
  525. httpClient := log.NewHTTPClient()
  526. opts = append(opts, azure.WithHTTPClient(httpClient))
  527. }
  528. if options == nil {
  529. options = make(map[string]string)
  530. }
  531. if apiVersion, ok := options["apiVersion"]; ok {
  532. opts = append(opts, azure.WithAPIVersion(apiVersion))
  533. }
  534. if len(headers) > 0 {
  535. opts = append(opts, azure.WithHeaders(headers))
  536. }
  537. return azure.New(opts...)
  538. }
  539. func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
  540. var opts []bedrock.Option
  541. if c.cfg.Options.Debug {
  542. httpClient := log.NewHTTPClient()
  543. opts = append(opts, bedrock.WithHTTPClient(httpClient))
  544. }
  545. if len(headers) > 0 {
  546. opts = append(opts, bedrock.WithHeaders(headers))
  547. }
  548. bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
  549. if bearerToken != "" {
  550. opts = append(opts, bedrock.WithAPIKey(bearerToken))
  551. }
  552. return bedrock.New(opts...)
  553. }
  554. func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
  555. opts := []google.Option{
  556. google.WithBaseURL(baseURL),
  557. google.WithGeminiAPIKey(apiKey),
  558. }
  559. if c.cfg.Options.Debug {
  560. httpClient := log.NewHTTPClient()
  561. opts = append(opts, google.WithHTTPClient(httpClient))
  562. }
  563. if len(headers) > 0 {
  564. opts = append(opts, google.WithHeaders(headers))
  565. }
  566. return google.New(opts...)
  567. }
  568. func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
  569. opts := []google.Option{}
  570. if c.cfg.Options.Debug {
  571. httpClient := log.NewHTTPClient()
  572. opts = append(opts, google.WithHTTPClient(httpClient))
  573. }
  574. if len(headers) > 0 {
  575. opts = append(opts, google.WithHeaders(headers))
  576. }
  577. project := options["project"]
  578. location := options["location"]
  579. opts = append(opts, google.WithVertex(project, location))
  580. return google.New(opts...)
  581. }
  582. func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
  583. if model.Think {
  584. return true
  585. }
  586. if model.ProviderOptions == nil {
  587. return false
  588. }
  589. opts, err := anthropic.ParseOptions(model.ProviderOptions)
  590. if err != nil {
  591. return false
  592. }
  593. if opts.Thinking != nil {
  594. return true
  595. }
  596. return false
  597. }
  598. func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
  599. headers := maps.Clone(providerCfg.ExtraHeaders)
  600. if headers == nil {
  601. headers = make(map[string]string)
  602. }
  603. // handle special headers for anthropic
  604. if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
  605. if v, ok := headers["anthropic-beta"]; ok {
  606. headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
  607. } else {
  608. headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
  609. }
  610. }
  611. apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
  612. baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
  613. switch providerCfg.Type {
  614. case openai.Name:
  615. return c.buildOpenaiProvider(baseURL, apiKey, headers)
  616. case anthropic.Name:
  617. return c.buildAnthropicProvider(baseURL, apiKey, headers)
  618. case openrouter.Name:
  619. return c.buildOpenrouterProvider(baseURL, apiKey, headers)
  620. case azure.Name:
  621. return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
  622. case bedrock.Name:
  623. return c.buildBedrockProvider(headers)
  624. case google.Name:
  625. return c.buildGoogleProvider(baseURL, apiKey, headers)
  626. case "google-vertex":
  627. return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
  628. case openaicompat.Name:
  629. if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
  630. if providerCfg.ExtraBody == nil {
  631. providerCfg.ExtraBody = map[string]any{}
  632. }
  633. providerCfg.ExtraBody["tool_stream"] = true
  634. }
  635. return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
  636. default:
  637. return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
  638. }
  639. }
  640. func isExactoSupported(modelID string) bool {
  641. supportedModels := []string{
  642. "moonshotai/kimi-k2-0905",
  643. "deepseek/deepseek-v3.1-terminus",
  644. "z-ai/glm-4.6",
  645. "openai/gpt-oss-120b",
  646. "qwen/qwen3-coder",
  647. }
  648. return slices.Contains(supportedModels, modelID)
  649. }
  650. func (c *coordinator) Cancel(sessionID string) {
  651. c.currentAgent.Cancel(sessionID)
  652. }
  653. func (c *coordinator) CancelAll() {
  654. c.currentAgent.CancelAll()
  655. }
  656. func (c *coordinator) ClearQueue(sessionID string) {
  657. c.currentAgent.ClearQueue(sessionID)
  658. }
  659. func (c *coordinator) IsBusy() bool {
  660. return c.currentAgent.IsBusy()
  661. }
  662. func (c *coordinator) IsSessionBusy(sessionID string) bool {
  663. return c.currentAgent.IsSessionBusy(sessionID)
  664. }
  665. func (c *coordinator) Model() Model {
  666. return c.currentAgent.Model()
  667. }
  668. func (c *coordinator) UpdateModels(ctx context.Context) error {
  669. // build the models again so we make sure we get the latest config
  670. large, small, err := c.buildAgentModels(ctx)
  671. if err != nil {
  672. return err
  673. }
  674. c.currentAgent.SetModels(large, small)
  675. agentCfg, ok := c.cfg.Agents[config.AgentCoder]
  676. if !ok {
  677. return errors.New("coder agent not configured")
  678. }
  679. tools, err := c.buildTools(ctx, agentCfg)
  680. if err != nil {
  681. return err
  682. }
  683. c.currentAgent.SetTools(tools)
  684. return nil
  685. }
  686. func (c *coordinator) QueuedPrompts(sessionID string) int {
  687. return c.currentAgent.QueuedPrompts(sessionID)
  688. }
  689. func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
  690. providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
  691. if !ok {
  692. return errors.New("model provider not configured")
  693. }
  694. return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
  695. }
  696. func (c *coordinator) isUnauthorized(err error) bool {
  697. var providerErr *fantasy.ProviderError
  698. return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
  699. }
  700. func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
  701. if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
  702. slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
  703. return err
  704. }
  705. if err := c.UpdateModels(ctx); err != nil {
  706. return err
  707. }
  708. return nil
  709. }
  710. func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
  711. newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
  712. if err != nil {
  713. slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
  714. return err
  715. }
  716. providerCfg.APIKey = newAPIKey
  717. c.cfg.Providers.Set(providerCfg.ID, providerCfg)
  718. if err := c.UpdateModels(ctx); err != nil {
  719. return err
  720. }
  721. return nil
  722. }