local.tsx 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import { createStore } from "solid-js/store"
  2. import { batch, createEffect, createMemo } from "solid-js"
  3. import { useSync } from "@tui/context/sync"
  4. import { useTheme } from "@tui/context/theme"
  5. import { uniqueBy } from "remeda"
  6. import path from "path"
  7. import { Global } from "@/global"
  8. import { iife } from "@/util/iife"
  9. import { createSimpleContext } from "./helper"
  10. import { useToast } from "../ui/toast"
  11. import { Provider } from "@/provider/provider"
  12. import { useArgs } from "./args"
  13. import { useSDK } from "./sdk"
  14. import { RGBA } from "@opentui/core"
  15. export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
  16. name: "Local",
  17. init: () => {
  18. const sync = useSync()
  19. const sdk = useSDK()
  20. const toast = useToast()
  21. function isModelValid(model: { providerID: string; modelID: string }) {
  22. const provider = sync.data.provider.find((x) => x.id === model.providerID)
  23. return !!provider?.models[model.modelID]
  24. }
  25. function getFirstValidModel(...modelFns: (() => { providerID: string; modelID: string } | undefined)[]) {
  26. for (const modelFn of modelFns) {
  27. const model = modelFn()
  28. if (!model) continue
  29. if (isModelValid(model)) return model
  30. }
  31. }
  32. // Automatically update model when agent changes
  33. createEffect(() => {
  34. const value = agent.current()
  35. if (value.model) {
  36. if (isModelValid(value.model))
  37. model.set({
  38. providerID: value.model.providerID,
  39. modelID: value.model.modelID,
  40. })
  41. else
  42. toast.show({
  43. variant: "warning",
  44. message: `Agent ${value.name}'s configured model ${value.model.providerID}/${value.model.modelID} is not valid`,
  45. duration: 3000,
  46. })
  47. }
  48. })
  49. const agent = iife(() => {
  50. const agents = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent"))
  51. const [agentStore, setAgentStore] = createStore<{
  52. current: string
  53. }>({
  54. current: agents()[0].name,
  55. })
  56. const { theme } = useTheme()
  57. const colors = createMemo(() => [
  58. theme.secondary,
  59. theme.accent,
  60. theme.success,
  61. theme.warning,
  62. theme.primary,
  63. theme.error,
  64. ])
  65. return {
  66. list() {
  67. return agents()
  68. },
  69. current() {
  70. return agents().find((x) => x.name === agentStore.current)!
  71. },
  72. set(name: string) {
  73. if (!agents().some((x) => x.name === name))
  74. return toast.show({
  75. variant: "warning",
  76. message: `Agent not found: ${name}`,
  77. duration: 3000,
  78. })
  79. setAgentStore("current", name)
  80. },
  81. move(direction: 1 | -1) {
  82. batch(() => {
  83. let next = agents().findIndex((x) => x.name === agentStore.current) + direction
  84. if (next < 0) next = agents().length - 1
  85. if (next >= agents().length) next = 0
  86. const value = agents()[next]
  87. setAgentStore("current", value.name)
  88. })
  89. },
  90. color(name: string) {
  91. const agent = agents().find((x) => x.name === name)
  92. if (agent?.color) return RGBA.fromHex(agent.color)
  93. const index = agents().findIndex((x) => x.name === name)
  94. if (index === -1) return colors()[0]
  95. return colors()[index % colors().length]
  96. },
  97. }
  98. })
  99. const model = iife(() => {
  100. const [modelStore, setModelStore] = createStore<{
  101. ready: boolean
  102. model: Record<
  103. string,
  104. {
  105. providerID: string
  106. modelID: string
  107. }
  108. >
  109. recent: {
  110. providerID: string
  111. modelID: string
  112. }[]
  113. favorite: {
  114. providerID: string
  115. modelID: string
  116. }[]
  117. }>({
  118. ready: false,
  119. model: {},
  120. recent: [],
  121. favorite: [],
  122. })
  123. const file = Bun.file(path.join(Global.Path.state, "model.json"))
  124. function save() {
  125. Bun.write(
  126. file,
  127. JSON.stringify({
  128. recent: modelStore.recent,
  129. favorite: modelStore.favorite,
  130. }),
  131. )
  132. }
  133. file
  134. .json()
  135. .then((x) => {
  136. if (Array.isArray(x.recent)) setModelStore("recent", x.recent)
  137. if (Array.isArray(x.favorite)) setModelStore("favorite", x.favorite)
  138. })
  139. .catch(() => {})
  140. .finally(() => {
  141. setModelStore("ready", true)
  142. })
  143. const args = useArgs()
  144. const fallbackModel = createMemo(() => {
  145. if (args.model) {
  146. const { providerID, modelID } = Provider.parseModel(args.model)
  147. if (isModelValid({ providerID, modelID })) {
  148. return {
  149. providerID,
  150. modelID,
  151. }
  152. }
  153. }
  154. if (sync.data.config.model) {
  155. const { providerID, modelID } = Provider.parseModel(sync.data.config.model)
  156. if (isModelValid({ providerID, modelID })) {
  157. return {
  158. providerID,
  159. modelID,
  160. }
  161. }
  162. }
  163. for (const item of modelStore.recent) {
  164. if (isModelValid(item)) {
  165. return item
  166. }
  167. }
  168. const provider = sync.data.provider[0]
  169. if (!provider) return undefined
  170. const defaultModel = sync.data.provider_default[provider.id]
  171. const firstModel = Object.values(provider.models)[0]
  172. const model = defaultModel ?? firstModel?.id
  173. if (!model) return undefined
  174. return {
  175. providerID: provider.id,
  176. modelID: model,
  177. }
  178. })
  179. const currentModel = createMemo(() => {
  180. const a = agent.current()
  181. return (
  182. getFirstValidModel(
  183. () => modelStore.model[a.name],
  184. () => a.model,
  185. fallbackModel,
  186. ) ?? undefined
  187. )
  188. })
  189. return {
  190. current: currentModel,
  191. get ready() {
  192. return modelStore.ready
  193. },
  194. recent() {
  195. return modelStore.recent
  196. },
  197. favorite() {
  198. return modelStore.favorite
  199. },
  200. parsed: createMemo(() => {
  201. const value = currentModel()
  202. if (!value) {
  203. return {
  204. provider: "Connect a provider",
  205. model: "No provider selected",
  206. }
  207. }
  208. const provider = sync.data.provider.find((x) => x.id === value.providerID)
  209. const info = provider?.models[value.modelID]
  210. return {
  211. provider: provider?.name ?? value.providerID,
  212. model: info?.name ?? value.modelID,
  213. }
  214. }),
  215. cycle(direction: 1 | -1) {
  216. const current = currentModel()
  217. if (!current) return
  218. const recent = modelStore.recent
  219. const index = recent.findIndex((x) => x.providerID === current.providerID && x.modelID === current.modelID)
  220. if (index === -1) return
  221. let next = index + direction
  222. if (next < 0) next = recent.length - 1
  223. if (next >= recent.length) next = 0
  224. const val = recent[next]
  225. if (!val) return
  226. setModelStore("model", agent.current().name, { ...val })
  227. },
  228. cycleFavorite(direction: 1 | -1) {
  229. const favorites = modelStore.favorite.filter((item) => isModelValid(item))
  230. if (!favorites.length) {
  231. toast.show({
  232. variant: "info",
  233. message: "Add a favorite model to use this shortcut",
  234. duration: 3000,
  235. })
  236. return
  237. }
  238. const current = currentModel()
  239. let index = -1
  240. if (current) {
  241. index = favorites.findIndex((x) => x.providerID === current.providerID && x.modelID === current.modelID)
  242. }
  243. if (index === -1) {
  244. index = direction === 1 ? 0 : favorites.length - 1
  245. } else {
  246. index += direction
  247. if (index < 0) index = favorites.length - 1
  248. if (index >= favorites.length) index = 0
  249. }
  250. const next = favorites[index]
  251. if (!next) return
  252. setModelStore("model", agent.current().name, { ...next })
  253. const uniq = uniqueBy([next, ...modelStore.recent], (x) => x.providerID + x.modelID)
  254. if (uniq.length > 10) uniq.pop()
  255. setModelStore("recent", uniq)
  256. save()
  257. },
  258. set(model: { providerID: string; modelID: string }, options?: { recent?: boolean }) {
  259. batch(() => {
  260. if (!isModelValid(model)) {
  261. toast.show({
  262. message: `Model ${model.providerID}/${model.modelID} is not valid`,
  263. variant: "warning",
  264. duration: 3000,
  265. })
  266. return
  267. }
  268. setModelStore("model", agent.current().name, model)
  269. if (options?.recent) {
  270. const uniq = uniqueBy([model, ...modelStore.recent], (x) => x.providerID + x.modelID)
  271. if (uniq.length > 10) uniq.pop()
  272. setModelStore("recent", uniq)
  273. save()
  274. }
  275. })
  276. },
  277. toggleFavorite(model: { providerID: string; modelID: string }) {
  278. batch(() => {
  279. if (!isModelValid(model)) {
  280. toast.show({
  281. message: `Model ${model.providerID}/${model.modelID} is not valid`,
  282. variant: "warning",
  283. duration: 3000,
  284. })
  285. return
  286. }
  287. const exists = modelStore.favorite.some(
  288. (x) => x.providerID === model.providerID && x.modelID === model.modelID,
  289. )
  290. const next = exists
  291. ? modelStore.favorite.filter((x) => x.providerID !== model.providerID || x.modelID !== model.modelID)
  292. : [model, ...modelStore.favorite]
  293. setModelStore("favorite", next)
  294. save()
  295. })
  296. },
  297. }
  298. })
  299. const mcp = {
  300. isEnabled(name: string) {
  301. const status = sync.data.mcp[name]
  302. return status?.status === "connected"
  303. },
  304. async toggle(name: string) {
  305. const status = sync.data.mcp[name]
  306. if (status?.status === "connected") {
  307. // Disable: disconnect the MCP
  308. await sdk.client.mcp.disconnect({ name })
  309. } else {
  310. // Enable/Retry: connect the MCP (handles disabled, failed, and other states)
  311. await sdk.client.mcp.connect({ name })
  312. }
  313. },
  314. }
  315. const result = {
  316. model,
  317. agent,
  318. mcp,
  319. }
  320. return result
  321. },
  322. })