Parcourir la source

wip: desktop work

Adam il y a 3 mois
Parent
commit
48f50cf55e

+ 12 - 3
packages/desktop/src/components/prompt-input.tsx

@@ -71,7 +71,7 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
     }
   })
 
-  const { flat, active, onInput, onKeyDown } = useFilteredList<string>({
+  const { flat, active, onInput, onKeyDown, refetch } = useFilteredList<string>({
     items: local.file.search,
     key: (x) => x,
     onSelect: (path) => {
@@ -81,6 +81,11 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
     },
   })
 
+  createEffect(() => {
+    local.model.recent()
+    refetch()
+  })
+
   createEffect(
     on(
       () => store.contentParts,
@@ -369,16 +374,20 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
               items={local.model.list()}
               current={local.model.current()}
               filterKeys={["provider.name", "name", "id"]}
-              groupBy={(x) => x.provider.name}
+              groupBy={(x) => (local.model.recent().includes(x) ? "Recent" : x.provider.name)}
               sortGroupsBy={(a, b) => {
                 const order = ["opencode", "anthropic", "github-copilot", "openai", "google", "openrouter", "vercel"]
+                if (a.category === "Recent" && b.category !== "Recent") return -1
+                if (b.category === "Recent" && a.category !== "Recent") return 1
                 const aProvider = a.items[0].provider.id
                 const bProvider = b.items[0].provider.id
                 if (order.includes(aProvider) && !order.includes(bProvider)) return -1
                 if (!order.includes(aProvider) && order.includes(bProvider)) return 1
                 return order.indexOf(aProvider) - order.indexOf(bProvider)
               }}
-              onSelect={(x) => local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined)}
+              onSelect={(x) =>
+                local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { recent: true })
+              }
               trigger={
                 <Button as="div" variant="ghost">
                   {local.model.current()?.name ?? "Select model"}

+ 66 - 13
packages/desktop/src/context/local.tsx

@@ -45,6 +45,37 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
     const sdk = useSDK()
     const sync = useSync()
 
+    function isModelValid(model: ModelKey) {
+      const provider = sync.data.provider.find((x) => x.id === model.providerID)
+      return !!provider?.models[model.modelID]
+    }
+
+    function getFirstValidModel(...modelFns: (() => ModelKey | undefined)[]) {
+      for (const modelFn of modelFns) {
+        const model = modelFn()
+        if (!model) continue
+        if (isModelValid(model)) return model
+      }
+    }
+
+    // Automatically update model when agent changes
+    createEffect(() => {
+      const value = agent.current()
+      if (value.model) {
+        if (isModelValid(value.model))
+          model.set({
+            providerID: value.model.providerID,
+            modelID: value.model.modelID,
+          })
+        // else
+        //   toast.show({
+        //     type: "warning",
+        //     message: `Agent ${value.name}'s configured model ${value.model.providerID}/${value.model.modelID} is not valid`,
+        //     duration: 3000,
+        //   })
+      }
+    })
+
     const agent = (() => {
       const list = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent"))
       const [store, setStore] = createStore<{
@@ -76,11 +107,6 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
     })()
 
     const model = (() => {
-      const list = createMemo(() =>
-        sync.data.provider.flatMap((p) => Object.values(p.models).map((m) => ({ ...m, provider: p }) as LocalModel)),
-      )
-      const find = (key: ModelKey) => list().find((m) => m.id === key?.modelID && m.provider.id === key.providerID)
-
       const [store, setStore] = createStore<{
         model: Record<string, ModelKey>
         recent: ModelKey[]
@@ -95,27 +121,54 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
         localStorage.setItem("model", JSON.stringify(store.recent))
       })
 
-      const fallback = createMemo(() => {
-        if (store.recent.length) return store.recent[0]
+      const list = createMemo(() =>
+        sync.data.provider.flatMap((p) => Object.values(p.models).map((m) => ({ ...m, provider: p }) as LocalModel)),
+      )
+      const find = (key: ModelKey) => list().find((m) => m.id === key?.modelID && m.provider.id === key.providerID)
+
+      const fallbackModel = createMemo(() => {
+        if (sync.data.config.model) {
+          const [providerID, modelID] = sync.data.config.model.split("/")
+          if (isModelValid({ providerID, modelID })) {
+            return {
+              providerID,
+              modelID,
+            }
+          }
+        }
+
+        for (const item of store.recent) {
+          if (isModelValid(item)) {
+            return item
+          }
+        }
         const provider = sync.data.provider[0]
         const model = Object.values(provider.models)[0]
-        return { modelID: model.id, providerID: provider.id }
+        return {
+          providerID: provider.id,
+          modelID: model.id,
+        }
       })
 
-      const current = createMemo(() => {
+      const currentModel = createMemo(() => {
         const a = agent.current()
-        return find(store.model[agent.current().name]) ?? find(a.model ?? fallback())
+        const key = getFirstValidModel(
+          () => store.model[a.name],
+          () => a.model,
+          fallbackModel,
+        )!
+        return find(key)
       })
 
       const recent = createMemo(() => store.recent.map(find).filter(Boolean))
 
       return {
-        list,
-        current,
+        current: currentModel,
         recent,
+        list,
         set(model: ModelKey | undefined, options?: { recent?: boolean }) {
           batch(() => {
-            setStore("model", agent.current().name, model ?? fallback())
+            setStore("model", agent.current().name, model ?? fallbackModel())
             if (options?.recent && model) {
               const uniq = uniqueBy([model, ...store.recent], (x) => x.providerID + x.modelID)
               if (uniq.length > 5) uniq.pop()

+ 0 - 3
packages/ui/src/components/tabs.css

@@ -57,9 +57,6 @@
     border-bottom: 1px solid var(--border-weak-base);
     border-right: 1px solid var(--border-weak-base);
     background-color: var(--background-base);
-    transition:
-      background-color 0.15s ease,
-      color 0.15s ease;
 
     &:disabled {
       pointer-events: none;

+ 9 - 4
packages/ui/src/hooks/use-filtered-list.tsx

@@ -11,18 +11,22 @@ export interface FilteredListProps<T> {
   current?: T
   groupBy?: (x: T) => string
   sortBy?: (a: T, b: T) => number
-  sortGroupsBy?: (a: { category: string; items: T[] }, b: { category: string; items: T[] }) => number
+  sortGroupsBy?: (
+    a: { category: string; items: T[] },
+    b: { category: string; items: T[] },
+  ) => number
   onSelect?: (value: T | undefined) => void
 }
 
 export function useFilteredList<T>(props: FilteredListProps<T>) {
   const [store, setStore] = createStore<{ filter: string }>({ filter: "" })
 
-  const [grouped] = createResource(
+  const [grouped, { refetch }] = createResource(
     () => store.filter,
     async (filter) => {
       const needle = filter?.toLowerCase()
-      const all = (typeof props.items === "function" ? await props.items(needle) : props.items) || []
+      const all =
+        (typeof props.items === "function" ? await props.items(needle) : props.items) || []
       const result = pipe(
         all,
         (x) => {
@@ -76,10 +80,11 @@ export function useFilteredList<T>(props: FilteredListProps<T>) {
   }
 
   return {
-    filter: () => store.filter,
     grouped,
+    filter: () => store.filter,
     flat,
     reset,
+    refetch,
     clear: () => setStore("filter", ""),
     onKeyDown,
     onInput,