stats.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. package cmd
  2. import (
  3. "bytes"
  4. "context"
  5. "database/sql"
  6. _ "embed"
  7. "encoding/json"
  8. "fmt"
  9. "html/template"
  10. "os"
  11. "os/user"
  12. "path/filepath"
  13. "strings"
  14. "time"
  15. "github.com/charmbracelet/crush/internal/config"
  16. "github.com/charmbracelet/crush/internal/db"
  17. "github.com/charmbracelet/crush/internal/event"
  18. "github.com/pkg/browser"
  19. "github.com/spf13/cobra"
  20. )
  21. //go:embed stats/index.html
  22. var statsTemplate string
  23. //go:embed stats/index.css
  24. var statsCSS string
  25. //go:embed stats/index.js
  26. var statsJS string
  27. //go:embed stats/header.svg
  28. var headerSVG string
  29. //go:embed stats/heartbit.svg
  30. var heartbitSVG string
  31. //go:embed stats/footer.svg
  32. var footerSVG string
  33. var statsCmd = &cobra.Command{
  34. Use: "stats",
  35. Short: "Show usage statistics",
  36. Long: "Generate and display usage statistics including token usage, costs, and activity patterns",
  37. RunE: runStats,
  38. }
  39. // Day names for day of week statistics.
  40. var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
  41. // Stats holds all the statistics data.
  42. type Stats struct {
  43. GeneratedAt time.Time `json:"generated_at"`
  44. Total TotalStats `json:"total"`
  45. UsageByDay []DailyUsage `json:"usage_by_day"`
  46. UsageByModel []ModelUsage `json:"usage_by_model"`
  47. UsageByHour []HourlyUsage `json:"usage_by_hour"`
  48. UsageByDayOfWeek []DayOfWeekUsage `json:"usage_by_day_of_week"`
  49. RecentActivity []DailyActivity `json:"recent_activity"`
  50. AvgResponseTimeMs float64 `json:"avg_response_time_ms"`
  51. ToolUsage []ToolUsage `json:"tool_usage"`
  52. HourDayHeatmap []HourDayHeatmapPt `json:"hour_day_heatmap"`
  53. }
  54. type TotalStats struct {
  55. TotalSessions int64 `json:"total_sessions"`
  56. TotalPromptTokens int64 `json:"total_prompt_tokens"`
  57. TotalCompletionTokens int64 `json:"total_completion_tokens"`
  58. TotalTokens int64 `json:"total_tokens"`
  59. TotalCost float64 `json:"total_cost"`
  60. TotalMessages int64 `json:"total_messages"`
  61. AvgTokensPerSession float64 `json:"avg_tokens_per_session"`
  62. AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
  63. }
  64. type DailyUsage struct {
  65. Day string `json:"day"`
  66. PromptTokens int64 `json:"prompt_tokens"`
  67. CompletionTokens int64 `json:"completion_tokens"`
  68. TotalTokens int64 `json:"total_tokens"`
  69. Cost float64 `json:"cost"`
  70. SessionCount int64 `json:"session_count"`
  71. }
  72. type ModelUsage struct {
  73. Model string `json:"model"`
  74. Provider string `json:"provider"`
  75. MessageCount int64 `json:"message_count"`
  76. }
  77. type HourlyUsage struct {
  78. Hour int `json:"hour"`
  79. SessionCount int64 `json:"session_count"`
  80. }
  81. type DayOfWeekUsage struct {
  82. DayOfWeek int `json:"day_of_week"`
  83. DayName string `json:"day_name"`
  84. SessionCount int64 `json:"session_count"`
  85. PromptTokens int64 `json:"prompt_tokens"`
  86. CompletionTokens int64 `json:"completion_tokens"`
  87. }
  88. type DailyActivity struct {
  89. Day string `json:"day"`
  90. SessionCount int64 `json:"session_count"`
  91. TotalTokens int64 `json:"total_tokens"`
  92. Cost float64 `json:"cost"`
  93. }
  94. type ToolUsage struct {
  95. ToolName string `json:"tool_name"`
  96. CallCount int64 `json:"call_count"`
  97. }
  98. type HourDayHeatmapPt struct {
  99. DayOfWeek int `json:"day_of_week"`
  100. Hour int `json:"hour"`
  101. SessionCount int64 `json:"session_count"`
  102. }
  103. func runStats(cmd *cobra.Command, _ []string) error {
  104. event.StatsViewed()
  105. dataDir, _ := cmd.Flags().GetString("data-dir")
  106. ctx := cmd.Context()
  107. if dataDir == "" {
  108. cfg, err := config.Init("", "", false)
  109. if err != nil {
  110. return fmt.Errorf("failed to initialize config: %w", err)
  111. }
  112. dataDir = cfg.Config().Options.DataDirectory
  113. }
  114. conn, err := db.Connect(ctx, dataDir)
  115. if err != nil {
  116. return fmt.Errorf("failed to connect to database: %w", err)
  117. }
  118. defer conn.Close()
  119. stats, err := gatherStats(ctx, conn)
  120. if err != nil {
  121. return fmt.Errorf("failed to gather stats: %w", err)
  122. }
  123. if stats.Total.TotalSessions == 0 {
  124. return fmt.Errorf("no data available: no sessions found in database")
  125. }
  126. currentUser, err := user.Current()
  127. if err != nil {
  128. return fmt.Errorf("failed to get current user: %w", err)
  129. }
  130. username := currentUser.Username
  131. project, err := os.Getwd()
  132. if err != nil {
  133. return fmt.Errorf("failed to get current directory: %w", err)
  134. }
  135. project = strings.Replace(project, currentUser.HomeDir, "~", 1)
  136. htmlPath := filepath.Join(dataDir, "stats/index.html")
  137. if err := generateHTML(stats, project, username, htmlPath); err != nil {
  138. return fmt.Errorf("failed to generate HTML: %w", err)
  139. }
  140. fmt.Printf("Stats generated: %s\n", htmlPath)
  141. if err := browser.OpenFile(htmlPath); err != nil {
  142. fmt.Printf("Could not open browser: %v\n", err)
  143. fmt.Println("Please open the file manually.")
  144. }
  145. return nil
  146. }
  147. func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
  148. queries := db.New(conn)
  149. stats := &Stats{
  150. GeneratedAt: time.Now(),
  151. }
  152. // Total stats.
  153. total, err := queries.GetTotalStats(ctx)
  154. if err != nil {
  155. return nil, fmt.Errorf("get total stats: %w", err)
  156. }
  157. stats.Total = TotalStats{
  158. TotalSessions: total.TotalSessions,
  159. TotalPromptTokens: toInt64(total.TotalPromptTokens),
  160. TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
  161. TotalTokens: toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
  162. TotalCost: toFloat64(total.TotalCost),
  163. TotalMessages: toInt64(total.TotalMessages),
  164. AvgTokensPerSession: toFloat64(total.AvgTokensPerSession),
  165. AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
  166. }
  167. // Usage by day.
  168. dailyUsage, err := queries.GetUsageByDay(ctx)
  169. if err != nil {
  170. return nil, fmt.Errorf("get usage by day: %w", err)
  171. }
  172. for _, d := range dailyUsage {
  173. prompt := nullFloat64ToInt64(d.PromptTokens)
  174. completion := nullFloat64ToInt64(d.CompletionTokens)
  175. stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
  176. Day: fmt.Sprintf("%v", d.Day),
  177. PromptTokens: prompt,
  178. CompletionTokens: completion,
  179. TotalTokens: prompt + completion,
  180. Cost: d.Cost.Float64,
  181. SessionCount: d.SessionCount,
  182. })
  183. }
  184. // Usage by model.
  185. modelUsage, err := queries.GetUsageByModel(ctx)
  186. if err != nil {
  187. return nil, fmt.Errorf("get usage by model: %w", err)
  188. }
  189. for _, m := range modelUsage {
  190. stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
  191. Model: m.Model,
  192. Provider: m.Provider,
  193. MessageCount: m.MessageCount,
  194. })
  195. }
  196. // Usage by hour.
  197. hourlyUsage, err := queries.GetUsageByHour(ctx)
  198. if err != nil {
  199. return nil, fmt.Errorf("get usage by hour: %w", err)
  200. }
  201. for _, h := range hourlyUsage {
  202. stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
  203. Hour: int(h.Hour),
  204. SessionCount: h.SessionCount,
  205. })
  206. }
  207. // Usage by day of week.
  208. dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
  209. if err != nil {
  210. return nil, fmt.Errorf("get usage by day of week: %w", err)
  211. }
  212. for _, d := range dowUsage {
  213. stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
  214. DayOfWeek: int(d.DayOfWeek),
  215. DayName: dayNames[int(d.DayOfWeek)],
  216. SessionCount: d.SessionCount,
  217. PromptTokens: nullFloat64ToInt64(d.PromptTokens),
  218. CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
  219. })
  220. }
  221. // Recent activity (last 30 days).
  222. recent, err := queries.GetRecentActivity(ctx)
  223. if err != nil {
  224. return nil, fmt.Errorf("get recent activity: %w", err)
  225. }
  226. for _, r := range recent {
  227. stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
  228. Day: fmt.Sprintf("%v", r.Day),
  229. SessionCount: r.SessionCount,
  230. TotalTokens: nullFloat64ToInt64(r.TotalTokens),
  231. Cost: r.Cost.Float64,
  232. })
  233. }
  234. // Average response time.
  235. avgResp, err := queries.GetAverageResponseTime(ctx)
  236. if err != nil {
  237. return nil, fmt.Errorf("get average response time: %w", err)
  238. }
  239. stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
  240. // Tool usage.
  241. toolUsage, err := queries.GetToolUsage(ctx)
  242. if err != nil {
  243. return nil, fmt.Errorf("get tool usage: %w", err)
  244. }
  245. for _, t := range toolUsage {
  246. if name, ok := t.ToolName.(string); ok && name != "" {
  247. stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
  248. ToolName: name,
  249. CallCount: t.CallCount,
  250. })
  251. }
  252. }
  253. // Hour/day heatmap.
  254. heatmap, err := queries.GetHourDayHeatmap(ctx)
  255. if err != nil {
  256. return nil, fmt.Errorf("get hour day heatmap: %w", err)
  257. }
  258. for _, h := range heatmap {
  259. stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
  260. DayOfWeek: int(h.DayOfWeek),
  261. Hour: int(h.Hour),
  262. SessionCount: h.SessionCount,
  263. })
  264. }
  265. return stats, nil
  266. }
  267. func toInt64(v any) int64 {
  268. switch val := v.(type) {
  269. case int64:
  270. return val
  271. case float64:
  272. return int64(val)
  273. case int:
  274. return int64(val)
  275. default:
  276. return 0
  277. }
  278. }
  279. func toFloat64(v any) float64 {
  280. switch val := v.(type) {
  281. case float64:
  282. return val
  283. case int64:
  284. return float64(val)
  285. case int:
  286. return float64(val)
  287. default:
  288. return 0
  289. }
  290. }
  291. func nullFloat64ToInt64(n sql.NullFloat64) int64 {
  292. if n.Valid {
  293. return int64(n.Float64)
  294. }
  295. return 0
  296. }
  297. func generateHTML(stats *Stats, projName, username, path string) error {
  298. statsJSON, err := json.Marshal(stats)
  299. if err != nil {
  300. return err
  301. }
  302. tmpl, err := template.New("stats").Parse(statsTemplate)
  303. if err != nil {
  304. return fmt.Errorf("parse template: %w", err)
  305. }
  306. data := struct {
  307. StatsJSON template.JS
  308. CSS template.CSS
  309. JS template.JS
  310. Header template.HTML
  311. Heartbit template.HTML
  312. Footer template.HTML
  313. GeneratedAt string
  314. ProjectName string
  315. Username string
  316. }{
  317. StatsJSON: template.JS(statsJSON),
  318. CSS: template.CSS(statsCSS),
  319. JS: template.JS(statsJS),
  320. Header: template.HTML(headerSVG),
  321. Heartbit: template.HTML(heartbitSVG),
  322. Footer: template.HTML(footerSVG),
  323. GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
  324. ProjectName: projName,
  325. Username: username,
  326. }
  327. var buf bytes.Buffer
  328. if err := tmpl.Execute(&buf, data); err != nil {
  329. return fmt.Errorf("execute template: %w", err)
  330. }
  331. // Ensure parent directory exists.
  332. if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
  333. return fmt.Errorf("create directory: %w", err)
  334. }
  335. return os.WriteFile(path, buf.Bytes(), 0o644)
  336. }