state.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package app
  2. import (
  3. "bufio"
  4. "fmt"
  5. "log/slog"
  6. "os"
  7. "time"
  8. "github.com/BurntSushi/toml"
  9. )
  10. type ModelUsage struct {
  11. ProviderID string `toml:"provider_id"`
  12. ModelID string `toml:"model_id"`
  13. LastUsed time.Time `toml:"last_used"`
  14. }
  15. type AgentModel struct {
  16. ProviderID string `toml:"provider_id"`
  17. ModelID string `toml:"model_id"`
  18. }
  19. type State struct {
  20. Theme string `toml:"theme"`
  21. ScrollSpeed *int `toml:"scroll_speed"`
  22. AgentModel map[string]AgentModel `toml:"agent_model"`
  23. Provider string `toml:"provider"`
  24. Model string `toml:"model"`
  25. Agent string `toml:"agent"`
  26. RecentlyUsedModels []ModelUsage `toml:"recently_used_models"`
  27. MessagesRight bool `toml:"messages_right"`
  28. SplitDiff bool `toml:"split_diff"`
  29. MessageHistory []Prompt `toml:"message_history"`
  30. }
  31. func NewState() *State {
  32. return &State{
  33. Theme: "opencode",
  34. Agent: "build",
  35. AgentModel: make(map[string]AgentModel),
  36. RecentlyUsedModels: make([]ModelUsage, 0),
  37. MessageHistory: make([]Prompt, 0),
  38. }
  39. }
  40. // UpdateModelUsage updates the recently used models list with the specified model
  41. func (s *State) UpdateModelUsage(providerID, modelID string) {
  42. now := time.Now()
  43. // Check if this model is already in the list
  44. for i, usage := range s.RecentlyUsedModels {
  45. if usage.ProviderID == providerID && usage.ModelID == modelID {
  46. s.RecentlyUsedModels[i].LastUsed = now
  47. usage := s.RecentlyUsedModels[i]
  48. copy(s.RecentlyUsedModels[1:i+1], s.RecentlyUsedModels[0:i])
  49. s.RecentlyUsedModels[0] = usage
  50. return
  51. }
  52. }
  53. newUsage := ModelUsage{
  54. ProviderID: providerID,
  55. ModelID: modelID,
  56. LastUsed: now,
  57. }
  58. // Prepend to slice and limit to last 50 entries
  59. s.RecentlyUsedModels = append([]ModelUsage{newUsage}, s.RecentlyUsedModels...)
  60. if len(s.RecentlyUsedModels) > 50 {
  61. s.RecentlyUsedModels = s.RecentlyUsedModels[:50]
  62. }
  63. }
  64. func (s *State) RemoveModelFromRecentlyUsed(providerID, modelID string) {
  65. for i, usage := range s.RecentlyUsedModels {
  66. if usage.ProviderID == providerID && usage.ModelID == modelID {
  67. s.RecentlyUsedModels = append(s.RecentlyUsedModels[:i], s.RecentlyUsedModels[i+1:]...)
  68. return
  69. }
  70. }
  71. }
  72. func (s *State) AddPromptToHistory(prompt Prompt) {
  73. s.MessageHistory = append([]Prompt{prompt}, s.MessageHistory...)
  74. if len(s.MessageHistory) > 50 {
  75. s.MessageHistory = s.MessageHistory[:50]
  76. }
  77. }
  78. // SaveState writes the provided Config struct to the specified TOML file.
  79. // It will create the file if it doesn't exist, or overwrite it if it does.
  80. func SaveState(filePath string, state *State) error {
  81. file, err := os.Create(filePath)
  82. if err != nil {
  83. return fmt.Errorf("failed to create/open config file %s: %w", filePath, err)
  84. }
  85. defer file.Close()
  86. writer := bufio.NewWriter(file)
  87. encoder := toml.NewEncoder(writer)
  88. if err := encoder.Encode(state); err != nil {
  89. return fmt.Errorf("failed to encode state to TOML file %s: %w", filePath, err)
  90. }
  91. if err := writer.Flush(); err != nil {
  92. return fmt.Errorf("failed to flush writer for state file %s: %w", filePath, err)
  93. }
  94. slog.Debug("State saved to file", "file", filePath)
  95. return nil
  96. }
  97. // LoadState loads the state from the specified TOML file.
  98. // It returns a pointer to the State struct and an error if any issues occur.
  99. func LoadState(filePath string) (*State, error) {
  100. var state State
  101. if _, err := toml.DecodeFile(filePath, &state); err != nil {
  102. if _, statErr := os.Stat(filePath); os.IsNotExist(statErr) {
  103. return nil, fmt.Errorf("state file not found at %s: %w", filePath, statErr)
  104. }
  105. return nil, fmt.Errorf("failed to decode TOML from file %s: %w", filePath, err)
  106. }
  107. // Restore attachment sources types that were deserialized as map[string]any
  108. for _, prompt := range state.MessageHistory {
  109. for _, att := range prompt.Attachments {
  110. att.RestoreSourceType()
  111. }
  112. }
  113. return &state, nil
  114. }