local.tsx 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import { createStore } from "solid-js/store"
  2. import { batch, createEffect, createMemo } from "solid-js"
  3. import { useSync } from "@tui/context/sync"
  4. import { useTheme } from "@tui/context/theme"
  5. import { uniqueBy } from "remeda"
  6. import path from "path"
  7. import { Global } from "@/global"
  8. import { iife } from "@/util/iife"
  9. import { createSimpleContext } from "./helper"
  10. import { useToast } from "../ui/toast"
  11. import { Provider } from "@/provider/provider"
  12. import { useArgs } from "./args"
  13. import { RGBA } from "@opentui/core"
  14. export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
  15. name: "Local",
  16. init: () => {
  17. const sync = useSync()
  18. const toast = useToast()
  19. function isModelValid(model: { providerID: string; modelID: string }) {
  20. const provider = sync.data.provider.find((x) => x.id === model.providerID)
  21. return !!provider?.models[model.modelID]
  22. }
  23. function getFirstValidModel(...modelFns: (() => { providerID: string; modelID: string } | undefined)[]) {
  24. for (const modelFn of modelFns) {
  25. const model = modelFn()
  26. if (!model) continue
  27. if (isModelValid(model)) return model
  28. }
  29. }
  30. // Automatically update model when agent changes
  31. createEffect(() => {
  32. const value = agent.current()
  33. if (value.model) {
  34. if (isModelValid(value.model))
  35. model.set({
  36. providerID: value.model.providerID,
  37. modelID: value.model.modelID,
  38. })
  39. else
  40. toast.show({
  41. variant: "warning",
  42. message: `Agent ${value.name}'s configured model ${value.model.providerID}/${value.model.modelID} is not valid`,
  43. duration: 3000,
  44. })
  45. }
  46. })
  47. const agent = iife(() => {
  48. const agents = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent"))
  49. const [agentStore, setAgentStore] = createStore<{
  50. current: string
  51. }>({
  52. current: agents()[0].name,
  53. })
  54. const { theme } = useTheme()
  55. const colors = createMemo(() => [
  56. theme.secondary,
  57. theme.accent,
  58. theme.success,
  59. theme.warning,
  60. theme.primary,
  61. theme.error,
  62. ])
  63. return {
  64. list() {
  65. return agents()
  66. },
  67. current() {
  68. return agents().find((x) => x.name === agentStore.current)!
  69. },
  70. set(name: string) {
  71. if (!agents().some((x) => x.name === name))
  72. return toast.show({
  73. variant: "warning",
  74. message: `Agent not found: ${name}`,
  75. duration: 3000,
  76. })
  77. setAgentStore("current", name)
  78. },
  79. move(direction: 1 | -1) {
  80. batch(() => {
  81. let next = agents().findIndex((x) => x.name === agentStore.current) + direction
  82. if (next < 0) next = agents().length - 1
  83. if (next >= agents().length) next = 0
  84. const value = agents()[next]
  85. setAgentStore("current", value.name)
  86. })
  87. },
  88. color(name: string) {
  89. const agent = agents().find((x) => x.name === name)
  90. if (agent?.color) return RGBA.fromHex(agent.color)
  91. const index = agents().findIndex((x) => x.name === name)
  92. if (index === -1) return colors()[0]
  93. return colors()[index % colors().length]
  94. },
  95. }
  96. })
  97. const model = iife(() => {
  98. const [modelStore, setModelStore] = createStore<{
  99. ready: boolean
  100. model: Record<
  101. string,
  102. {
  103. providerID: string
  104. modelID: string
  105. }
  106. >
  107. recent: {
  108. providerID: string
  109. modelID: string
  110. }[]
  111. }>({
  112. ready: false,
  113. model: {},
  114. recent: [],
  115. })
  116. const file = Bun.file(path.join(Global.Path.state, "model.json"))
  117. file
  118. .json()
  119. .then((x) => {
  120. setModelStore("recent", x.recent)
  121. })
  122. .catch(() => {})
  123. .finally(() => {
  124. setModelStore("ready", true)
  125. })
  126. const args = useArgs()
  127. const fallbackModel = createMemo(() => {
  128. if (args.model) {
  129. const { providerID, modelID } = Provider.parseModel(args.model)
  130. if (isModelValid({ providerID, modelID })) {
  131. return {
  132. providerID,
  133. modelID,
  134. }
  135. }
  136. }
  137. if (sync.data.config.model) {
  138. const { providerID, modelID } = Provider.parseModel(sync.data.config.model)
  139. if (isModelValid({ providerID, modelID })) {
  140. return {
  141. providerID,
  142. modelID,
  143. }
  144. }
  145. }
  146. for (const item of modelStore.recent) {
  147. if (isModelValid(item)) {
  148. return item
  149. }
  150. }
  151. const provider = sync.data.provider[0]
  152. const model = sync.data.provider_default[provider.id] ?? Object.values(provider.models)[0].id
  153. return {
  154. providerID: provider.id,
  155. modelID: model,
  156. }
  157. })
  158. const currentModel = createMemo(() => {
  159. const a = agent.current()
  160. return getFirstValidModel(
  161. () => modelStore.model[a.name],
  162. () => a.model,
  163. fallbackModel,
  164. )!
  165. })
  166. return {
  167. current: currentModel,
  168. get ready() {
  169. return modelStore.ready
  170. },
  171. recent() {
  172. return modelStore.recent
  173. },
  174. parsed: createMemo(() => {
  175. const value = currentModel()
  176. const provider = sync.data.provider.find((x) => x.id === value.providerID)!
  177. const model = provider.models[value.modelID]
  178. return {
  179. provider: provider.name ?? value.providerID,
  180. model: model.name ?? value.modelID,
  181. }
  182. }),
  183. cycle(direction: 1 | -1) {
  184. const current = currentModel()
  185. if (!current) return
  186. const recent = modelStore.recent
  187. const index = recent.findIndex((x) => x.providerID === current.providerID && x.modelID === current.modelID)
  188. if (index === -1) return
  189. let next = index + direction
  190. if (next < 0) next = recent.length - 1
  191. if (next >= recent.length) next = 0
  192. const val = recent[next]
  193. if (!val) return
  194. setModelStore("model", agent.current().name, { ...val })
  195. },
  196. set(model: { providerID: string; modelID: string }, options?: { recent?: boolean }) {
  197. batch(() => {
  198. if (!isModelValid(model)) {
  199. toast.show({
  200. message: `Model ${model.providerID}/${model.modelID} is not valid`,
  201. variant: "warning",
  202. duration: 3000,
  203. })
  204. return
  205. }
  206. setModelStore("model", agent.current().name, model)
  207. if (options?.recent) {
  208. const uniq = uniqueBy([model, ...modelStore.recent], (x) => x.providerID + x.modelID)
  209. if (uniq.length > 5) uniq.pop()
  210. setModelStore("recent", uniq)
  211. Bun.write(
  212. file,
  213. JSON.stringify({
  214. recent: modelStore.recent,
  215. }),
  216. )
  217. }
  218. })
  219. },
  220. }
  221. })
  222. const result = {
  223. model,
  224. agent,
  225. }
  226. return result
  227. },
  228. })