stats.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. package cmd
  2. import (
  3. "bytes"
  4. "context"
  5. "database/sql"
  6. _ "embed"
  7. "encoding/json"
  8. "fmt"
  9. "html/template"
  10. "os"
  11. "os/user"
  12. "path/filepath"
  13. "strings"
  14. "time"
  15. "github.com/charmbracelet/crush/internal/config"
  16. "github.com/charmbracelet/crush/internal/db"
  17. "github.com/pkg/browser"
  18. "github.com/spf13/cobra"
  19. )
  20. //go:embed stats/index.html
  21. var statsTemplate string
  22. //go:embed stats/index.css
  23. var statsCSS string
  24. //go:embed stats/index.js
  25. var statsJS string
  26. //go:embed stats/header.svg
  27. var headerSVG string
  28. //go:embed stats/heartbit.svg
  29. var heartbitSVG string
  30. //go:embed stats/footer.svg
  31. var footerSVG string
  32. var statsCmd = &cobra.Command{
  33. Use: "stats",
  34. Short: "Show usage statistics",
  35. Long: "Generate and display usage statistics including token usage, costs, and activity patterns",
  36. RunE: runStats,
  37. }
  38. // Day names for day of week statistics.
  39. var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
  40. // Stats holds all the statistics data.
  41. type Stats struct {
  42. GeneratedAt time.Time `json:"generated_at"`
  43. Total TotalStats `json:"total"`
  44. UsageByDay []DailyUsage `json:"usage_by_day"`
  45. UsageByModel []ModelUsage `json:"usage_by_model"`
  46. UsageByHour []HourlyUsage `json:"usage_by_hour"`
  47. UsageByDayOfWeek []DayOfWeekUsage `json:"usage_by_day_of_week"`
  48. RecentActivity []DailyActivity `json:"recent_activity"`
  49. AvgResponseTimeMs float64 `json:"avg_response_time_ms"`
  50. ToolUsage []ToolUsage `json:"tool_usage"`
  51. HourDayHeatmap []HourDayHeatmapPt `json:"hour_day_heatmap"`
  52. }
  53. type TotalStats struct {
  54. TotalSessions int64 `json:"total_sessions"`
  55. TotalPromptTokens int64 `json:"total_prompt_tokens"`
  56. TotalCompletionTokens int64 `json:"total_completion_tokens"`
  57. TotalTokens int64 `json:"total_tokens"`
  58. TotalCost float64 `json:"total_cost"`
  59. TotalMessages int64 `json:"total_messages"`
  60. AvgTokensPerSession float64 `json:"avg_tokens_per_session"`
  61. AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
  62. }
  63. type DailyUsage struct {
  64. Day string `json:"day"`
  65. PromptTokens int64 `json:"prompt_tokens"`
  66. CompletionTokens int64 `json:"completion_tokens"`
  67. TotalTokens int64 `json:"total_tokens"`
  68. Cost float64 `json:"cost"`
  69. SessionCount int64 `json:"session_count"`
  70. }
  71. type ModelUsage struct {
  72. Model string `json:"model"`
  73. Provider string `json:"provider"`
  74. MessageCount int64 `json:"message_count"`
  75. }
  76. type HourlyUsage struct {
  77. Hour int `json:"hour"`
  78. SessionCount int64 `json:"session_count"`
  79. }
  80. type DayOfWeekUsage struct {
  81. DayOfWeek int `json:"day_of_week"`
  82. DayName string `json:"day_name"`
  83. SessionCount int64 `json:"session_count"`
  84. PromptTokens int64 `json:"prompt_tokens"`
  85. CompletionTokens int64 `json:"completion_tokens"`
  86. }
  87. type DailyActivity struct {
  88. Day string `json:"day"`
  89. SessionCount int64 `json:"session_count"`
  90. TotalTokens int64 `json:"total_tokens"`
  91. Cost float64 `json:"cost"`
  92. }
  93. type ToolUsage struct {
  94. ToolName string `json:"tool_name"`
  95. CallCount int64 `json:"call_count"`
  96. }
  97. type HourDayHeatmapPt struct {
  98. DayOfWeek int `json:"day_of_week"`
  99. Hour int `json:"hour"`
  100. SessionCount int64 `json:"session_count"`
  101. }
  102. func runStats(cmd *cobra.Command, _ []string) error {
  103. dataDir, _ := cmd.Flags().GetString("data-dir")
  104. ctx := cmd.Context()
  105. if dataDir == "" {
  106. cfg, err := config.Init("", "", false)
  107. if err != nil {
  108. return fmt.Errorf("failed to initialize config: %w", err)
  109. }
  110. dataDir = cfg.Options.DataDirectory
  111. }
  112. conn, err := db.Connect(ctx, dataDir)
  113. if err != nil {
  114. return fmt.Errorf("failed to connect to database: %w", err)
  115. }
  116. defer conn.Close()
  117. stats, err := gatherStats(ctx, conn)
  118. if err != nil {
  119. return fmt.Errorf("failed to gather stats: %w", err)
  120. }
  121. if stats.Total.TotalSessions == 0 {
  122. return fmt.Errorf("no data available: no sessions found in database")
  123. }
  124. currentUser, err := user.Current()
  125. if err != nil {
  126. return fmt.Errorf("failed to get current user: %w", err)
  127. }
  128. username := currentUser.Username
  129. project, err := os.Getwd()
  130. if err != nil {
  131. return fmt.Errorf("failed to get current directory: %w", err)
  132. }
  133. project = strings.Replace(project, currentUser.HomeDir, "~", 1)
  134. htmlPath := filepath.Join(dataDir, "stats/index.html")
  135. if err := generateHTML(stats, project, username, htmlPath); err != nil {
  136. return fmt.Errorf("failed to generate HTML: %w", err)
  137. }
  138. fmt.Printf("Stats generated: %s\n", htmlPath)
  139. if err := browser.OpenFile(htmlPath); err != nil {
  140. fmt.Printf("Could not open browser: %v\n", err)
  141. fmt.Println("Please open the file manually.")
  142. }
  143. return nil
  144. }
  145. func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
  146. queries := db.New(conn)
  147. stats := &Stats{
  148. GeneratedAt: time.Now(),
  149. }
  150. // Total stats.
  151. total, err := queries.GetTotalStats(ctx)
  152. if err != nil {
  153. return nil, fmt.Errorf("get total stats: %w", err)
  154. }
  155. stats.Total = TotalStats{
  156. TotalSessions: total.TotalSessions,
  157. TotalPromptTokens: toInt64(total.TotalPromptTokens),
  158. TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
  159. TotalTokens: toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
  160. TotalCost: toFloat64(total.TotalCost),
  161. TotalMessages: toInt64(total.TotalMessages),
  162. AvgTokensPerSession: toFloat64(total.AvgTokensPerSession),
  163. AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
  164. }
  165. // Usage by day.
  166. dailyUsage, err := queries.GetUsageByDay(ctx)
  167. if err != nil {
  168. return nil, fmt.Errorf("get usage by day: %w", err)
  169. }
  170. for _, d := range dailyUsage {
  171. prompt := nullFloat64ToInt64(d.PromptTokens)
  172. completion := nullFloat64ToInt64(d.CompletionTokens)
  173. stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
  174. Day: fmt.Sprintf("%v", d.Day),
  175. PromptTokens: prompt,
  176. CompletionTokens: completion,
  177. TotalTokens: prompt + completion,
  178. Cost: d.Cost.Float64,
  179. SessionCount: d.SessionCount,
  180. })
  181. }
  182. // Usage by model.
  183. modelUsage, err := queries.GetUsageByModel(ctx)
  184. if err != nil {
  185. return nil, fmt.Errorf("get usage by model: %w", err)
  186. }
  187. for _, m := range modelUsage {
  188. stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
  189. Model: m.Model,
  190. Provider: m.Provider,
  191. MessageCount: m.MessageCount,
  192. })
  193. }
  194. // Usage by hour.
  195. hourlyUsage, err := queries.GetUsageByHour(ctx)
  196. if err != nil {
  197. return nil, fmt.Errorf("get usage by hour: %w", err)
  198. }
  199. for _, h := range hourlyUsage {
  200. stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
  201. Hour: int(h.Hour),
  202. SessionCount: h.SessionCount,
  203. })
  204. }
  205. // Usage by day of week.
  206. dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
  207. if err != nil {
  208. return nil, fmt.Errorf("get usage by day of week: %w", err)
  209. }
  210. for _, d := range dowUsage {
  211. stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
  212. DayOfWeek: int(d.DayOfWeek),
  213. DayName: dayNames[int(d.DayOfWeek)],
  214. SessionCount: d.SessionCount,
  215. PromptTokens: nullFloat64ToInt64(d.PromptTokens),
  216. CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
  217. })
  218. }
  219. // Recent activity (last 30 days).
  220. recent, err := queries.GetRecentActivity(ctx)
  221. if err != nil {
  222. return nil, fmt.Errorf("get recent activity: %w", err)
  223. }
  224. for _, r := range recent {
  225. stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
  226. Day: fmt.Sprintf("%v", r.Day),
  227. SessionCount: r.SessionCount,
  228. TotalTokens: nullFloat64ToInt64(r.TotalTokens),
  229. Cost: r.Cost.Float64,
  230. })
  231. }
  232. // Average response time.
  233. avgResp, err := queries.GetAverageResponseTime(ctx)
  234. if err != nil {
  235. return nil, fmt.Errorf("get average response time: %w", err)
  236. }
  237. stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
  238. // Tool usage.
  239. toolUsage, err := queries.GetToolUsage(ctx)
  240. if err != nil {
  241. return nil, fmt.Errorf("get tool usage: %w", err)
  242. }
  243. for _, t := range toolUsage {
  244. if name, ok := t.ToolName.(string); ok && name != "" {
  245. stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
  246. ToolName: name,
  247. CallCount: t.CallCount,
  248. })
  249. }
  250. }
  251. // Hour/day heatmap.
  252. heatmap, err := queries.GetHourDayHeatmap(ctx)
  253. if err != nil {
  254. return nil, fmt.Errorf("get hour day heatmap: %w", err)
  255. }
  256. for _, h := range heatmap {
  257. stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
  258. DayOfWeek: int(h.DayOfWeek),
  259. Hour: int(h.Hour),
  260. SessionCount: h.SessionCount,
  261. })
  262. }
  263. return stats, nil
  264. }
  265. func toInt64(v any) int64 {
  266. switch val := v.(type) {
  267. case int64:
  268. return val
  269. case float64:
  270. return int64(val)
  271. case int:
  272. return int64(val)
  273. default:
  274. return 0
  275. }
  276. }
  277. func toFloat64(v any) float64 {
  278. switch val := v.(type) {
  279. case float64:
  280. return val
  281. case int64:
  282. return float64(val)
  283. case int:
  284. return float64(val)
  285. default:
  286. return 0
  287. }
  288. }
  289. func nullFloat64ToInt64(n sql.NullFloat64) int64 {
  290. if n.Valid {
  291. return int64(n.Float64)
  292. }
  293. return 0
  294. }
  295. func generateHTML(stats *Stats, projName, username, path string) error {
  296. statsJSON, err := json.Marshal(stats)
  297. if err != nil {
  298. return err
  299. }
  300. tmpl, err := template.New("stats").Parse(statsTemplate)
  301. if err != nil {
  302. return fmt.Errorf("parse template: %w", err)
  303. }
  304. data := struct {
  305. StatsJSON template.JS
  306. CSS template.CSS
  307. JS template.JS
  308. Header template.HTML
  309. Heartbit template.HTML
  310. Footer template.HTML
  311. GeneratedAt string
  312. ProjectName string
  313. Username string
  314. }{
  315. StatsJSON: template.JS(statsJSON),
  316. CSS: template.CSS(statsCSS),
  317. JS: template.JS(statsJS),
  318. Header: template.HTML(headerSVG),
  319. Heartbit: template.HTML(heartbitSVG),
  320. Footer: template.HTML(footerSVG),
  321. GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
  322. ProjectName: projName,
  323. Username: username,
  324. }
  325. var buf bytes.Buffer
  326. if err := tmpl.Execute(&buf, data); err != nil {
  327. return fmt.Errorf("execute template: %w", err)
  328. }
  329. // Ensure parent directory exists.
  330. if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
  331. return fmt.Errorf("create directory: %w", err)
  332. }
  333. return os.WriteFile(path, buf.Bytes(), 0o644)
  334. }