Dax Raad 8 месяцев назад
Родитель
Сommit
db2bb32bcf

+ 1 - 1
packages/opencode/src/config/config.ts

@@ -50,7 +50,7 @@ export namespace Config {
 
   export const Info = z
     .object({
-      provider: z.lazy(() => Provider.Info.array().optional()),
+      provider: z.record(z.string(), z.record(z.string(), z.any())).optional(),
       tool: z
         .object({
           provider: z.record(z.string(), z.string().array()).optional(),

+ 0 - 77
packages/opencode/src/provider/database.ts

@@ -1,77 +0,0 @@
-import type { Provider } from "./provider"
-
-export const PROVIDER_DATABASE: Provider.Info[] = [
-  {
-    id: "anthropic",
-    name: "Anthropic",
-    models: [
-      {
-        id: "claude-sonnet-4-20250514",
-        name: "Claude Sonnet 4",
-        cost: {
-          input: 3.0 / 1_000_000,
-          output: 15.0 / 1_000_000,
-          inputCached: 3.75 / 1_000_000,
-          outputCached: 0.3 / 1_000_000,
-        },
-        contextWindow: 200_000,
-        maxOutputTokens: 50_000,
-        reasoning: true,
-        attachment: true,
-      },
-      {
-        id: "claude-opus-4-20250514",
-        name: "Claude Opus 4",
-        cost: {
-          input: 15.0 / 1_000_000,
-          output: 75.0 / 1_000_000,
-          inputCached: 18.75 / 1_000_000,
-          outputCached: 1.5 / 1_000_000,
-        },
-        contextWindow: 200_000,
-        maxOutputTokens: 32_000,
-        reasoning: true,
-        attachment: true,
-      },
-    ],
-  },
-  {
-    id: "openai",
-    name: "OpenAI",
-    models: [
-      {
-        id: "codex-mini-latest",
-        name: "Codex Mini",
-        cost: {
-          input: 1.5 / 1_000_000,
-          inputCached: 0.375 / 1_000_000,
-          output: 6.0 / 1_000_000,
-          outputCached: 0.0 / 1_000_000,
-        },
-        contextWindow: 200_000,
-        maxOutputTokens: 100_000,
-        attachment: true,
-        reasoning: true,
-      },
-    ],
-  },
-  {
-    id: "google",
-    name: "Google",
-    models: [
-      {
-        id: "gemini-2.5-pro-preview-03-25",
-        name: "Gemini 2.5 Pro",
-        cost: {
-          input: 1.25 / 1_000_000,
-          inputCached: 0 / 1_000_000,
-          output: 10 / 1_000_000,
-          outputCached: 0 / 1_000_000,
-        },
-        contextWindow: 1_000_000,
-        maxOutputTokens: 50_000,
-        attachment: true,
-      },
-    ],
-  },
-]

+ 29 - 0
packages/opencode/src/provider/models.ts

@@ -0,0 +1,29 @@
+import { Global } from "../global"
+import { Log } from "../util/log"
+import path from "path"
+
+export namespace ModelsDev {
+  const log = Log.create({ service: "models.dev" })
+
+  function filepath() {
+    return path.join(Global.Path.data, "models.json")
+  }
+
+  export async function get() {
+    const file = Bun.file(filepath())
+    if (await file.exists()) {
+      refresh()
+      return file.json()
+    }
+    await refresh()
+    return get()
+  }
+
+  async function refresh() {
+    log.info("refreshing")
+    const result = await fetch("https://models.dev/api.json")
+    if (!result.ok)
+      throw new Error(`Failed to fetch models.dev: ${result.statusText}`)
+    await Bun.write(filepath(), result)
+  }
+}

+ 70 - 42
packages/opencode/src/provider/provider.ts

@@ -1,7 +1,7 @@
 import z from "zod"
 import { App } from "../app/app"
 import { Config } from "../config/config"
-import { PROVIDER_DATABASE } from "./database"
+import { mapValues, sortBy } from "remeda"
 import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
 import { Log } from "../util/log"
 import path from "path"
@@ -22,6 +22,7 @@ import type { Tool } from "../tool/tool"
 import { WriteTool } from "../tool/write"
 import { TodoReadTool, TodoWriteTool } from "../tool/todo"
 import { AuthAnthropic } from "../auth/anthropic"
+import { ModelsDev } from "./models"
 
 export namespace Provider {
   const log = Log.create({ service: "provider" })
@@ -30,16 +31,18 @@ export namespace Provider {
     .object({
       id: z.string(),
       name: z.string().optional(),
+      attachment: z.boolean(),
+      reasoning: z.boolean().optional(),
       cost: z.object({
         input: z.number(),
         inputCached: z.number(),
         output: z.number(),
         outputCached: z.number(),
       }),
-      contextWindow: z.number(),
-      maxOutputTokens: z.number().optional(),
-      attachment: z.boolean(),
-      reasoning: z.boolean().optional(),
+      limit: z.object({
+        context: z.number(),
+        output: z.number(),
+      }),
     })
     .openapi({
       ref: "Provider.Model",
@@ -50,23 +53,27 @@ export namespace Provider {
     .object({
       id: z.string(),
       name: z.string(),
-      options: z.record(z.string(), z.any()).optional(),
-      models: Model.array(),
+      models: z.record(z.string(), Model),
     })
     .openapi({
       ref: "Provider.Info",
     })
   export type Info = z.output<typeof Info>
 
-  const AUTODETECT: Record<string, string[]> = {
-    anthropic: ["ANTHROPIC_API_KEY"],
-    openai: ["OPENAI_API_KEY"],
-    google: ["GOOGLE_GENERATIVE_AI_API_KEY"], // TODO: support GEMINI_API_KEY?
+  type Autodetector = (provider: Info) => Promise<Record<string, any> | false>
+
+  function env(...keys: string[]): Autodetector {
+    return async () => {
+      for (const key of keys) {
+        if (process.env[key]) return {}
+      }
+      return false
+    }
   }
 
-  const AUTODETECT2: Record<
+  const AUTODETECT: Record<
     string,
-    () => Promise<Record<string, any> | false>
+    (provider: Info) => Promise<Record<string, any> | false>
   > = {
     anthropic: async () => {
       const result = await AuthAnthropic.load()
@@ -78,44 +85,53 @@ export namespace Provider {
             "anthropic-beta": "oauth-2025-04-20",
           },
         }
-      if (process.env["ANTHROPIC_API_KEY"]) return {}
-      return false
+      return env("ANTHROPIC_API_KEY")
     },
+    google: env("GOOGLE_GENERATIVE_AI_API_KEY"),
+    openai: env("OPENAI_API_KEY"),
   }
 
   const state = App.state("provider", async () => {
     log.info("loading config")
     const config = await Config.get()
     log.info("loading providers")
-    const providers = new Map<string, Info>()
+    const database: Record<string, Provider.Info> = await ModelsDev.get()
+
+    const providers: {
+      [providerID: string]: {
+        info: Provider.Info
+        options: Record<string, any>
+      }
+    } = {}
     const models = new Map<string, { info: Model; language: LanguageModel }>()
     const sdk = new Map<string, SDK>()
 
     log.info("loading")
 
-    for (const [providerID, fn] of Object.entries(AUTODETECT2)) {
-      const provider = PROVIDER_DATABASE.find((x) => x.id === providerID)
+    for (const [providerID, fn] of Object.entries(AUTODETECT)) {
+      const provider = database[providerID]
       if (!provider) continue
-      const result = await fn()
-      if (!result) continue
-      providers.set(providerID, {
-        ...provider,
-        options: {
-          ...provider.options,
-          ...result,
-        },
-      })
-    }
-
-    for (const item of PROVIDER_DATABASE) {
-      if (!AUTODETECT[item.id].some((env) => process.env[env])) continue
-      log.info("found", { providerID: item.id })
-      providers.set(item.id, item)
+      const options = await fn(provider)
+      if (!options) continue
+      providers[providerID] = {
+        info: provider,
+        options,
+      }
     }
 
-    for (const item of config.provider ?? []) {
-      log.info("found", { providerID: item.id })
-      providers.set(item.id, item)
+    for (const [providerID, options] of Object.entries(config.provider ?? {})) {
+      const existing = providers[providerID]
+      if (existing) {
+        existing.options = {
+          ...existing.options,
+          ...options,
+        }
+        continue
+      }
+      providers[providerID] = {
+        info: database[providerID],
+        options,
+      }
     }
 
     return {
@@ -126,7 +142,9 @@ export namespace Provider {
   })
 
   export async function active() {
-    return state().then((state) => state.providers)
+    return state().then((state) =>
+      mapValues(state.providers, (item) => item.info),
+    )
   }
 
   async function getSDK(providerID: string) {
@@ -149,7 +167,7 @@ export namespace Provider {
     }
     const mod = await import(path.join(dir))
     const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
-    const loaded = fn(s.providers.get(providerID)?.options)
+    const loaded = fn(s.providers[providerID]?.options)
     s.sdk.set(providerID, loaded)
     return loaded as SDK
   }
@@ -164,9 +182,9 @@ export namespace Provider {
       modelID,
     })
 
-    const provider = s.providers.get(providerID)
+    const provider = s.providers[providerID]
     if (!provider) throw new ModelNotFoundError(modelID)
-    const info = provider.models.find((m) => m.id === modelID)
+    const info = provider.info.models[modelID]
     if (!info) throw new ModelNotFoundError(modelID)
 
     const sdk = await getSDK(providerID)
@@ -189,10 +207,20 @@ export namespace Provider {
     }
   }
 
+  const priority = ["claude-sonnet-4", "gemini-2.5-pro-preview", "codex-mini"]
+  export function sort(models: Model[]) {
+    return sortBy(
+      models,
+      [(model) => priority.indexOf(model.id), "desc"],
+      [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
+      [(model) => model.id, "desc"],
+    )
+  }
+
   export async function defaultModel() {
-    const [provider] = await active().then((val) => val.values().toArray())
+    const [provider] = await active().then((val) => Object.values(val))
     if (!provider) throw new Error("no providers found")
-    const model = provider.models[0]
+    const [model] = sort(Object.values(provider.models))
     if (!model) throw new Error("no models found")
     return {
       providerID: provider.id,

+ 14 - 2
packages/opencode/src/server/server.ts

@@ -10,6 +10,7 @@ import { Message } from "../session/message"
 import { Provider } from "../provider/provider"
 import { App } from "../app/app"
 import { Global } from "../global"
+import { mapValues } from "remeda"
 
 export namespace Server {
   const log = Log.create({ service: "server" })
@@ -379,7 +380,12 @@ export namespace Server {
               description: "List of providers",
               content: {
                 "application/json": {
-                  schema: resolver(Provider.Info.array()),
+                  schema: resolver(
+                    z.object({
+                      providers: Provider.Info.array(),
+                      default: z.record(z.string(), z.string()),
+                    }),
+                  ),
                 },
               },
             },
@@ -387,7 +393,13 @@ export namespace Server {
         }),
         async (c) => {
           const providers = await Provider.active()
-          return c.json(providers.values().toArray())
+          return c.json({
+            providers: Object.values(providers),
+            defaults: mapValues(
+              providers,
+              (item) => Provider.sort(Object.values(item.models))[0].id,
+            ),
+          })
         },
       )
 

+ 5 - 3
packages/tui/cmd/opencode/main.go

@@ -23,7 +23,10 @@ func main() {
 		slog.Error("Failed to create client", "error", err)
 		os.Exit(1)
 	}
-	paths, _ := httpClient.PostPathGetWithResponse(context.Background())
+	paths, err := httpClient.PostPathGetWithResponse(context.Background())
+	if err != nil {
+		panic(err)
+	}
 	logfile := filepath.Join(paths.JSON200.Data, "log", "tui.log")
 
 	if _, err := os.Stat(filepath.Dir(logfile)); os.IsNotExist(err) {
@@ -48,8 +51,7 @@ func main() {
 
 	app_, err := app.New(ctx, httpClient)
 	if err != nil {
-		slog.Error("Failed to create app", "error", err)
-		// return err
+		panic(err)
 	}
 
 	// Set up the TUI

+ 14 - 14
packages/tui/internal/app/app.go

@@ -43,18 +43,24 @@ func New(ctx context.Context, httpClient *client.ClientWithResponses) (*App, err
 
 	appInfoResponse, _ := httpClient.PostAppInfoWithResponse(ctx)
 	appInfo := appInfoResponse.JSON200
-	providersResponse, _ := httpClient.PostProviderListWithResponse(ctx)
+	providersResponse, err := httpClient.PostProviderListWithResponse(ctx)
+	if err != nil {
+		return nil, err
+	}
 	providers := []client.ProviderInfo{}
 	var defaultProvider *client.ProviderInfo
 	var defaultModel *client.ProviderModel
 
-	for _, provider := range *providersResponse.JSON200 {
-		if provider.Id == "anthropic" {
-			defaultProvider = &provider
-
-			for _, model := range provider.Models {
-				if model.Id == "claude-sonnet-4-20250514" {
+	for i, provider := range providersResponse.JSON200.Providers {
+		if i == 0 || provider.Id == "anthropic" {
+			defaultProvider = &providersResponse.JSON200.Providers[i]
+			if match, ok := providersResponse.JSON200.Default[provider.Id]; ok {
+				model := defaultProvider.Models[match]
+				defaultModel = &model
+			} else {
+				for _, model := range provider.Models {
 					defaultModel = &model
+					break
 				}
 			}
 		}
@@ -63,12 +69,6 @@ func New(ctx context.Context, httpClient *client.ClientWithResponses) (*App, err
 	if len(providers) == 0 {
 		return nil, fmt.Errorf("no providers found")
 	}
-	if defaultProvider == nil {
-		defaultProvider = &providers[0]
-	}
-	if defaultModel == nil {
-		defaultModel = &defaultProvider.Models[0]
-	}
 
 	appConfigPath := filepath.Join(appInfo.Path.Config, "tui.toml")
 	appConfig, err := config.LoadConfig(appConfigPath)
@@ -296,7 +296,7 @@ func (a *App) ListProviders(ctx context.Context) ([]client.ProviderInfo, error)
 	}
 
 	providers := *resp.JSON200
-	return providers, nil
+	return providers.Providers, nil
 }
 
 // IsFilepickerOpen returns whether the filepicker is currently open

+ 2 - 2
packages/tui/internal/components/core/status.go

@@ -7,9 +7,9 @@ import (
 
 	tea "github.com/charmbracelet/bubbletea"
 	"github.com/charmbracelet/lipgloss"
+	"github.com/sst/opencode/internal/app"
 	"github.com/sst/opencode/internal/pubsub"
 	"github.com/sst/opencode/internal/status"
-	"github.com/sst/opencode/internal/app"
 	"github.com/sst/opencode/internal/styles"
 	"github.com/sst/opencode/internal/theme"
 )
@@ -145,7 +145,7 @@ func (m statusCmp) View() string {
 	if m.app.Session.Id != "" {
 		tokens := float32(0)
 		cost := float32(0)
-		contextWindow := m.app.Model.ContextWindow
+		contextWindow := m.app.Model.Limit.Context
 
 		for _, message := range m.app.Messages {
 			if message.Metadata.Assistant != nil {

+ 14 - 3
packages/tui/internal/components/dialog/models.go

@@ -3,6 +3,9 @@ package dialog
 import (
 	"context"
 	"fmt"
+	"maps"
+	"slices"
+	"strings"
 
 	"github.com/charmbracelet/bubbles/key"
 	tea "github.com/charmbracelet/bubbletea"
@@ -38,7 +41,6 @@ type modelDialogCmp struct {
 	app                *app.App
 	availableProviders []client.ProviderInfo
 	provider           client.ProviderInfo
-	model              *client.ProviderModel
 
 	selectedIdx     int
 	width           int
@@ -144,7 +146,8 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 				m.switchProvider(1)
 			}
 		case key.Matches(msg, modelKeys.Enter):
-			return m, util.CmdHandler(CloseModelDialogMsg{Provider: &m.provider, Model: &m.provider.Models[m.selectedIdx]})
+			models := m.models()
+			return m, util.CmdHandler(CloseModelDialogMsg{Provider: &m.provider, Model: &models[m.selectedIdx]})
 		case key.Matches(msg, modelKeys.Escape):
 			return m, util.CmdHandler(CloseModelDialogMsg{})
 		}
@@ -156,6 +159,13 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	return m, nil
 }
 
+func (m *modelDialogCmp) models() []client.ProviderModel {
+	models := slices.SortedFunc(maps.Values(m.provider.Models), func(a, b client.ProviderModel) int {
+		return strings.Compare(*a.Name, *b.Name)
+	})
+	return models
+}
+
 // moveSelectionUp moves the selection up or wraps to bottom
 func (m *modelDialogCmp) moveSelectionUp() {
 	if m.selectedIdx > 0 {
@@ -218,13 +228,14 @@ func (m *modelDialogCmp) View() string {
 	endIdx := min(m.scrollOffset+numVisibleModels, len(m.provider.Models))
 	modelItems := make([]string, 0, endIdx-m.scrollOffset)
 
+	models := m.models()
 	for i := m.scrollOffset; i < endIdx; i++ {
 		itemStyle := baseStyle.Width(maxDialogWidth)
 		if i == m.selectedIdx {
 			itemStyle = itemStyle.Background(t.Primary()).
 				Foreground(t.Background()).Bold(true)
 		}
-		modelItems = append(modelItems, itemStyle.Render(*m.provider.Models[i].Name))
+		modelItems = append(modelItems, itemStyle.Render(*models[i].Name))
 	}
 
 	scrollIndicator := m.getScrollIndicators(maxDialogWidth)

+ 43 - 23
packages/tui/pkg/client/gen/openapi.json

@@ -401,10 +401,25 @@
             "content": {
               "application/json": {
                 "schema": {
-                  "type": "array",
-                  "items": {
-                    "$ref": "#/components/schemas/Provider.Info"
-                  }
+                  "type": "object",
+                  "properties": {
+                    "providers": {
+                      "type": "array",
+                      "items": {
+                        "$ref": "#/components/schemas/Provider.Info"
+                      }
+                    },
+                    "default": {
+                      "type": "object",
+                      "additionalProperties": {
+                        "type": "string"
+                      }
+                    }
+                  },
+                  "required": [
+                    "providers",
+                    "default"
+                  ]
                 }
               }
             }
@@ -1080,13 +1095,9 @@
           "name": {
             "type": "string"
           },
-          "options": {
-            "type": "object",
-            "additionalProperties": {}
-          },
           "models": {
-            "type": "array",
-            "items": {
+            "type": "object",
+            "additionalProperties": {
               "$ref": "#/components/schemas/Provider.Model"
             }
           }
@@ -1106,6 +1117,12 @@
           "name": {
             "type": "string"
           },
+          "attachment": {
+            "type": "boolean"
+          },
+          "reasoning": {
+            "type": "boolean"
+          },
           "cost": {
             "type": "object",
             "properties": {
@@ -1129,24 +1146,27 @@
               "outputCached"
             ]
           },
-          "contextWindow": {
-            "type": "number"
-          },
-          "maxOutputTokens": {
-            "type": "number"
-          },
-          "attachment": {
-            "type": "boolean"
-          },
-          "reasoning": {
-            "type": "boolean"
+          "limit": {
+            "type": "object",
+            "properties": {
+              "context": {
+                "type": "number"
+              },
+              "output": {
+                "type": "number"
+              }
+            },
+            "required": [
+              "context",
+              "output"
+            ]
           }
         },
         "required": [
           "id",
+          "attachment",
           "cost",
-          "contextWindow",
-          "attachment"
+          "limit"
         ]
       }
     }

+ 20 - 13
packages/tui/pkg/client/generated-client.go

@@ -203,26 +203,27 @@ type MessageToolInvocationToolResult struct {
 
 // ProviderInfo defines model for Provider.Info.
 type ProviderInfo struct {
-	Id      string                  `json:"id"`
-	Models  []ProviderModel         `json:"models"`
-	Name    string                  `json:"name"`
-	Options *map[string]interface{} `json:"options,omitempty"`
+	Id     string                   `json:"id"`
+	Models map[string]ProviderModel `json:"models"`
+	Name   string                   `json:"name"`
 }
 
 // ProviderModel defines model for Provider.Model.
 type ProviderModel struct {
-	Attachment    bool    `json:"attachment"`
-	ContextWindow float32 `json:"contextWindow"`
-	Cost          struct {
+	Attachment bool `json:"attachment"`
+	Cost       struct {
 		Input        float32 `json:"input"`
 		InputCached  float32 `json:"inputCached"`
 		Output       float32 `json:"output"`
 		OutputCached float32 `json:"outputCached"`
 	} `json:"cost"`
-	Id              string   `json:"id"`
-	MaxOutputTokens *float32 `json:"maxOutputTokens,omitempty"`
-	Name            *string  `json:"name,omitempty"`
-	Reasoning       *bool    `json:"reasoning,omitempty"`
+	Id    string `json:"id"`
+	Limit struct {
+		Context float32 `json:"context"`
+		Output  float32 `json:"output"`
+	} `json:"limit"`
+	Name      *string `json:"name,omitempty"`
+	Reasoning *bool   `json:"reasoning,omitempty"`
 }
 
 // PermissionInfo defines model for permission.info.
@@ -1815,7 +1816,10 @@ func (r PostPathGetResponse) StatusCode() int {
 type PostProviderListResponse struct {
 	Body         []byte
 	HTTPResponse *http.Response
-	JSON200      *[]ProviderInfo
+	JSON200      *struct {
+		Default   map[string]string `json:"default"`
+		Providers []ProviderInfo    `json:"providers"`
+	}
 }
 
 // Status returns HTTPResponse.Status
@@ -2299,7 +2303,10 @@ func ParsePostProviderListResponse(rsp *http.Response) (*PostProviderListRespons
 
 	switch {
 	case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200:
-		var dest []ProviderInfo
+		var dest struct {
+			Default   map[string]string `json:"default"`
+			Providers []ProviderInfo    `json:"providers"`
+		}
 		if err := json.Unmarshal(bodyBytes, &dest); err != nil {
 			return nil, err
 		}