root.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package cmd
  2. import (
  3. "context"
  4. "os"
  5. "sync"
  6. tea "github.com/charmbracelet/bubbletea"
  7. "github.com/kujtimiihoxha/termai/internal/app"
  8. "github.com/kujtimiihoxha/termai/internal/db"
  9. "github.com/kujtimiihoxha/termai/internal/llm/models"
  10. "github.com/kujtimiihoxha/termai/internal/tui"
  11. "github.com/spf13/cobra"
  12. "github.com/spf13/viper"
  13. )
  14. var rootCmd = &cobra.Command{
  15. Use: "termai",
  16. Short: "A terminal ai assistant",
  17. Long: `A terminal ai assistant`,
  18. RunE: func(cmd *cobra.Command, args []string) error {
  19. if cmd.Flag("help").Changed {
  20. cmd.Help()
  21. return nil
  22. }
  23. debug, _ := cmd.Flags().GetBool("debug")
  24. viper.Set("debug", debug)
  25. if debug {
  26. viper.Set("log.level", "debug")
  27. }
  28. conn, err := db.Connect()
  29. if err != nil {
  30. return err
  31. }
  32. ctx := context.Background()
  33. app := app.New(ctx, conn)
  34. app.Logger.Info("Starting termai...")
  35. tui := tea.NewProgram(
  36. tui.New(app),
  37. tea.WithAltScreen(),
  38. )
  39. app.Logger.Info("Setting up subscriptions...")
  40. ch, unsub := setupSubscriptions(app)
  41. defer unsub()
  42. go func() {
  43. for msg := range ch {
  44. tui.Send(msg)
  45. }
  46. }()
  47. if _, err := tui.Run(); err != nil {
  48. return err
  49. }
  50. return nil
  51. },
  52. }
  53. func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
  54. ch := make(chan tea.Msg)
  55. wg := sync.WaitGroup{}
  56. ctx, cancel := context.WithCancel(app.Context)
  57. {
  58. sub := app.Logger.Subscribe(ctx)
  59. wg.Add(1)
  60. go func() {
  61. for ev := range sub {
  62. ch <- ev
  63. }
  64. wg.Done()
  65. }()
  66. }
  67. {
  68. sub := app.Sessions.Subscribe(ctx)
  69. wg.Add(1)
  70. go func() {
  71. for ev := range sub {
  72. ch <- ev
  73. }
  74. wg.Done()
  75. }()
  76. }
  77. {
  78. sub := app.Messages.Subscribe(ctx)
  79. wg.Add(1)
  80. go func() {
  81. for ev := range sub {
  82. ch <- ev
  83. }
  84. wg.Done()
  85. }()
  86. }
  87. {
  88. sub := app.LLM.Subscribe(ctx)
  89. wg.Add(1)
  90. go func() {
  91. for ev := range sub {
  92. ch <- ev
  93. }
  94. wg.Done()
  95. }()
  96. }
  97. return ch, func() {
  98. cancel()
  99. wg.Wait()
  100. close(ch)
  101. }
  102. }
  103. // Execute adds all child commands to the root command and sets flags appropriately.
  104. // This is called by main.main(). It only needs to happen once to the rootCmd.
  105. func Execute() {
  106. err := rootCmd.Execute()
  107. if err != nil {
  108. os.Exit(1)
  109. }
  110. }
  111. func loadConfig() {
  112. viper.SetConfigName(".termai")
  113. viper.SetConfigType("yaml")
  114. viper.AddConfigPath("$HOME")
  115. viper.AddConfigPath("$XDG_CONFIG_HOME/termai")
  116. viper.AddConfigPath(".")
  117. viper.SetEnvPrefix("TERMAI")
  118. // SET DEFAULTS
  119. viper.SetDefault("log.level", "info")
  120. viper.SetDefault("data.dir", ".termai")
  121. // LLM
  122. viper.SetDefault("models.big", string(models.DefaultBigModel))
  123. viper.SetDefault("models.little", string(models.DefaultLittleModel))
  124. viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY"))
  125. viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY"))
  126. viper.SetDefault("providers.common.max_tokens", 4000)
  127. viper.SetDefault("agents.default", "coder")
  128. //
  129. viper.ReadInConfig()
  130. workdir, err := os.Getwd()
  131. if err != nil {
  132. panic(err)
  133. }
  134. viper.Set("wd", workdir)
  135. }
  136. func init() {
  137. loadConfig()
  138. rootCmd.Flags().BoolP("help", "h", false, "Help")
  139. rootCmd.Flags().BoolP("debug", "d", false, "Help")
  140. }