config.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package config
  2. import (
  3. "fmt"
  4. "os"
  5. "strings"
  6. "github.com/kujtimiihoxha/termai/internal/llm/models"
  7. "github.com/spf13/viper"
  8. )
  9. type MCPType string
  10. const (
  11. MCPStdio MCPType = "stdio"
  12. MCPSse MCPType = "sse"
  13. )
  14. type MCPServer struct {
  15. Command string `json:"command"`
  16. Env []string `json:"env"`
  17. Args []string `json:"args"`
  18. Type MCPType `json:"type"`
  19. URL string `json:"url"`
  20. Headers map[string]string `json:"headers"`
  21. // TODO: add permissions configuration
  22. // TODO: add the ability to specify the tools to import
  23. }
  24. type Model struct {
  25. Coder models.ModelID `json:"coder"`
  26. CoderMaxTokens int64 `json:"coderMaxTokens"`
  27. Task models.ModelID `json:"task"`
  28. TaskMaxTokens int64 `json:"taskMaxTokens"`
  29. // TODO: Maybe support multiple models for different purposes
  30. }
  31. type Provider struct {
  32. APIKey string `json:"apiKey"`
  33. Enabled bool `json:"enabled"`
  34. }
  35. type Data struct {
  36. Directory string `json:"directory"`
  37. }
  38. type Log struct {
  39. Level string `json:"level"`
  40. }
  41. type LSPConfig struct {
  42. Disabled bool `json:"enabled"`
  43. Command string `json:"command"`
  44. Args []string `json:"args"`
  45. Options any `json:"options"`
  46. }
  47. type Config struct {
  48. Data *Data `json:"data,omitempty"`
  49. Log *Log `json:"log,omitempty"`
  50. MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
  51. Providers map[models.ModelProvider]Provider `json:"providers,omitempty"`
  52. LSP map[string]LSPConfig `json:"lsp,omitempty"`
  53. Model *Model `json:"model,omitempty"`
  54. }
  55. var cfg *Config
  56. const (
  57. defaultDataDirectory = ".termai"
  58. defaultLogLevel = "info"
  59. defaultMaxTokens = int64(5000)
  60. termai = "termai"
  61. )
  62. func Load(debug bool) error {
  63. if cfg != nil {
  64. return nil
  65. }
  66. viper.SetConfigName(fmt.Sprintf(".%s", termai))
  67. viper.SetConfigType("json")
  68. viper.AddConfigPath("$HOME")
  69. viper.AddConfigPath(fmt.Sprintf("$XDG_CONFIG_HOME/%s", termai))
  70. viper.SetEnvPrefix(strings.ToUpper(termai))
  71. // Add defaults
  72. viper.SetDefault("data.directory", defaultDataDirectory)
  73. if debug {
  74. viper.Set("log.level", "debug")
  75. } else {
  76. viper.SetDefault("log.level", defaultLogLevel)
  77. }
  78. defaultModelSet := false
  79. if os.Getenv("ANTHROPIC_API_KEY") != "" {
  80. viper.SetDefault("providers.anthropic.apiKey", os.Getenv("ANTHROPIC_API_KEY"))
  81. viper.SetDefault("providers.anthropic.enabled", true)
  82. viper.SetDefault("model.coder", models.Claude37Sonnet)
  83. viper.SetDefault("model.task", models.Claude37Sonnet)
  84. defaultModelSet = true
  85. }
  86. if os.Getenv("OPENAI_API_KEY") != "" {
  87. viper.SetDefault("providers.openai.apiKey", os.Getenv("OPENAI_API_KEY"))
  88. viper.SetDefault("providers.openai.enabled", true)
  89. if !defaultModelSet {
  90. viper.SetDefault("model.coder", models.GPT4o)
  91. viper.SetDefault("model.task", models.GPT4o)
  92. defaultModelSet = true
  93. }
  94. }
  95. if os.Getenv("GEMINI_API_KEY") != "" {
  96. viper.SetDefault("providers.gemini.apiKey", os.Getenv("GEMINI_API_KEY"))
  97. viper.SetDefault("providers.gemini.enabled", true)
  98. if !defaultModelSet {
  99. viper.SetDefault("model.coder", models.GRMINI20Flash)
  100. viper.SetDefault("model.task", models.GRMINI20Flash)
  101. defaultModelSet = true
  102. }
  103. }
  104. if os.Getenv("GROQ_API_KEY") != "" {
  105. viper.SetDefault("providers.groq.apiKey", os.Getenv("GROQ_API_KEY"))
  106. viper.SetDefault("providers.groq.enabled", true)
  107. if !defaultModelSet {
  108. viper.SetDefault("model.coder", models.QWENQwq)
  109. viper.SetDefault("model.task", models.QWENQwq)
  110. defaultModelSet = true
  111. }
  112. }
  113. // TODO: add more providers
  114. cfg = &Config{}
  115. err := viper.ReadInConfig()
  116. if err != nil {
  117. if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
  118. return err
  119. }
  120. }
  121. local := viper.New()
  122. local.SetConfigName(fmt.Sprintf(".%s", termai))
  123. local.SetConfigType("json")
  124. local.AddConfigPath(".")
  125. // load local config, this will override the global config
  126. if err = local.ReadInConfig(); err == nil {
  127. viper.MergeConfigMap(local.AllSettings())
  128. }
  129. viper.Unmarshal(cfg)
  130. if cfg.Model != nil && cfg.Model.CoderMaxTokens <= 0 {
  131. cfg.Model.CoderMaxTokens = defaultMaxTokens
  132. }
  133. if cfg.Model != nil && cfg.Model.TaskMaxTokens <= 0 {
  134. cfg.Model.TaskMaxTokens = defaultMaxTokens
  135. }
  136. for _, v := range cfg.MCPServers {
  137. if v.Type == "" {
  138. v.Type = MCPStdio
  139. }
  140. }
  141. workdir, err := os.Getwd()
  142. if err != nil {
  143. return err
  144. }
  145. viper.Set("wd", workdir)
  146. return nil
  147. }
  148. func Get() *Config {
  149. if cfg == nil {
  150. err := Load(false)
  151. if err != nil {
  152. panic(err)
  153. }
  154. }
  155. return cfg
  156. }
  157. func WorkingDirectory() string {
  158. return viper.GetString("wd")
  159. }
  160. func Write() error {
  161. return viper.WriteConfig()
  162. }