Просмотр исходного кода

fix: agent model selection priority issue (#3572)

oribi 3 месяцев назад
Родитель
Сommit
2d5df3ad76
2 измененных файлов с 89 добавлено и 13 удалено
  1. 13 13
      packages/tui/internal/app/app.go
  2. 76 0
      packages/tui/internal/app/app_test.go

+ 13 - 13
packages/tui/internal/app/app.go

@@ -490,19 +490,7 @@ func (a *App) InitializeProvider() tea.Cmd {
 		}
 	}
 
-	// Priority 2: Config file model setting
-	if selectedProvider == nil && a.Config.Model != "" {
-		if provider, model := findModelByFullID(providers, a.Config.Model); provider != nil &&
-			model != nil {
-			selectedProvider = provider
-			selectedModel = model
-			slog.Debug("Selected model from config", "provider", provider.ID, "model", model.ID)
-		} else {
-			slog.Debug("Config model not found", "model", a.Config.Model)
-		}
-	}
-
-	// Priority 3: Current agent's preferred model
+	// Priority 2: Current agent's preferred model
 	if selectedProvider == nil && a.Agent().Model.ModelID != "" {
 		if provider, model := findModelByProviderAndModelID(providers, a.Agent().Model.ProviderID, a.Agent().Model.ModelID); provider != nil &&
 			model != nil {
@@ -522,6 +510,18 @@ func (a *App) InitializeProvider() tea.Cmd {
 		}
 	}
 
+	// Priority 3: Config file model setting
+	if selectedProvider == nil && a.Config.Model != "" {
+		if provider, model := findModelByFullID(providers, a.Config.Model); provider != nil &&
+			model != nil {
+			selectedProvider = provider
+			selectedModel = model
+			slog.Debug("Selected model from config", "provider", provider.ID, "model", model.ID)
+		} else {
+			slog.Debug("Config model not found", "model", a.Config.Model)
+		}
+	}
+
 	// Priority 4: Recent model usage (most recently used model)
 	if selectedProvider == nil && len(a.State.RecentlyUsedModels) > 0 {
 		recentUsage := a.State.RecentlyUsedModels[0] // Most recent is first

+ 76 - 0
packages/tui/internal/app/app_test.go

@@ -226,3 +226,79 @@ func TestFindProviderByID(t *testing.T) {
 		})
 	}
 }
+
+// TestModelSelectionPriority tests the priority order for model selection
+func TestModelSelectionPriority(t *testing.T) {
+	providers := []opencode.Provider{
+		{
+			ID: "anthropic",
+			Models: map[string]opencode.Model{
+				"claude-opus": {ID: "claude-opus"},
+			},
+		},
+		{
+			ID: "openai",
+			Models: map[string]opencode.Model{
+				"gpt-4": {ID: "gpt-4"},
+			},
+		},
+	}
+
+	tests := []struct {
+		name               string
+		agentProviderID    string
+		agentModelID       string
+		configModel        string
+		expectedProviderID string
+		expectedModelID    string
+		description        string
+	}{
+		{
+			name:               "agent model takes priority over config",
+			agentProviderID:    "openai",
+			agentModelID:       "gpt-4",
+			configModel:        "anthropic/claude-opus",
+			expectedProviderID: "openai",
+			expectedModelID:    "gpt-4",
+			description:        "When agent specifies a model, it should be used even if config has a different model",
+		},
+		{
+			name:               "config model used when agent has no model",
+			agentProviderID:    "",
+			agentModelID:       "",
+			configModel:        "anthropic/claude-opus",
+			expectedProviderID: "anthropic",
+			expectedModelID:    "claude-opus",
+			description:        "When agent has no model specified, config model should be used as fallback",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			var selectedProvider *opencode.Provider
+			var selectedModel *opencode.Model
+
+			// Simulate priority 2: Agent model check
+			if tt.agentModelID != "" {
+				selectedProvider, selectedModel = findModelByProviderAndModelID(providers, tt.agentProviderID, tt.agentModelID)
+			}
+
+			// Simulate priority 3: Config model fallback
+			if selectedProvider == nil && tt.configModel != "" {
+				selectedProvider, selectedModel = findModelByFullID(providers, tt.configModel)
+			}
+
+			if selectedProvider == nil || selectedModel == nil {
+				t.Fatalf("Expected to find model, but got nil - %s", tt.description)
+			}
+
+			if selectedProvider.ID != tt.expectedProviderID {
+				t.Errorf("Expected provider %s, got %s - %s", tt.expectedProviderID, selectedProvider.ID, tt.description)
+			}
+
+			if selectedModel.ID != tt.expectedModelID {
+				t.Errorf("Expected model %s, got %s - %s", tt.expectedModelID, selectedModel.ID, tt.description)
+			}
+		})
+	}
+}