local.tsx 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import { createStore } from "solid-js/store"
  2. import { batch, createEffect, createMemo, createSignal, onMount } from "solid-js"
  3. import { useSync } from "@tui/context/sync"
  4. import { useTheme } from "@tui/context/theme"
  5. import { uniqueBy } from "remeda"
  6. import path from "path"
  7. import { Global } from "@/global"
  8. import { iife } from "@/util/iife"
  9. import { createSimpleContext } from "./helper"
  10. import { useToast } from "../ui/toast"
  11. import { createEventBus } from "@solid-primitives/event-bus"
  12. import { Provider } from "@/provider/provider"
  13. export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
  14. name: "Local",
  15. init: (props: { initialModel?: string; initialAgent?: string; initialPrompt?: string }) => {
  16. const sync = useSync()
  17. const toast = useToast()
  18. function isModelValid(model: { providerID: string; modelID: string }) {
  19. const provider = sync.data.provider.find((x) => x.id === model.providerID)
  20. return !!provider?.models[model.modelID]
  21. }
  22. function getFirstValidModel(
  23. ...modelFns: (() => { providerID: string; modelID: string } | undefined)[]
  24. ) {
  25. for (const modelFn of modelFns) {
  26. const model = modelFn()
  27. if (!model) continue
  28. if (isModelValid(model)) return model
  29. }
  30. }
  31. // Set initial model if provided
  32. onMount(() => {
  33. batch(() => {
  34. if (props.initialAgent) {
  35. agent.set(props.initialAgent)
  36. }
  37. if (props.initialModel) {
  38. const { providerID, modelID } = Provider.parseModel(props.initialModel)
  39. if (!providerID || !modelID)
  40. return toast.show({
  41. variant: "warning",
  42. message: `Invalid model format: ${props.initialModel}`,
  43. duration: 3000,
  44. })
  45. model.set({ providerID, modelID }, { recent: true })
  46. }
  47. })
  48. })
  49. // Automatically update model when agent changes
  50. createEffect(() => {
  51. const value = agent.current()
  52. if (value.model) {
  53. if (isModelValid(value.model))
  54. model.set({
  55. providerID: value.model.providerID,
  56. modelID: value.model.modelID,
  57. })
  58. else
  59. toast.show({
  60. variant: "warning",
  61. message: `Agent ${value.name}'s configured model ${value.model.providerID}/${value.model.modelID} is not valid`,
  62. duration: 3000,
  63. })
  64. }
  65. })
  66. const agent = iife(() => {
  67. const agents = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent"))
  68. const [agentStore, setAgentStore] = createStore<{
  69. current: string
  70. }>({
  71. current: agents()[0].name,
  72. })
  73. const { theme } = useTheme()
  74. const colors = createMemo(() => [
  75. theme.secondary,
  76. theme.accent,
  77. theme.success,
  78. theme.warning,
  79. theme.primary,
  80. theme.error,
  81. ])
  82. return {
  83. list() {
  84. return agents()
  85. },
  86. current() {
  87. return agents().find((x) => x.name === agentStore.current)!
  88. },
  89. set(name: string) {
  90. if (!agents().some((x) => x.name === name))
  91. return toast.show({
  92. variant: "warning",
  93. message: `Agent not found: ${name}`,
  94. duration: 3000,
  95. })
  96. setAgentStore("current", name)
  97. },
  98. move(direction: 1 | -1) {
  99. batch(() => {
  100. let next = agents().findIndex((x) => x.name === agentStore.current) + direction
  101. if (next < 0) next = agents().length - 1
  102. if (next >= agents().length) next = 0
  103. const value = agents()[next]
  104. setAgentStore("current", value.name)
  105. })
  106. },
  107. color(name: string) {
  108. const index = agents().findIndex((x) => x.name === name)
  109. return colors()[index % colors().length]
  110. },
  111. }
  112. })
  113. const model = iife(() => {
  114. const [modelStore, setModelStore] = createStore<{
  115. ready: boolean
  116. model: Record<
  117. string,
  118. {
  119. providerID: string
  120. modelID: string
  121. }
  122. >
  123. recent: {
  124. providerID: string
  125. modelID: string
  126. }[]
  127. }>({
  128. ready: false,
  129. model: {},
  130. recent: [],
  131. })
  132. const file = Bun.file(path.join(Global.Path.state, "model.json"))
  133. file
  134. .json()
  135. .then((x) => {
  136. setModelStore("recent", x.recent)
  137. })
  138. .catch(() => {})
  139. .finally(() => {
  140. setModelStore("ready", true)
  141. })
  142. const fallbackModel = createMemo(() => {
  143. if (sync.data.config.model) {
  144. const { providerID, modelID } = Provider.parseModel(sync.data.config.model)
  145. if (isModelValid({ providerID, modelID })) {
  146. return {
  147. providerID,
  148. modelID,
  149. }
  150. }
  151. }
  152. for (const item of modelStore.recent) {
  153. if (isModelValid(item)) {
  154. return item
  155. }
  156. }
  157. const provider = sync.data.provider[0]
  158. const model = Object.values(provider.models)[0]
  159. return {
  160. providerID: provider.id,
  161. modelID: model.id,
  162. }
  163. })
  164. const currentModel = createMemo(() => {
  165. const a = agent.current()
  166. return getFirstValidModel(
  167. () => modelStore.model[a.name],
  168. () => a.model,
  169. fallbackModel,
  170. )!
  171. })
  172. return {
  173. current: currentModel,
  174. get ready() {
  175. return modelStore.ready
  176. },
  177. recent() {
  178. return modelStore.recent
  179. },
  180. parsed: createMemo(() => {
  181. const value = currentModel()
  182. const provider = sync.data.provider.find((x) => x.id === value.providerID)!
  183. const model = provider.models[value.modelID]
  184. return {
  185. provider: provider.name ?? value.providerID,
  186. model: model.name ?? value.modelID,
  187. }
  188. }),
  189. cycle(direction: 1 | -1) {
  190. const current = currentModel()
  191. if (!current) return
  192. const recent = modelStore.recent
  193. const index = recent.findIndex(
  194. (x) => x.providerID === current.providerID && x.modelID === current.modelID,
  195. )
  196. if (index === -1) return
  197. let next = index + direction
  198. if (next < 0) next = recent.length - 1
  199. if (next >= recent.length) next = 0
  200. const val = recent[next]
  201. if (!val) return
  202. setModelStore("model", agent.current().name, { ...val })
  203. },
  204. set(model: { providerID: string; modelID: string }, options?: { recent?: boolean }) {
  205. batch(() => {
  206. if (!isModelValid(model)) {
  207. toast.show({
  208. message: `Model ${model.providerID}/${model.modelID} is not valid`,
  209. variant: "warning",
  210. duration: 3000,
  211. })
  212. return
  213. }
  214. setModelStore("model", agent.current().name, model)
  215. if (options?.recent) {
  216. const uniq = uniqueBy([model, ...modelStore.recent], (x) => x.providerID + x.modelID)
  217. if (uniq.length > 5) uniq.pop()
  218. setModelStore("recent", uniq)
  219. Bun.write(
  220. file,
  221. JSON.stringify({
  222. recent: modelStore.recent,
  223. }),
  224. )
  225. }
  226. })
  227. },
  228. }
  229. })
  230. const setInitialPrompt = createEventBus<string>()
  231. onMount(() => {
  232. if (props.initialPrompt) setInitialPrompt.emit(props.initialPrompt)
  233. })
  234. const result = {
  235. model,
  236. agent,
  237. get setInitialPrompt() {
  238. return setInitialPrompt
  239. },
  240. }
  241. return result
  242. },
  243. })