trajectory.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package cmd
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "os"
  6. "github.com/charmbracelet/crush/internal/config"
  7. "github.com/charmbracelet/crush/internal/db"
  8. "github.com/charmbracelet/crush/internal/message"
  9. "github.com/charmbracelet/crush/internal/session"
  10. "github.com/charmbracelet/crush/internal/trajectory"
  11. "github.com/charmbracelet/crush/internal/version"
  12. "github.com/spf13/cobra"
  13. )
  14. var trajectoryCmd = &cobra.Command{
  15. Use: "trajectory",
  16. Short: "Trajectory export utilities",
  17. Long: "Export session trajectories in Harbor ATIF format for analysis and sharing",
  18. }
  19. var trajectoryExportCmd = &cobra.Command{
  20. Use: "export",
  21. Short: "Export a session as ATIF trajectory",
  22. Long: "Export a Crush session in Harbor ATIF (Agent Trajectory Interchange Format) v1.4",
  23. Example: `
  24. # Export a session as JSON to stdout
  25. crush trajectory export --session <session-id>
  26. # Export a session to a JSON file
  27. crush trajectory export --session <session-id> --output trajectory.json
  28. # Export as HTML for visualization
  29. crush trajectory export --session <session-id> --format html --output trajectory.html
  30. # Validate with Harbor validator
  31. crush trajectory export --session <session-id> > out.json
  32. python -m harbor.utils.trajectory_validator out.json
  33. `,
  34. RunE: func(cmd *cobra.Command, args []string) error {
  35. sessionID, _ := cmd.Flags().GetString("session")
  36. outputFile, _ := cmd.Flags().GetString("output")
  37. format, _ := cmd.Flags().GetString("format")
  38. dataDir, _ := cmd.Flags().GetString("data-dir")
  39. if sessionID == "" {
  40. return fmt.Errorf("--session flag is required")
  41. }
  42. ctx := cmd.Context()
  43. cwd, err := ResolveCwd(cmd)
  44. if err != nil {
  45. return err
  46. }
  47. // Load config (lightweight, no full app init).
  48. cfg, err := config.Load(cwd, dataDir, false)
  49. if err != nil {
  50. return fmt.Errorf("failed to load config: %w", err)
  51. }
  52. // Connect to DB.
  53. conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
  54. if err != nil {
  55. return fmt.Errorf("failed to connect to database: %w", err)
  56. }
  57. defer conn.Close()
  58. querier := db.New(conn)
  59. sessionSvc := session.NewService(querier)
  60. messageSvc := message.NewService(querier)
  61. // Load session.
  62. sess, err := sessionSvc.Get(ctx, sessionID)
  63. if err != nil {
  64. return fmt.Errorf("failed to get session: %w", err)
  65. }
  66. // Load messages.
  67. messages, err := messageSvc.List(ctx, sessionID)
  68. if err != nil {
  69. return fmt.Errorf("failed to list messages: %w", err)
  70. }
  71. // Determine model name from first assistant message.
  72. var modelName string
  73. for _, msg := range messages {
  74. if msg.Role == message.Assistant && msg.Model != "" {
  75. modelName = msg.Model
  76. break
  77. }
  78. }
  79. // Export to ATIF.
  80. traj, err := trajectory.ExportSession(sess, messages, "Crush", version.Version, modelName)
  81. if err != nil {
  82. return fmt.Errorf("failed to export trajectory: %w", err)
  83. }
  84. var data []byte
  85. switch format {
  86. case "html":
  87. data, err = trajectory.RenderHTML(traj)
  88. if err != nil {
  89. return fmt.Errorf("failed to render HTML: %w", err)
  90. }
  91. case "json":
  92. data, err = json.MarshalIndent(traj, "", " ")
  93. if err != nil {
  94. return fmt.Errorf("failed to marshal trajectory: %w", err)
  95. }
  96. default:
  97. return fmt.Errorf("unknown format: %s (use 'json' or 'html')", format)
  98. }
  99. // Write output.
  100. if outputFile != "" {
  101. if err := os.WriteFile(outputFile, data, 0o644); err != nil {
  102. return fmt.Errorf("failed to write output file: %w", err)
  103. }
  104. cmd.Printf("Exported trajectory to %s\n", outputFile)
  105. } else {
  106. cmd.Println(string(data))
  107. }
  108. return nil
  109. },
  110. }
  111. func init() {
  112. trajectoryExportCmd.Flags().StringP("session", "s", "", "Session ID to export (required)")
  113. trajectoryExportCmd.Flags().StringP("output", "o", "", "Output file path (defaults to stdout)")
  114. trajectoryExportCmd.Flags().StringP("format", "f", "json", "Output format: json or html")
  115. _ = trajectoryExportCmd.MarkFlagRequired("session")
  116. trajectoryCmd.AddCommand(trajectoryExportCmd)
  117. }