local.tsx 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import { createStore } from "solid-js/store"
  2. import { batch, createEffect, createMemo } from "solid-js"
  3. import { useSync } from "@tui/context/sync"
  4. import { useTheme } from "@tui/context/theme"
  5. import { uniqueBy } from "remeda"
  6. import path from "path"
  7. import { Global } from "@/global"
  8. import { iife } from "@/util/iife"
  9. import { createSimpleContext } from "./helper"
  10. import { useToast } from "../ui/toast"
  11. import { Provider } from "@/provider/provider"
  12. import { useArgs } from "./args"
  13. import { RGBA } from "@opentui/core"
  14. export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
  15. name: "Local",
  16. init: () => {
  17. const sync = useSync()
  18. const toast = useToast()
  19. function isModelValid(model: { providerID: string; modelID: string }) {
  20. const provider = sync.data.provider.find((x) => x.id === model.providerID)
  21. return !!provider?.models[model.modelID]
  22. }
  23. function getFirstValidModel(...modelFns: (() => { providerID: string; modelID: string } | undefined)[]) {
  24. for (const modelFn of modelFns) {
  25. const model = modelFn()
  26. if (!model) continue
  27. if (isModelValid(model)) return model
  28. }
  29. }
  30. // Automatically update model when agent changes
  31. createEffect(() => {
  32. const value = agent.current()
  33. if (value.model) {
  34. if (isModelValid(value.model))
  35. model.set({
  36. providerID: value.model.providerID,
  37. modelID: value.model.modelID,
  38. })
  39. else
  40. toast.show({
  41. variant: "warning",
  42. message: `Agent ${value.name}'s configured model ${value.model.providerID}/${value.model.modelID} is not valid`,
  43. duration: 3000,
  44. })
  45. }
  46. })
  47. const agent = iife(() => {
  48. const agents = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent"))
  49. const [agentStore, setAgentStore] = createStore<{
  50. current: string
  51. }>({
  52. current: agents()[0].name,
  53. })
  54. const { theme } = useTheme()
  55. const colors = createMemo(() => [
  56. theme.secondary,
  57. theme.accent,
  58. theme.success,
  59. theme.warning,
  60. theme.primary,
  61. theme.error,
  62. ])
  63. return {
  64. list() {
  65. return agents()
  66. },
  67. current() {
  68. return agents().find((x) => x.name === agentStore.current)!
  69. },
  70. set(name: string) {
  71. if (!agents().some((x) => x.name === name))
  72. return toast.show({
  73. variant: "warning",
  74. message: `Agent not found: ${name}`,
  75. duration: 3000,
  76. })
  77. setAgentStore("current", name)
  78. },
  79. move(direction: 1 | -1) {
  80. batch(() => {
  81. let next = agents().findIndex((x) => x.name === agentStore.current) + direction
  82. if (next < 0) next = agents().length - 1
  83. if (next >= agents().length) next = 0
  84. const value = agents()[next]
  85. setAgentStore("current", value.name)
  86. })
  87. },
  88. color(name: string) {
  89. const agent = agents().find((x) => x.name === name)
  90. if (agent?.color) return RGBA.fromHex(agent.color)
  91. const index = agents().findIndex((x) => x.name === name)
  92. if (index === -1) return colors()[0]
  93. return colors()[index % colors().length]
  94. },
  95. }
  96. })
  97. const model = iife(() => {
  98. const [modelStore, setModelStore] = createStore<{
  99. ready: boolean
  100. model: Record<
  101. string,
  102. {
  103. providerID: string
  104. modelID: string
  105. }
  106. >
  107. recent: {
  108. providerID: string
  109. modelID: string
  110. }[]
  111. favorite: {
  112. providerID: string
  113. modelID: string
  114. }[]
  115. }>({
  116. ready: false,
  117. model: {},
  118. recent: [],
  119. favorite: [],
  120. })
  121. const file = Bun.file(path.join(Global.Path.state, "model.json"))
  122. function save() {
  123. Bun.write(
  124. file,
  125. JSON.stringify({
  126. recent: modelStore.recent,
  127. favorite: modelStore.favorite,
  128. }),
  129. )
  130. }
  131. file
  132. .json()
  133. .then((x) => {
  134. if (Array.isArray(x.recent)) setModelStore("recent", x.recent)
  135. if (Array.isArray(x.favorite)) setModelStore("favorite", x.favorite)
  136. })
  137. .catch(() => {})
  138. .finally(() => {
  139. setModelStore("ready", true)
  140. })
  141. const args = useArgs()
  142. const fallbackModel = createMemo(() => {
  143. if (args.model) {
  144. const { providerID, modelID } = Provider.parseModel(args.model)
  145. if (isModelValid({ providerID, modelID })) {
  146. return {
  147. providerID,
  148. modelID,
  149. }
  150. }
  151. }
  152. if (sync.data.config.model) {
  153. const { providerID, modelID } = Provider.parseModel(sync.data.config.model)
  154. if (isModelValid({ providerID, modelID })) {
  155. return {
  156. providerID,
  157. modelID,
  158. }
  159. }
  160. }
  161. for (const item of modelStore.recent) {
  162. if (isModelValid(item)) {
  163. return item
  164. }
  165. }
  166. const provider = sync.data.provider[0]
  167. const model = sync.data.provider_default[provider.id] ?? Object.values(provider.models)[0].id
  168. return {
  169. providerID: provider.id,
  170. modelID: model,
  171. }
  172. })
  173. const currentModel = createMemo(() => {
  174. const a = agent.current()
  175. return getFirstValidModel(
  176. () => modelStore.model[a.name],
  177. () => a.model,
  178. fallbackModel,
  179. )!
  180. })
  181. return {
  182. current: currentModel,
  183. get ready() {
  184. return modelStore.ready
  185. },
  186. recent() {
  187. return modelStore.recent
  188. },
  189. favorite() {
  190. return modelStore.favorite
  191. },
  192. parsed: createMemo(() => {
  193. const value = currentModel()
  194. const provider = sync.data.provider.find((x) => x.id === value.providerID)!
  195. const model = provider.models[value.modelID]
  196. return {
  197. provider: provider.name ?? value.providerID,
  198. model: model.name ?? value.modelID,
  199. }
  200. }),
  201. cycle(direction: 1 | -1) {
  202. const current = currentModel()
  203. if (!current) return
  204. const recent = modelStore.recent
  205. const index = recent.findIndex((x) => x.providerID === current.providerID && x.modelID === current.modelID)
  206. if (index === -1) return
  207. let next = index + direction
  208. if (next < 0) next = recent.length - 1
  209. if (next >= recent.length) next = 0
  210. const val = recent[next]
  211. if (!val) return
  212. setModelStore("model", agent.current().name, { ...val })
  213. },
  214. cycleFavorite(direction: 1 | -1) {
  215. const favorites = modelStore.favorite.filter((item) => isModelValid(item))
  216. if (!favorites.length) {
  217. toast.show({
  218. variant: "info",
  219. message: "Add a favorite model to use this shortcut",
  220. duration: 3000,
  221. })
  222. return
  223. }
  224. const current = currentModel()
  225. let index = favorites.findIndex((x) => x.providerID === current.providerID && x.modelID === current.modelID)
  226. if (index === -1) {
  227. index = direction === 1 ? 0 : favorites.length - 1
  228. } else {
  229. index += direction
  230. if (index < 0) index = favorites.length - 1
  231. if (index >= favorites.length) index = 0
  232. }
  233. const next = favorites[index]
  234. if (!next) return
  235. setModelStore("model", agent.current().name, { ...next })
  236. const uniq = uniqueBy([next, ...modelStore.recent], (x) => x.providerID + x.modelID)
  237. if (uniq.length > 10) uniq.pop()
  238. setModelStore("recent", uniq)
  239. save()
  240. },
  241. set(model: { providerID: string; modelID: string }, options?: { recent?: boolean }) {
  242. batch(() => {
  243. if (!isModelValid(model)) {
  244. toast.show({
  245. message: `Model ${model.providerID}/${model.modelID} is not valid`,
  246. variant: "warning",
  247. duration: 3000,
  248. })
  249. return
  250. }
  251. setModelStore("model", agent.current().name, model)
  252. if (options?.recent) {
  253. const uniq = uniqueBy([model, ...modelStore.recent], (x) => x.providerID + x.modelID)
  254. if (uniq.length > 10) uniq.pop()
  255. setModelStore("recent", uniq)
  256. save()
  257. }
  258. })
  259. },
  260. toggleFavorite(model: { providerID: string; modelID: string }) {
  261. batch(() => {
  262. if (!isModelValid(model)) {
  263. toast.show({
  264. message: `Model ${model.providerID}/${model.modelID} is not valid`,
  265. variant: "warning",
  266. duration: 3000,
  267. })
  268. return
  269. }
  270. const exists = modelStore.favorite.some(
  271. (x) => x.providerID === model.providerID && x.modelID === model.modelID,
  272. )
  273. const next = exists
  274. ? modelStore.favorite.filter((x) => x.providerID !== model.providerID || x.modelID !== model.modelID)
  275. : [model, ...modelStore.favorite]
  276. setModelStore("favorite", next)
  277. save()
  278. })
  279. },
  280. }
  281. })
  282. const result = {
  283. model,
  284. agent,
  285. }
  286. return result
  287. },
  288. })