stats.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. package cmd
  2. import (
  3. "bytes"
  4. "context"
  5. "database/sql"
  6. _ "embed"
  7. "encoding/json"
  8. "fmt"
  9. "html/template"
  10. "os"
  11. "os/user"
  12. "path/filepath"
  13. "strings"
  14. "time"
  15. "github.com/charmbracelet/crush/internal/config"
  16. "github.com/charmbracelet/crush/internal/db"
  17. "github.com/charmbracelet/crush/internal/event"
  18. "github.com/pkg/browser"
  19. "github.com/spf13/cobra"
  20. )
  21. //go:embed stats/index.html
  22. var statsTemplate string
  23. //go:embed stats/index.css
  24. var statsCSS string
  25. //go:embed stats/index.js
  26. var statsJS string
  27. //go:embed stats/header.svg
  28. var headerSVG string
  29. //go:embed stats/heartbit.svg
  30. var heartbitSVG string
  31. //go:embed stats/footer.svg
  32. var footerSVG string
  33. var statsCmd = &cobra.Command{
  34. Use: "stats",
  35. Short: "Show usage statistics",
  36. Long: "Generate and display usage statistics including token usage, costs, and activity patterns",
  37. RunE: runStats,
  38. }
  39. // Day names for day of week statistics.
  40. var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
  41. // Stats holds all the statistics data.
  42. type Stats struct {
  43. GeneratedAt time.Time `json:"generated_at"`
  44. Total TotalStats `json:"total"`
  45. UsageByDay []DailyUsage `json:"usage_by_day"`
  46. UsageByModel []ModelUsage `json:"usage_by_model"`
  47. UsageByHour []HourlyUsage `json:"usage_by_hour"`
  48. UsageByDayOfWeek []DayOfWeekUsage `json:"usage_by_day_of_week"`
  49. RecentActivity []DailyActivity `json:"recent_activity"`
  50. AvgResponseTimeMs float64 `json:"avg_response_time_ms"`
  51. ToolUsage []ToolUsage `json:"tool_usage"`
  52. HourDayHeatmap []HourDayHeatmapPt `json:"hour_day_heatmap"`
  53. }
  54. type TotalStats struct {
  55. TotalSessions int64 `json:"total_sessions"`
  56. TotalPromptTokens int64 `json:"total_prompt_tokens"`
  57. TotalCompletionTokens int64 `json:"total_completion_tokens"`
  58. TotalTokens int64 `json:"total_tokens"`
  59. TotalCost float64 `json:"total_cost"`
  60. TotalMessages int64 `json:"total_messages"`
  61. AvgTokensPerSession float64 `json:"avg_tokens_per_session"`
  62. AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
  63. }
  64. type DailyUsage struct {
  65. Day string `json:"day"`
  66. PromptTokens int64 `json:"prompt_tokens"`
  67. CompletionTokens int64 `json:"completion_tokens"`
  68. TotalTokens int64 `json:"total_tokens"`
  69. Cost float64 `json:"cost"`
  70. SessionCount int64 `json:"session_count"`
  71. }
  72. type ModelUsage struct {
  73. Model string `json:"model"`
  74. Provider string `json:"provider"`
  75. MessageCount int64 `json:"message_count"`
  76. }
  77. type HourlyUsage struct {
  78. Hour int `json:"hour"`
  79. SessionCount int64 `json:"session_count"`
  80. }
  81. type DayOfWeekUsage struct {
  82. DayOfWeek int `json:"day_of_week"`
  83. DayName string `json:"day_name"`
  84. SessionCount int64 `json:"session_count"`
  85. PromptTokens int64 `json:"prompt_tokens"`
  86. CompletionTokens int64 `json:"completion_tokens"`
  87. }
  88. type DailyActivity struct {
  89. Day string `json:"day"`
  90. SessionCount int64 `json:"session_count"`
  91. TotalTokens int64 `json:"total_tokens"`
  92. Cost float64 `json:"cost"`
  93. }
  94. type ToolUsage struct {
  95. ToolName string `json:"tool_name"`
  96. CallCount int64 `json:"call_count"`
  97. }
  98. type HourDayHeatmapPt struct {
  99. DayOfWeek int `json:"day_of_week"`
  100. Hour int `json:"hour"`
  101. SessionCount int64 `json:"session_count"`
  102. }
  103. func runStats(cmd *cobra.Command, _ []string) error {
  104. dataDir, _ := cmd.Flags().GetString("data-dir")
  105. ctx := cmd.Context()
  106. cfg, err := config.Init("", dataDir, false)
  107. if err != nil {
  108. return fmt.Errorf("failed to initialize config: %w", err)
  109. }
  110. if dataDir == "" {
  111. dataDir = cfg.Config().Options.DataDirectory
  112. }
  113. if shouldEnableMetrics(cfg.Config()) {
  114. event.Init()
  115. }
  116. event.StatsViewed()
  117. conn, err := db.Connect(ctx, dataDir)
  118. if err != nil {
  119. return fmt.Errorf("failed to connect to database: %w", err)
  120. }
  121. defer conn.Close()
  122. stats, err := gatherStats(ctx, conn)
  123. if err != nil {
  124. return fmt.Errorf("failed to gather stats: %w", err)
  125. }
  126. if stats.Total.TotalSessions == 0 {
  127. return fmt.Errorf("no data available: no sessions found in database")
  128. }
  129. currentUser, err := user.Current()
  130. if err != nil {
  131. return fmt.Errorf("failed to get current user: %w", err)
  132. }
  133. username := currentUser.Username
  134. project, err := os.Getwd()
  135. if err != nil {
  136. return fmt.Errorf("failed to get current directory: %w", err)
  137. }
  138. project = strings.Replace(project, currentUser.HomeDir, "~", 1)
  139. htmlPath := filepath.Join(dataDir, "stats/index.html")
  140. if err := generateHTML(stats, project, username, htmlPath); err != nil {
  141. return fmt.Errorf("failed to generate HTML: %w", err)
  142. }
  143. fmt.Printf("Stats generated: %s\n", htmlPath)
  144. if err := browser.OpenFile(htmlPath); err != nil {
  145. fmt.Printf("Could not open browser: %v\n", err)
  146. fmt.Println("Please open the file manually.")
  147. }
  148. return nil
  149. }
  150. func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
  151. queries := db.New(conn)
  152. stats := &Stats{
  153. GeneratedAt: time.Now(),
  154. }
  155. // Total stats.
  156. total, err := queries.GetTotalStats(ctx)
  157. if err != nil {
  158. return nil, fmt.Errorf("get total stats: %w", err)
  159. }
  160. stats.Total = TotalStats{
  161. TotalSessions: total.TotalSessions,
  162. TotalPromptTokens: toInt64(total.TotalPromptTokens),
  163. TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
  164. TotalTokens: toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
  165. TotalCost: toFloat64(total.TotalCost),
  166. TotalMessages: toInt64(total.TotalMessages),
  167. AvgTokensPerSession: toFloat64(total.AvgTokensPerSession),
  168. AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
  169. }
  170. // Usage by day.
  171. dailyUsage, err := queries.GetUsageByDay(ctx)
  172. if err != nil {
  173. return nil, fmt.Errorf("get usage by day: %w", err)
  174. }
  175. for _, d := range dailyUsage {
  176. prompt := nullFloat64ToInt64(d.PromptTokens)
  177. completion := nullFloat64ToInt64(d.CompletionTokens)
  178. stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
  179. Day: fmt.Sprintf("%v", d.Day),
  180. PromptTokens: prompt,
  181. CompletionTokens: completion,
  182. TotalTokens: prompt + completion,
  183. Cost: d.Cost.Float64,
  184. SessionCount: d.SessionCount,
  185. })
  186. }
  187. // Usage by model.
  188. modelUsage, err := queries.GetUsageByModel(ctx)
  189. if err != nil {
  190. return nil, fmt.Errorf("get usage by model: %w", err)
  191. }
  192. for _, m := range modelUsage {
  193. stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
  194. Model: m.Model,
  195. Provider: m.Provider,
  196. MessageCount: m.MessageCount,
  197. })
  198. }
  199. // Usage by hour.
  200. hourlyUsage, err := queries.GetUsageByHour(ctx)
  201. if err != nil {
  202. return nil, fmt.Errorf("get usage by hour: %w", err)
  203. }
  204. for _, h := range hourlyUsage {
  205. stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
  206. Hour: int(h.Hour),
  207. SessionCount: h.SessionCount,
  208. })
  209. }
  210. // Usage by day of week.
  211. dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
  212. if err != nil {
  213. return nil, fmt.Errorf("get usage by day of week: %w", err)
  214. }
  215. for _, d := range dowUsage {
  216. stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
  217. DayOfWeek: int(d.DayOfWeek),
  218. DayName: dayNames[int(d.DayOfWeek)],
  219. SessionCount: d.SessionCount,
  220. PromptTokens: nullFloat64ToInt64(d.PromptTokens),
  221. CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
  222. })
  223. }
  224. // Recent activity (last 30 days).
  225. recent, err := queries.GetRecentActivity(ctx)
  226. if err != nil {
  227. return nil, fmt.Errorf("get recent activity: %w", err)
  228. }
  229. for _, r := range recent {
  230. stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
  231. Day: fmt.Sprintf("%v", r.Day),
  232. SessionCount: r.SessionCount,
  233. TotalTokens: nullFloat64ToInt64(r.TotalTokens),
  234. Cost: r.Cost.Float64,
  235. })
  236. }
  237. // Average response time.
  238. avgResp, err := queries.GetAverageResponseTime(ctx)
  239. if err != nil {
  240. return nil, fmt.Errorf("get average response time: %w", err)
  241. }
  242. stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
  243. // Tool usage.
  244. toolUsage, err := queries.GetToolUsage(ctx)
  245. if err != nil {
  246. return nil, fmt.Errorf("get tool usage: %w", err)
  247. }
  248. for _, t := range toolUsage {
  249. if name, ok := t.ToolName.(string); ok && name != "" {
  250. stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
  251. ToolName: name,
  252. CallCount: t.CallCount,
  253. })
  254. }
  255. }
  256. // Hour/day heatmap.
  257. heatmap, err := queries.GetHourDayHeatmap(ctx)
  258. if err != nil {
  259. return nil, fmt.Errorf("get hour day heatmap: %w", err)
  260. }
  261. for _, h := range heatmap {
  262. stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
  263. DayOfWeek: int(h.DayOfWeek),
  264. Hour: int(h.Hour),
  265. SessionCount: h.SessionCount,
  266. })
  267. }
  268. return stats, nil
  269. }
  270. func toInt64(v any) int64 {
  271. switch val := v.(type) {
  272. case int64:
  273. return val
  274. case float64:
  275. return int64(val)
  276. case int:
  277. return int64(val)
  278. default:
  279. return 0
  280. }
  281. }
  282. func toFloat64(v any) float64 {
  283. switch val := v.(type) {
  284. case float64:
  285. return val
  286. case int64:
  287. return float64(val)
  288. case int:
  289. return float64(val)
  290. default:
  291. return 0
  292. }
  293. }
  294. func nullFloat64ToInt64(n sql.NullFloat64) int64 {
  295. if n.Valid {
  296. return int64(n.Float64)
  297. }
  298. return 0
  299. }
  300. func generateHTML(stats *Stats, projName, username, path string) error {
  301. statsJSON, err := json.Marshal(stats)
  302. if err != nil {
  303. return err
  304. }
  305. tmpl, err := template.New("stats").Parse(statsTemplate)
  306. if err != nil {
  307. return fmt.Errorf("parse template: %w", err)
  308. }
  309. data := struct {
  310. StatsJSON template.JS
  311. CSS template.CSS
  312. JS template.JS
  313. Header template.HTML
  314. Heartbit template.HTML
  315. Footer template.HTML
  316. GeneratedAt string
  317. ProjectName string
  318. Username string
  319. }{
  320. StatsJSON: template.JS(statsJSON),
  321. CSS: template.CSS(statsCSS),
  322. JS: template.JS(statsJS),
  323. Header: template.HTML(headerSVG),
  324. Heartbit: template.HTML(heartbitSVG),
  325. Footer: template.HTML(footerSVG),
  326. GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
  327. ProjectName: projName,
  328. Username: username,
  329. }
  330. var buf bytes.Buffer
  331. if err := tmpl.Execute(&buf, data); err != nil {
  332. return fmt.Errorf("execute template: %w", err)
  333. }
  334. // Ensure parent directory exists.
  335. if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
  336. return fmt.Errorf("create directory: %w", err)
  337. }
  338. return os.WriteFile(path, buf.Bytes(), 0o644)
  339. }