session_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package session
  2. import (
  3. "database/sql"
  4. "testing"
  5. "github.com/charmbracelet/crush/internal/config"
  6. "github.com/charmbracelet/crush/internal/db"
  7. "github.com/stretchr/testify/require"
  8. )
  9. func TestMarshalModels(t *testing.T) {
  10. t.Parallel()
  11. t.Run("empty", func(t *testing.T) {
  12. t.Parallel()
  13. result, err := marshalModels(map[config.SelectedModelType]config.SelectedModel{})
  14. require.NoError(t, err)
  15. require.Equal(t, "", result)
  16. })
  17. t.Run("nil", func(t *testing.T) {
  18. t.Parallel()
  19. result, err := marshalModels(nil)
  20. require.NoError(t, err)
  21. require.Equal(t, "", result)
  22. })
  23. t.Run("single entry", func(t *testing.T) {
  24. t.Parallel()
  25. models := map[config.SelectedModelType]config.SelectedModel{
  26. config.SelectedModelTypeLarge: {
  27. Model: "claude-sonnet-4-20250514",
  28. Provider: "anthropic",
  29. },
  30. }
  31. result, err := marshalModels(models)
  32. require.NoError(t, err)
  33. require.Contains(t, result, "claude-sonnet-4-20250514")
  34. require.Contains(t, result, "anthropic")
  35. })
  36. t.Run("round-trip", func(t *testing.T) {
  37. t.Parallel()
  38. temp := 0.7
  39. topP := 0.9
  40. topK := int64(50)
  41. freqPen := 0.1
  42. presPen := 0.2
  43. models := map[config.SelectedModelType]config.SelectedModel{
  44. config.SelectedModelTypeLarge: {
  45. Model: "gpt-4o",
  46. Provider: "openai",
  47. ReasoningEffort: "high",
  48. Think: true,
  49. MaxTokens: 4096,
  50. Temperature: &temp,
  51. TopP: &topP,
  52. TopK: &topK,
  53. FrequencyPenalty: &freqPen,
  54. PresencePenalty: &presPen,
  55. ProviderOptions: map[string]any{"key": "value"},
  56. },
  57. config.SelectedModelTypeSmall: {
  58. Model: "gpt-4o-mini",
  59. Provider: "openai",
  60. },
  61. }
  62. data, err := marshalModels(models)
  63. require.NoError(t, err)
  64. result, err := unmarshalModels(data)
  65. require.NoError(t, err)
  66. require.Equal(t, models, result)
  67. })
  68. }
  69. func TestUnmarshalModels(t *testing.T) {
  70. t.Parallel()
  71. t.Run("empty string", func(t *testing.T) {
  72. t.Parallel()
  73. result, err := unmarshalModels("")
  74. require.NoError(t, err)
  75. require.Nil(t, result)
  76. })
  77. t.Run("valid JSON", func(t *testing.T) {
  78. t.Parallel()
  79. data := `{"large":{"model":"gpt-4o","provider":"openai"}}`
  80. result, err := unmarshalModels(data)
  81. require.NoError(t, err)
  82. require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
  83. require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
  84. })
  85. t.Run("invalid JSON", func(t *testing.T) {
  86. t.Parallel()
  87. _, err := unmarshalModels("{invalid}")
  88. require.Error(t, err)
  89. })
  90. }
  91. func TestFromDBItemWithModels(t *testing.T) {
  92. t.Parallel()
  93. t.Run("null models", func(t *testing.T) {
  94. t.Parallel()
  95. item := testDBSession()
  96. item.Models = sql.NullString{Valid: false}
  97. result := service{}.fromDBItem(item)
  98. require.Nil(t, result.Models)
  99. })
  100. t.Run("empty models", func(t *testing.T) {
  101. t.Parallel()
  102. item := testDBSession()
  103. item.Models = sql.NullString{String: "", Valid: true}
  104. result := service{}.fromDBItem(item)
  105. require.Nil(t, result.Models)
  106. })
  107. t.Run("valid models", func(t *testing.T) {
  108. t.Parallel()
  109. item := testDBSession()
  110. item.Models = sql.NullString{
  111. String: `{"large":{"model":"gpt-4o","provider":"openai"}}`,
  112. Valid: true,
  113. }
  114. result := service{}.fromDBItem(item)
  115. require.NotNil(t, result.Models)
  116. require.Equal(t, "gpt-4o", result.Models[config.SelectedModelTypeLarge].Model)
  117. })
  118. t.Run("invalid JSON models", func(t *testing.T) {
  119. t.Parallel()
  120. item := testDBSession()
  121. item.Models = sql.NullString{
  122. String: "{invalid}",
  123. Valid: true,
  124. }
  125. result := service{}.fromDBItem(item)
  126. require.Nil(t, result.Models)
  127. })
  128. }
  129. func testDBSession() db.Session {
  130. return db.Session{
  131. ID: "test-id",
  132. Title: "Test",
  133. }
  134. }