|
|
@@ -226,3 +226,79 @@ func TestFindProviderByID(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+// TestModelSelectionPriority tests the priority order for model selection
|
|
|
+func TestModelSelectionPriority(t *testing.T) {
|
|
|
+ providers := []opencode.Provider{
|
|
|
+ {
|
|
|
+ ID: "anthropic",
|
|
|
+ Models: map[string]opencode.Model{
|
|
|
+ "claude-opus": {ID: "claude-opus"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ ID: "openai",
|
|
|
+ Models: map[string]opencode.Model{
|
|
|
+ "gpt-4": {ID: "gpt-4"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ agentProviderID string
|
|
|
+ agentModelID string
|
|
|
+ configModel string
|
|
|
+ expectedProviderID string
|
|
|
+ expectedModelID string
|
|
|
+ description string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "agent model takes priority over config",
|
|
|
+ agentProviderID: "openai",
|
|
|
+ agentModelID: "gpt-4",
|
|
|
+ configModel: "anthropic/claude-opus",
|
|
|
+ expectedProviderID: "openai",
|
|
|
+ expectedModelID: "gpt-4",
|
|
|
+ description: "When agent specifies a model, it should be used even if config has a different model",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "config model used when agent has no model",
|
|
|
+ agentProviderID: "",
|
|
|
+ agentModelID: "",
|
|
|
+ configModel: "anthropic/claude-opus",
|
|
|
+ expectedProviderID: "anthropic",
|
|
|
+ expectedModelID: "claude-opus",
|
|
|
+ description: "When agent has no model specified, config model should be used as fallback",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ var selectedProvider *opencode.Provider
|
|
|
+ var selectedModel *opencode.Model
|
|
|
+
|
|
|
+ // Simulate priority 2: Agent model check
|
|
|
+ if tt.agentModelID != "" {
|
|
|
+ selectedProvider, selectedModel = findModelByProviderAndModelID(providers, tt.agentProviderID, tt.agentModelID)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Simulate priority 3: Config model fallback
|
|
|
+ if selectedProvider == nil && tt.configModel != "" {
|
|
|
+ selectedProvider, selectedModel = findModelByFullID(providers, tt.configModel)
|
|
|
+ }
|
|
|
+
|
|
|
+ if selectedProvider == nil || selectedModel == nil {
|
|
|
+ t.Fatalf("Expected to find model, but got nil - %s", tt.description)
|
|
|
+ }
|
|
|
+
|
|
|
+ if selectedProvider.ID != tt.expectedProviderID {
|
|
|
+ t.Errorf("Expected provider %s, got %s - %s", tt.expectedProviderID, selectedProvider.ID, tt.description)
|
|
|
+ }
|
|
|
+
|
|
|
+ if selectedModel.ID != tt.expectedModelID {
|
|
|
+ t.Errorf("Expected model %s, got %s - %s", tt.expectedModelID, selectedModel.ID, tt.description)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|