Procházet zdrojové kódy

Choose model according to the docs (#1536)

Dominik Engelhardt před 6 měsíci
rodič
revize
42a5fcead4

+ 109 - 52
packages/tui/internal/app/app.go

@@ -270,37 +270,58 @@ func (a *App) SwitchModeReverse() (*App, tea.Cmd) {
 	return a.cycleMode(false)
 }
 
-func (a *App) InitializeProvider() tea.Cmd {
-	providersResponse, err := a.Client.App.Providers(context.Background())
-	if err != nil {
-		slog.Error("Failed to list providers", "error", err)
-		// TODO: notify user
-		return nil
+// findModelByFullID finds a model by its full ID in the format "provider/model"
+func findModelByFullID(providers []opencode.Provider, fullModelID string) (*opencode.Provider, *opencode.Model) {
+	modelParts := strings.SplitN(fullModelID, "/", 2)
+	if len(modelParts) < 2 {
+		return nil, nil
 	}
-	providers := providersResponse.Providers
-	var defaultProvider *opencode.Provider
-	var defaultModel *opencode.Model
 
-	var anthropic *opencode.Provider
+	providerID := modelParts[0]
+	modelID := modelParts[1]
+
+	return findModelByProviderAndModelID(providers, providerID, modelID)
+}
+
+// findModelByProviderAndModelID finds a model by provider ID and model ID
+func findModelByProviderAndModelID(providers []opencode.Provider, providerID, modelID string) (*opencode.Provider, *opencode.Model) {
 	for _, provider := range providers {
-		if provider.ID == "anthropic" {
-			anthropic = &provider
+		if provider.ID != providerID {
+			continue
 		}
-	}
 
-	// default to anthropic if available
-	if anthropic != nil {
-		defaultProvider = anthropic
-		defaultModel = getDefaultModel(providersResponse, *anthropic)
+		for _, model := range provider.Models {
+			if model.ID == modelID {
+				return &provider, &model
+			}
+		}
+
+		// Provider found but model not found
+		return nil, nil
 	}
 
+	// Provider not found
+	return nil, nil
+}
+
+// findProviderByID finds a provider by its ID
+func findProviderByID(providers []opencode.Provider, providerID string) *opencode.Provider {
 	for _, provider := range providers {
-		if defaultProvider == nil || defaultModel == nil {
-			defaultProvider = &provider
-			defaultModel = getDefaultModel(providersResponse, provider)
+		if provider.ID == providerID {
+			return &provider
 		}
-		providers = append(providers, provider)
 	}
+	return nil
+}
+
+func (a *App) InitializeProvider() tea.Cmd {
+	providersResponse, err := a.Client.App.Providers(context.Background())
+	if err != nil {
+		slog.Error("Failed to list providers", "error", err)
+		// TODO: notify user
+		return nil
+	}
+	providers := providersResponse.Providers
 	if len(providers) == 0 {
 		slog.Error("No providers configured")
 		return nil
@@ -314,50 +335,86 @@ func (a *App) InitializeProvider() tea.Cmd {
 		a.State.Model = model.ModelID
 	}
 
-	var currentProvider *opencode.Provider
-	var currentModel *opencode.Model
-	for _, provider := range providers {
-		if provider.ID == a.State.Provider {
-			currentProvider = &provider
+	var selectedProvider *opencode.Provider
+	var selectedModel *opencode.Model
 
-			for _, model := range provider.Models {
-				if model.ID == a.State.Model {
-					currentModel = &model
-				}
-			}
+	// Priority 1: Command line --model flag (InitialModel)
+	if a.InitialModel != nil && *a.InitialModel != "" {
+		if provider, model := findModelByFullID(providers, *a.InitialModel); provider != nil && model != nil {
+			selectedProvider = provider
+			selectedModel = model
+			slog.Debug("Selected model from command line", "provider", provider.ID, "model", model.ID)
+		} else {
+			slog.Debug("Command line model not found", "model", *a.InitialModel)
 		}
 	}
-	if currentProvider == nil || currentModel == nil {
-		currentProvider = defaultProvider
-		currentModel = defaultModel
+
+	// Priority 2: Config file model setting
+	if selectedProvider == nil && a.Config.Model != "" {
+		if provider, model := findModelByFullID(providers, a.Config.Model); provider != nil && model != nil {
+			selectedProvider = provider
+			selectedModel = model
+			slog.Debug("Selected model from config", "provider", provider.ID, "model", model.ID)
+		} else {
+			slog.Debug("Config model not found", "model", a.Config.Model)
+		}
 	}
 
-	var initialProvider *opencode.Provider
-	var initialModel *opencode.Model
-	if a.InitialModel != nil && *a.InitialModel != "" {
-		splits := strings.Split(*a.InitialModel, "/")
-		for _, provider := range providers {
-			if provider.ID == splits[0] {
-				initialProvider = &provider
-				for _, model := range provider.Models {
-					modelID := strings.Join(splits[1:], "/")
-					if model.ID == modelID {
-						initialModel = &model
-					}
-				}
+	// Priority 3: Recent model usage (most recently used model)
+	if selectedProvider == nil && len(a.State.RecentlyUsedModels) > 0 {
+		recentUsage := a.State.RecentlyUsedModels[0] // Most recent is first
+		if provider, model := findModelByProviderAndModelID(providers, recentUsage.ProviderID, recentUsage.ModelID); provider != nil && model != nil {
+			selectedProvider = provider
+			selectedModel = model
+			slog.Debug("Selected model from recent usage", "provider", provider.ID, "model", model.ID)
+		} else {
+			slog.Debug("Recent model not found", "provider", recentUsage.ProviderID, "model", recentUsage.ModelID)
+		}
+	}
+
+	// Priority 4: State-based model (backwards compatibility)
+	if selectedProvider == nil && a.State.Provider != "" && a.State.Model != "" {
+		if provider, model := findModelByProviderAndModelID(providers, a.State.Provider, a.State.Model); provider != nil && model != nil {
+			selectedProvider = provider
+			selectedModel = model
+			slog.Debug("Selected model from state", "provider", provider.ID, "model", model.ID)
+		} else {
+			slog.Debug("State model not found", "provider", a.State.Provider, "model", a.State.Model)
+		}
+	}
+
+	// Priority 5: Internal priority fallback (Anthropic preferred, then first available)
+	if selectedProvider == nil {
+		// Try Anthropic first as internal priority
+		if provider := findProviderByID(providers, "anthropic"); provider != nil {
+			if model := getDefaultModel(providersResponse, *provider); model != nil {
+				selectedProvider = provider
+				selectedModel = model
+				slog.Debug("Selected model from internal priority (Anthropic)", "provider", provider.ID, "model", model.ID)
+			}
+		}
+
+		// If Anthropic not available, use first available provider
+		if selectedProvider == nil && len(providers) > 0 {
+			provider := &providers[0]
+			if model := getDefaultModel(providersResponse, *provider); model != nil {
+				selectedProvider = provider
+				selectedModel = model
+				slog.Debug("Selected model from fallback (first available)", "provider", provider.ID, "model", model.ID)
 			}
 		}
 	}
 
-	if initialProvider != nil && initialModel != nil {
-		currentProvider = initialProvider
-		currentModel = initialModel
+	// Final safety check
+	if selectedProvider == nil || selectedModel == nil {
+		slog.Error("Failed to select any model")
+		return nil
 	}
 
 	var cmds []tea.Cmd
 	cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{
-		Provider: *currentProvider,
-		Model:    *currentModel,
+		Provider: *selectedProvider,
+		Model:    *selectedModel,
 	}))
 	if a.InitialPrompt != nil && *a.InitialPrompt != "" {
 		cmds = append(cmds, util.CmdHandler(SendPrompt{Text: *a.InitialPrompt}))

+ 228 - 0
packages/tui/internal/app/app_test.go

@@ -0,0 +1,228 @@
+package app
+
+import (
+	"testing"
+
+	"github.com/sst/opencode-sdk-go"
+)
+
+// TestFindModelByFullID tests the findModelByFullID function
+func TestFindModelByFullID(t *testing.T) {
+	// Create test providers with models
+	providers := []opencode.Provider{
+		{
+			ID: "anthropic",
+			Models: map[string]opencode.Model{
+				"claude-3-opus-20240229":   {ID: "claude-3-opus-20240229"},
+				"claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
+			},
+		},
+		{
+			ID: "openai",
+			Models: map[string]opencode.Model{
+				"gpt-4":         {ID: "gpt-4"},
+				"gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
+			},
+		},
+	}
+
+	tests := []struct {
+		name               string
+		fullModelID        string
+		expectedFound      bool
+		expectedProviderID string
+		expectedModelID    string
+	}{
+		{
+			name:               "valid full model ID",
+			fullModelID:        "anthropic/claude-3-opus-20240229",
+			expectedFound:      true,
+			expectedProviderID: "anthropic",
+			expectedModelID:    "claude-3-opus-20240229",
+		},
+		{
+			name:               "valid full model ID with slash in model name",
+			fullModelID:        "openai/gpt-3.5-turbo",
+			expectedFound:      true,
+			expectedProviderID: "openai",
+			expectedModelID:    "gpt-3.5-turbo",
+		},
+		{
+			name:          "invalid format - missing slash",
+			fullModelID:   "anthropic",
+			expectedFound: false,
+		},
+		{
+			name:          "invalid format - empty string",
+			fullModelID:   "",
+			expectedFound: false,
+		},
+		{
+			name:          "provider not found",
+			fullModelID:   "nonexistent/model",
+			expectedFound: false,
+		},
+		{
+			name:          "model not found",
+			fullModelID:   "anthropic/nonexistent-model",
+			expectedFound: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			provider, model := findModelByFullID(providers, tt.fullModelID)
+
+			if tt.expectedFound {
+				if provider == nil || model == nil {
+					t.Errorf("Expected to find provider/model, but got nil")
+					return
+				}
+
+				if provider.ID != tt.expectedProviderID {
+					t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
+				}
+
+				if model.ID != tt.expectedModelID {
+					t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
+				}
+			} else {
+				if provider != nil || model != nil {
+					t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
+				}
+			}
+		})
+	}
+}
+
+// TestFindModelByProviderAndModelID tests the findModelByProviderAndModelID function
+func TestFindModelByProviderAndModelID(t *testing.T) {
+	// Create test providers with models
+	providers := []opencode.Provider{
+		{
+			ID: "anthropic",
+			Models: map[string]opencode.Model{
+				"claude-3-opus-20240229":   {ID: "claude-3-opus-20240229"},
+				"claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
+			},
+		},
+		{
+			ID: "openai",
+			Models: map[string]opencode.Model{
+				"gpt-4":         {ID: "gpt-4"},
+				"gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
+			},
+		},
+	}
+
+	tests := []struct {
+		name               string
+		providerID         string
+		modelID            string
+		expectedFound      bool
+		expectedProviderID string
+		expectedModelID    string
+	}{
+		{
+			name:               "valid provider and model",
+			providerID:         "anthropic",
+			modelID:            "claude-3-opus-20240229",
+			expectedFound:      true,
+			expectedProviderID: "anthropic",
+			expectedModelID:    "claude-3-opus-20240229",
+		},
+		{
+			name:          "provider not found",
+			providerID:    "nonexistent",
+			modelID:       "claude-3-opus-20240229",
+			expectedFound: false,
+		},
+		{
+			name:          "model not found",
+			providerID:    "anthropic",
+			modelID:       "nonexistent-model",
+			expectedFound: false,
+		},
+		{
+			name:          "both provider and model not found",
+			providerID:    "nonexistent",
+			modelID:       "nonexistent-model",
+			expectedFound: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			provider, model := findModelByProviderAndModelID(providers, tt.providerID, tt.modelID)
+
+			if tt.expectedFound {
+				if provider == nil || model == nil {
+					t.Errorf("Expected to find provider/model, but got nil")
+					return
+				}
+
+				if provider.ID != tt.expectedProviderID {
+					t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
+				}
+
+				if model.ID != tt.expectedModelID {
+					t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
+				}
+			} else {
+				if provider != nil || model != nil {
+					t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
+				}
+			}
+		})
+	}
+}
+
+// TestFindProviderByID tests the findProviderByID function
+func TestFindProviderByID(t *testing.T) {
+	// Create test providers
+	providers := []opencode.Provider{
+		{ID: "anthropic"},
+		{ID: "openai"},
+		{ID: "google"},
+	}
+
+	tests := []struct {
+		name               string
+		providerID         string
+		expectedFound      bool
+		expectedProviderID string
+	}{
+		{
+			name:               "provider found",
+			providerID:         "anthropic",
+			expectedFound:      true,
+			expectedProviderID: "anthropic",
+		},
+		{
+			name:          "provider not found",
+			providerID:    "nonexistent",
+			expectedFound: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			provider := findProviderByID(providers, tt.providerID)
+
+			if tt.expectedFound {
+				if provider == nil {
+					t.Errorf("Expected to find provider, but got nil")
+					return
+				}
+
+				if provider.ID != tt.expectedProviderID {
+					t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
+				}
+			} else {
+				if provider != nil {
+					t.Errorf("Expected not to find provider, but got %v", provider)
+				}
+			}
+		})
+	}
+}

+ 6 - 4
packages/web/src/content/docs/docs/models.mdx

@@ -66,9 +66,11 @@ If you've configured a [custom provider](/docs/providers#custom), the `provider_
 
 ## Loading models
 
-When opencode starts up, it checks for the following:
+When opencode starts up, it checks for models in the following priority order:
 
-1. The model list in the opencode config.
+1. The `--model` or `-m` command line flag. The format is the same as in the config file: `provider_id/model_id`.
+
+2. The model list in the opencode config.
 
    ```json title="opencode.json"
    {
@@ -79,6 +81,6 @@ When opencode starts up, it checks for the following:
 
    The format here is `provider/model`.
 
-2. The last used model.
+3. The last used model.
 
-3. The first model using an internal priority.
+4. The first model using an internal priority.