config.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. package config
  2. import (
  3. "context"
  4. "fmt"
  5. "log/slog"
  6. "net/http"
  7. "net/url"
  8. "os"
  9. "slices"
  10. "strings"
  11. "time"
  12. "github.com/charmbracelet/catwalk/pkg/catwalk"
  13. "github.com/charmbracelet/crush/internal/csync"
  14. "github.com/charmbracelet/crush/internal/env"
  15. "github.com/tidwall/sjson"
  16. )
  17. const (
  18. appName = "crush"
  19. defaultDataDirectory = ".crush"
  20. )
  21. var defaultContextPaths = []string{
  22. ".github/copilot-instructions.md",
  23. ".cursorrules",
  24. ".cursor/rules/",
  25. "CLAUDE.md",
  26. "CLAUDE.local.md",
  27. "GEMINI.md",
  28. "gemini.md",
  29. "crush.md",
  30. "crush.local.md",
  31. "Crush.md",
  32. "Crush.local.md",
  33. "CRUSH.md",
  34. "CRUSH.local.md",
  35. "AGENTS.md",
  36. "agents.md",
  37. "Agents.md",
  38. }
  39. type SelectedModelType string
  40. const (
  41. SelectedModelTypeLarge SelectedModelType = "large"
  42. SelectedModelTypeSmall SelectedModelType = "small"
  43. )
  44. const (
  45. AgentCoder string = "coder"
  46. AgentTask string = "task"
  47. )
  48. type SelectedModel struct {
  49. // The model id as used by the provider API.
  50. // Required.
  51. Model string `json:"model" jsonschema:"required,description=The model ID as used by the provider API,example=gpt-4o"`
  52. // The model provider, same as the key/id used in the providers config.
  53. // Required.
  54. Provider string `json:"provider" jsonschema:"required,description=The model provider ID that matches a key in the providers config,example=openai"`
  55. // Only used by models that use the openai provider and need this set.
  56. ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"description=Reasoning effort level for OpenAI models that support it,enum=low,enum=medium,enum=high"`
  57. // Used by anthropic models that can reason to indicate if the model should think.
  58. Think bool `json:"think,omitempty" jsonschema:"description=Enable thinking mode for Anthropic models that support reasoning"`
  59. // Overrides the default model configuration.
  60. MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"description=Maximum number of tokens for model responses,minimum=1,maximum=200000,example=4096"`
  61. Temperature *float64 `json:"temperature,omitempty" jsonschema:"description=Sampling temperature,minimum=0,maximum=1,example=0.7"`
  62. TopP *float64 `json:"top_p,omitempty" jsonschema:"description=Top-p (nucleus) sampling parameter,minimum=0,maximum=1,example=0.9"`
  63. TopK *int64 `json:"top_k,omitempty" jsonschema:"description=Top-k sampling parameter"`
  64. FrequencyPenalty *float64 `json:"frequency_penalty,omitempty" jsonschema:"description=Frequency penalty to reduce repetition"`
  65. PresencePenalty *float64 `json:"presence_penalty,omitempty" jsonschema:"description=Presence penalty to increase topic diversity"`
  66. // Override provider specific options.
  67. ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for the model"`
  68. }
  69. type ProviderConfig struct {
  70. // The provider's id.
  71. ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"`
  72. // The provider's name, used for display purposes.
  73. Name string `json:"name,omitempty" jsonschema:"description=Human-readable name for the provider,example=OpenAI"`
  74. // The provider's API endpoint.
  75. BaseURL string `json:"base_url,omitempty" jsonschema:"description=Base URL for the provider's API,format=uri,example=https://api.openai.com/v1"`
  76. // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
  77. Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"`
  78. // The provider's API key.
  79. APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"`
  80. // Marks the provider as disabled.
  81. Disable bool `json:"disable,omitempty" jsonschema:"description=Whether this provider is disabled,default=false"`
  82. // Custom system prompt prefix.
  83. SystemPromptPrefix string `json:"system_prompt_prefix,omitempty" jsonschema:"description=Custom prefix to add to system prompts for this provider"`
  84. // Extra headers to send with each request to the provider.
  85. ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"description=Additional HTTP headers to send with requests"`
  86. // Extra body
  87. ExtraBody map[string]any `json:"extra_body,omitempty" jsonschema:"description=Additional fields to include in request bodies, only works with openai-compatible providers"`
  88. ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for this provider"`
  89. // Used to pass extra parameters to the provider.
  90. ExtraParams map[string]string `json:"-"`
  91. // The provider models
  92. Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"`
  93. }
  94. type MCPType string
  95. const (
  96. MCPStdio MCPType = "stdio"
  97. MCPSSE MCPType = "sse"
  98. MCPHttp MCPType = "http"
  99. )
  100. type MCPConfig struct {
  101. Command string `json:"command,omitempty" jsonschema:"description=Command to execute for stdio MCP servers,example=npx"`
  102. Env map[string]string `json:"env,omitempty" jsonschema:"description=Environment variables to set for the MCP server"`
  103. Args []string `json:"args,omitempty" jsonschema:"description=Arguments to pass to the MCP server command"`
  104. Type MCPType `json:"type" jsonschema:"required,description=Type of MCP connection,enum=stdio,enum=sse,enum=http,default=stdio"`
  105. URL string `json:"url,omitempty" jsonschema:"description=URL for HTTP or SSE MCP servers,format=uri,example=http://localhost:3000/mcp"`
  106. Disabled bool `json:"disabled,omitempty" jsonschema:"description=Whether this MCP server is disabled,default=false"`
  107. Timeout int `json:"timeout,omitempty" jsonschema:"description=Timeout in seconds for MCP server connections,default=15,example=30,example=60,example=120"`
  108. // TODO: maybe make it possible to get the value from the env
  109. Headers map[string]string `json:"headers,omitempty" jsonschema:"description=HTTP headers for HTTP/SSE MCP servers"`
  110. }
  111. type LSPConfig struct {
  112. Disabled bool `json:"disabled,omitempty" jsonschema:"description=Whether this LSP server is disabled,default=false"`
  113. Command string `json:"command,omitempty" jsonschema:"required,description=Command to execute for the LSP server,example=gopls"`
  114. Args []string `json:"args,omitempty" jsonschema:"description=Arguments to pass to the LSP server command"`
  115. Env map[string]string `json:"env,omitempty" jsonschema:"description=Environment variables to set to the LSP server command"`
  116. FileTypes []string `json:"filetypes,omitempty" jsonschema:"description=File types this LSP server handles,example=go,example=mod,example=rs,example=c,example=js,example=ts"`
  117. RootMarkers []string `json:"root_markers,omitempty" jsonschema:"description=Files or directories that indicate the project root,example=go.mod,example=package.json,example=Cargo.toml"`
  118. InitOptions map[string]any `json:"init_options,omitempty" jsonschema:"description=Initialization options passed to the LSP server during initialize request"`
  119. Options map[string]any `json:"options,omitempty" jsonschema:"description=LSP server-specific settings passed during initialization"`
  120. }
  121. type TUIOptions struct {
  122. CompactMode bool `json:"compact_mode,omitempty" jsonschema:"description=Enable compact mode for the TUI interface,default=false"`
  123. DiffMode string `json:"diff_mode,omitempty" jsonschema:"description=Diff mode for the TUI interface,enum=unified,enum=split"`
  124. // Here we can add themes later or any TUI related options
  125. //
  126. Completions Completions `json:"completions,omitzero" jsonschema:"description=Completions UI options"`
  127. }
  128. // Completions defines options for the completions UI.
  129. type Completions struct {
  130. MaxDepth *int `json:"max_depth,omitempty" jsonschema:"description=Maximum depth for the ls tool,default=0,example=10"`
  131. MaxItems *int `json:"max_items,omitempty" jsonschema:"description=Maximum number of items to return for the ls tool,default=1000,example=100"`
  132. }
  133. func (c Completions) Limits() (depth, items int) {
  134. return ptrValOr(c.MaxDepth, 0), ptrValOr(c.MaxItems, 0)
  135. }
  136. type Permissions struct {
  137. AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"description=List of tools that don't require permission prompts,example=bash,example=view"` // Tools that don't require permission prompts
  138. SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
  139. }
  140. type Attribution struct {
  141. CoAuthoredBy bool `json:"co_authored_by,omitempty" jsonschema:"description=Add Co-Authored-By trailer to commit messages,default=true"`
  142. GeneratedWith bool `json:"generated_with,omitempty" jsonschema:"description=Add Generated with Crush line to commit messages and issues and PRs,default=true"`
  143. }
  144. type Options struct {
  145. ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
  146. TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
  147. Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
  148. DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
  149. DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
  150. DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
  151. DisabledTools []string `json:"disabled_tools" jsonschema:"description=Tools to disable"`
  152. DisableProviderAutoUpdate bool `json:"disable_provider_auto_update,omitempty" jsonschema:"description=Disable providers auto-update,default=false"`
  153. Attribution *Attribution `json:"attribution,omitempty" jsonschema:"description=Attribution settings for generated content"`
  154. DisableMetrics bool `json:"disable_metrics,omitempty" jsonschema:"description=Disable sending metrics,default=false"`
  155. }
  156. type MCPs map[string]MCPConfig
  157. type MCP struct {
  158. Name string `json:"name"`
  159. MCP MCPConfig `json:"mcp"`
  160. }
  161. func (m MCPs) Sorted() []MCP {
  162. sorted := make([]MCP, 0, len(m))
  163. for k, v := range m {
  164. sorted = append(sorted, MCP{
  165. Name: k,
  166. MCP: v,
  167. })
  168. }
  169. slices.SortFunc(sorted, func(a, b MCP) int {
  170. return strings.Compare(a.Name, b.Name)
  171. })
  172. return sorted
  173. }
  174. type LSPs map[string]LSPConfig
  175. type LSP struct {
  176. Name string `json:"name"`
  177. LSP LSPConfig `json:"lsp"`
  178. }
  179. func (l LSPs) Sorted() []LSP {
  180. sorted := make([]LSP, 0, len(l))
  181. for k, v := range l {
  182. sorted = append(sorted, LSP{
  183. Name: k,
  184. LSP: v,
  185. })
  186. }
  187. slices.SortFunc(sorted, func(a, b LSP) int {
  188. return strings.Compare(a.Name, b.Name)
  189. })
  190. return sorted
  191. }
  192. func (l LSPConfig) ResolvedEnv() []string {
  193. return resolveEnvs(l.Env)
  194. }
  195. func (m MCPConfig) ResolvedEnv() []string {
  196. return resolveEnvs(m.Env)
  197. }
  198. func (m MCPConfig) ResolvedHeaders() map[string]string {
  199. resolver := NewShellVariableResolver(env.New())
  200. for e, v := range m.Headers {
  201. var err error
  202. m.Headers[e], err = resolver.ResolveValue(v)
  203. if err != nil {
  204. slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
  205. continue
  206. }
  207. }
  208. return m.Headers
  209. }
  210. type Agent struct {
  211. ID string `json:"id,omitempty"`
  212. Name string `json:"name,omitempty"`
  213. Description string `json:"description,omitempty"`
  214. // This is the id of the system prompt used by the agent
  215. Disabled bool `json:"disabled,omitempty"`
  216. Model SelectedModelType `json:"model" jsonschema:"required,description=The model type to use for this agent,enum=large,enum=small,default=large"`
  217. // The available tools for the agent
  218. // if this is nil, all tools are available
  219. AllowedTools []string `json:"allowed_tools,omitempty"`
  220. // this tells us which MCPs are available for this agent
  221. // if this is empty all mcps are available
  222. // the string array is the list of tools from the AllowedMCP the agent has available
  223. // if the string array is nil, all tools from the AllowedMCP are available
  224. AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
  225. // Overrides the context paths for this agent
  226. ContextPaths []string `json:"context_paths,omitempty"`
  227. }
  228. type Tools struct {
  229. Ls ToolLs `json:"ls,omitzero"`
  230. }
  231. type ToolLs struct {
  232. MaxDepth *int `json:"max_depth,omitempty" jsonschema:"description=Maximum depth for the ls tool,default=0,example=10"`
  233. MaxItems *int `json:"max_items,omitempty" jsonschema:"description=Maximum number of items to return for the ls tool,default=1000,example=100"`
  234. }
  235. func (t ToolLs) Limits() (depth, items int) {
  236. return ptrValOr(t.MaxDepth, 0), ptrValOr(t.MaxItems, 0)
  237. }
  238. // Config holds the configuration for crush.
  239. type Config struct {
  240. Schema string `json:"$schema,omitempty"`
  241. // We currently only support large/small as values here.
  242. Models map[SelectedModelType]SelectedModel `json:"models,omitempty" jsonschema:"description=Model configurations for different model types,example={\"large\":{\"model\":\"gpt-4o\",\"provider\":\"openai\"}}"`
  243. // The providers that are configured
  244. Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty" jsonschema:"description=AI provider configurations"`
  245. MCP MCPs `json:"mcp,omitempty" jsonschema:"description=Model Context Protocol server configurations"`
  246. LSP LSPs `json:"lsp,omitempty" jsonschema:"description=Language Server Protocol configurations"`
  247. Options *Options `json:"options,omitempty" jsonschema:"description=General application options"`
  248. Permissions *Permissions `json:"permissions,omitempty" jsonschema:"description=Permission settings for tool usage"`
  249. Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"`
  250. Agents map[string]Agent `json:"-"`
  251. // Internal
  252. workingDir string `json:"-"`
  253. // TODO: find a better way to do this this should probably not be part of the config
  254. resolver VariableResolver
  255. dataConfigDir string `json:"-"`
  256. knownProviders []catwalk.Provider `json:"-"`
  257. }
  258. func (c *Config) WorkingDir() string {
  259. return c.workingDir
  260. }
  261. func (c *Config) EnabledProviders() []ProviderConfig {
  262. var enabled []ProviderConfig
  263. for p := range c.Providers.Seq() {
  264. if !p.Disable {
  265. enabled = append(enabled, p)
  266. }
  267. }
  268. return enabled
  269. }
  270. // IsConfigured return true if at least one provider is configured
  271. func (c *Config) IsConfigured() bool {
  272. return len(c.EnabledProviders()) > 0
  273. }
  274. func (c *Config) GetModel(provider, model string) *catwalk.Model {
  275. if providerConfig, ok := c.Providers.Get(provider); ok {
  276. for _, m := range providerConfig.Models {
  277. if m.ID == model {
  278. return &m
  279. }
  280. }
  281. }
  282. return nil
  283. }
  284. func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
  285. model, ok := c.Models[modelType]
  286. if !ok {
  287. return nil
  288. }
  289. if providerConfig, ok := c.Providers.Get(model.Provider); ok {
  290. return &providerConfig
  291. }
  292. return nil
  293. }
  294. func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
  295. model, ok := c.Models[modelType]
  296. if !ok {
  297. return nil
  298. }
  299. return c.GetModel(model.Provider, model.Model)
  300. }
  301. func (c *Config) LargeModel() *catwalk.Model {
  302. model, ok := c.Models[SelectedModelTypeLarge]
  303. if !ok {
  304. return nil
  305. }
  306. return c.GetModel(model.Provider, model.Model)
  307. }
  308. func (c *Config) SmallModel() *catwalk.Model {
  309. model, ok := c.Models[SelectedModelTypeSmall]
  310. if !ok {
  311. return nil
  312. }
  313. return c.GetModel(model.Provider, model.Model)
  314. }
  315. func (c *Config) SetCompactMode(enabled bool) error {
  316. if c.Options == nil {
  317. c.Options = &Options{}
  318. }
  319. c.Options.TUI.CompactMode = enabled
  320. return c.SetConfigField("options.tui.compact_mode", enabled)
  321. }
  322. func (c *Config) Resolve(key string) (string, error) {
  323. if c.resolver == nil {
  324. return "", fmt.Errorf("no variable resolver configured")
  325. }
  326. return c.resolver.ResolveValue(key)
  327. }
  328. func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
  329. c.Models[modelType] = model
  330. if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
  331. return fmt.Errorf("failed to update preferred model: %w", err)
  332. }
  333. return nil
  334. }
  335. func (c *Config) SetConfigField(key string, value any) error {
  336. // read the data
  337. data, err := os.ReadFile(c.dataConfigDir)
  338. if err != nil {
  339. if os.IsNotExist(err) {
  340. data = []byte("{}")
  341. } else {
  342. return fmt.Errorf("failed to read config file: %w", err)
  343. }
  344. }
  345. newValue, err := sjson.Set(string(data), key, value)
  346. if err != nil {
  347. return fmt.Errorf("failed to set config field %s: %w", key, err)
  348. }
  349. if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil {
  350. return fmt.Errorf("failed to write config file: %w", err)
  351. }
  352. return nil
  353. }
  354. func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
  355. // First save to the config file
  356. err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
  357. if err != nil {
  358. return fmt.Errorf("failed to save API key to config file: %w", err)
  359. }
  360. providerConfig, exists := c.Providers.Get(providerID)
  361. if exists {
  362. providerConfig.APIKey = apiKey
  363. c.Providers.Set(providerID, providerConfig)
  364. return nil
  365. }
  366. var foundProvider *catwalk.Provider
  367. for _, p := range c.knownProviders {
  368. if string(p.ID) == providerID {
  369. foundProvider = &p
  370. break
  371. }
  372. }
  373. if foundProvider != nil {
  374. // Create new provider config based on known provider
  375. providerConfig = ProviderConfig{
  376. ID: providerID,
  377. Name: foundProvider.Name,
  378. BaseURL: foundProvider.APIEndpoint,
  379. Type: foundProvider.Type,
  380. APIKey: apiKey,
  381. Disable: false,
  382. ExtraHeaders: make(map[string]string),
  383. ExtraParams: make(map[string]string),
  384. Models: foundProvider.Models,
  385. }
  386. } else {
  387. return fmt.Errorf("provider with ID %s not found in known providers", providerID)
  388. }
  389. // Store the updated provider config
  390. c.Providers.Set(providerID, providerConfig)
  391. return nil
  392. }
  393. func allToolNames() []string {
  394. return []string{
  395. "agent",
  396. "bash",
  397. "download",
  398. "edit",
  399. "multiedit",
  400. "lsp_diagnostics",
  401. "lsp_references",
  402. "fetch",
  403. "agentic_fetch",
  404. "glob",
  405. "grep",
  406. "ls",
  407. "sourcegraph",
  408. "view",
  409. "write",
  410. }
  411. }
  412. func resolveAllowedTools(allTools []string, disabledTools []string) []string {
  413. if disabledTools == nil {
  414. return allTools
  415. }
  416. // filter out disabled tools (exclude mode)
  417. return filterSlice(allTools, disabledTools, false)
  418. }
  419. func resolveReadOnlyTools(tools []string) []string {
  420. readOnlyTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
  421. // filter to only include tools that are in allowedtools (include mode)
  422. return filterSlice(tools, readOnlyTools, true)
  423. }
  424. func filterSlice(data []string, mask []string, include bool) []string {
  425. filtered := []string{}
  426. for _, s := range data {
  427. // if include is true, we include items that ARE in the mask
  428. // if include is false, we include items that are NOT in the mask
  429. if include == slices.Contains(mask, s) {
  430. filtered = append(filtered, s)
  431. }
  432. }
  433. return filtered
  434. }
  435. func (c *Config) SetupAgents() {
  436. allowedTools := resolveAllowedTools(allToolNames(), c.Options.DisabledTools)
  437. agents := map[string]Agent{
  438. AgentCoder: {
  439. ID: AgentCoder,
  440. Name: "Coder",
  441. Description: "An agent that helps with executing coding tasks.",
  442. Model: SelectedModelTypeLarge,
  443. ContextPaths: c.Options.ContextPaths,
  444. AllowedTools: allowedTools,
  445. },
  446. AgentTask: {
  447. ID: AgentCoder,
  448. Name: "Task",
  449. Description: "An agent that helps with searching for context and finding implementation details.",
  450. Model: SelectedModelTypeLarge,
  451. ContextPaths: c.Options.ContextPaths,
  452. AllowedTools: resolveReadOnlyTools(allowedTools),
  453. // NO MCPs or LSPs by default
  454. AllowedMCP: map[string][]string{},
  455. },
  456. }
  457. c.Agents = agents
  458. }
  459. func (c *Config) Resolver() VariableResolver {
  460. return c.resolver
  461. }
  462. func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
  463. testURL := ""
  464. headers := make(map[string]string)
  465. apiKey, _ := resolver.ResolveValue(c.APIKey)
  466. switch c.Type {
  467. case catwalk.TypeOpenAI, catwalk.TypeOpenAICompat, catwalk.TypeOpenRouter:
  468. baseURL, _ := resolver.ResolveValue(c.BaseURL)
  469. if baseURL == "" {
  470. baseURL = "https://api.openai.com/v1"
  471. }
  472. if c.ID == string(catwalk.InferenceProviderOpenRouter) {
  473. testURL = baseURL + "/credits"
  474. } else {
  475. testURL = baseURL + "/models"
  476. }
  477. headers["Authorization"] = "Bearer " + apiKey
  478. case catwalk.TypeAnthropic:
  479. baseURL, _ := resolver.ResolveValue(c.BaseURL)
  480. if baseURL == "" {
  481. baseURL = "https://api.anthropic.com/v1"
  482. }
  483. testURL = baseURL + "/models"
  484. headers["x-api-key"] = apiKey
  485. headers["anthropic-version"] = "2023-06-01"
  486. case catwalk.TypeGoogle:
  487. baseURL, _ := resolver.ResolveValue(c.BaseURL)
  488. if baseURL == "" {
  489. baseURL = "https://generativelanguage.googleapis.com"
  490. }
  491. testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey)
  492. }
  493. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  494. defer cancel()
  495. client := &http.Client{}
  496. req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
  497. if err != nil {
  498. return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
  499. }
  500. for k, v := range headers {
  501. req.Header.Set(k, v)
  502. }
  503. for k, v := range c.ExtraHeaders {
  504. req.Header.Set(k, v)
  505. }
  506. b, err := client.Do(req)
  507. if err != nil {
  508. return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
  509. }
  510. if c.ID == string(catwalk.InferenceProviderZAI) {
  511. if b.StatusCode == http.StatusUnauthorized {
  512. // for z.ai just check if the http response is not 401
  513. return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
  514. }
  515. } else {
  516. if b.StatusCode != http.StatusOK {
  517. return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
  518. }
  519. }
  520. _ = b.Body.Close()
  521. return nil
  522. }
  523. func resolveEnvs(envs map[string]string) []string {
  524. resolver := NewShellVariableResolver(env.New())
  525. for e, v := range envs {
  526. var err error
  527. envs[e], err = resolver.ResolveValue(v)
  528. if err != nil {
  529. slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
  530. continue
  531. }
  532. }
  533. res := make([]string, 0, len(envs))
  534. for k, v := range envs {
  535. res = append(res, fmt.Sprintf("%s=%s", k, v))
  536. }
  537. return res
  538. }
  539. func ptrValOr[T any](t *T, el T) T {
  540. if t == nil {
  541. return el
  542. }
  543. return *t
  544. }