Просмотр исходного кода

feat: ollama-deepseek reasoning support

System233 10 месяцев назад
Родитель
Сommit
0eecda2dc3
3 измененных файлов с 244 добавлено и 3 удалено
  1. 15 3
      src/api/providers/ollama.ts
  2. 124 0
      src/utils/__tests__/xml-matcher.test.ts
  3. 105 0
      src/utils/xml-matcher.ts

+ 15 - 3
src/api/providers/ollama.ts

@@ -6,6 +6,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
 import { convertToR1Format } from "../transform/r1-format"
 import { ApiStream } from "../transform/stream"
 import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./openai"
+import { XmlMatcher } from "../../utils/xml-matcher"
 
 const OLLAMA_DEFAULT_TEMPERATURE = 0
 
@@ -35,15 +36,26 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
 			temperature: this.options.modelTemperature ?? OLLAMA_DEFAULT_TEMPERATURE,
 			stream: true,
 		})
+		const matcher = new XmlMatcher(
+			"think",
+			(chunk) =>
+				({
+					type: chunk.matched ? "reasoning" : "text",
+					text: chunk.data,
+				}) as const,
+		)
 		for await (const chunk of stream) {
 			const delta = chunk.choices[0]?.delta
+
 			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
+				for (const chunk of matcher.update(delta.content)) {
+					yield chunk
 				}
 			}
 		}
+		for (const chunk of matcher.final()) {
+			yield chunk
+		}
 	}
 
 	getModel(): { id: string; info: ModelInfo } {

+ 124 - 0
src/utils/__tests__/xml-matcher.test.ts

@@ -0,0 +1,124 @@
+import { XmlMatcher } from "../xml-matcher"
+
+describe("XmlMatcher", () => {
+	it("only match at position 0", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("<think>data</think>"), ...matcher.final()]
+		expect(chunks).toHaveLength(1)
+		expect(chunks).toEqual([
+			{
+				matched: true,
+				data: "data",
+			},
+		])
+	})
+	it("tag with space", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("< think >data</ think >"), ...matcher.final()]
+		expect(chunks).toHaveLength(1)
+		expect(chunks).toEqual([
+			{
+				matched: true,
+				data: "data",
+			},
+		])
+	})
+
+	it("invalid tag", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("< think 1>data</ think >"), ...matcher.final()]
+		expect(chunks).toHaveLength(1)
+		expect(chunks).toEqual([
+			{
+				matched: false,
+				data: "< think 1>data</ think >",
+			},
+		])
+	})
+
+	it("anonymous tag", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("<>data</>"), ...matcher.final()]
+		expect(chunks).toHaveLength(1)
+		expect(chunks).toEqual([
+			{
+				matched: false,
+				data: "<>data</>",
+			},
+		])
+	})
+
+	it("streaming push", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [
+			...matcher.update("<thi"),
+			...matcher.update("nk"),
+			...matcher.update(">dat"),
+			...matcher.update("a</"),
+			...matcher.update("think>"),
+		]
+		expect(chunks).toHaveLength(2)
+		expect(chunks).toEqual([
+			{
+				matched: true,
+				data: "dat",
+			},
+			{
+				matched: true,
+				data: "a",
+			},
+		])
+	})
+
+	it("nested tag", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("<think>X<think>Y</think>Z</think>"), ...matcher.final()]
+		expect(chunks).toHaveLength(1)
+		expect(chunks).toEqual([
+			{
+				matched: true,
+				data: "X<think>Y</think>Z",
+			},
+		])
+	})
+
+	it("nested invalid tag", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("<think>X<think>Y</thxink>Z</think>"), ...matcher.final()]
+		expect(chunks).toHaveLength(2)
+		expect(chunks).toEqual([
+			{
+				matched: true,
+				data: "X<think>Y</thxink>Z",
+			},
+			{
+				matched: true,
+				data: "</think>",
+			},
+		])
+	})
+
+	it("Wrong matching position", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("1<think>data</think>"), ...matcher.final()]
+		expect(chunks).toHaveLength(1)
+		expect(chunks).toEqual([
+			{
+				matched: false,
+				data: "1<think>data</think>",
+			},
+		])
+	})
+
+	it("Unclosed tag", () => {
+		const matcher = new XmlMatcher("think")
+		const chunks = [...matcher.update("<think>data"), ...matcher.final()]
+		expect(chunks).toHaveLength(1)
+		expect(chunks).toEqual([
+			{
+				matched: true,
+				data: "data",
+			},
+		])
+	})
+})

+ 105 - 0
src/utils/xml-matcher.ts

@@ -0,0 +1,105 @@
+export interface XmlMatcherResult {
+	matched: boolean
+	data: string
+}
+export class XmlMatcher<Result = XmlMatcherResult> {
+	index = 0
+	chunks: XmlMatcherResult[] = []
+	cached: string[] = []
+	matched: boolean = false
+	state: "TEXT" | "TAG_OPEN" | "TAG_CLOSE" = "TEXT"
+	depth = 0
+	pointer = 0
+	constructor(
+		readonly tagName: string,
+		readonly transform?: (chunks: XmlMatcherResult) => Result,
+		readonly position = 0,
+	) {}
+	private collect() {
+		if (!this.cached.length) {
+			return
+		}
+		const last = this.chunks.at(-1)
+		const data = this.cached.join("")
+		const matched = this.matched
+		if (last?.matched === matched) {
+			last.data += data
+		} else {
+			this.chunks.push({
+				data,
+				matched,
+			})
+		}
+		this.cached = []
+	}
+	private pop() {
+		const chunks = this.chunks
+		this.chunks = []
+		if (!this.transform) {
+			return chunks as Result[]
+		}
+		return chunks.map(this.transform)
+	}
+
+	private _update(chunk: string) {
+		for (let i = 0; i < chunk.length; i++) {
+			const char = chunk[i]
+			this.cached.push(char)
+			this.pointer++
+
+			if (this.state === "TEXT") {
+				if (char === "<" && (this.pointer <= this.position + 1 || this.matched)) {
+					this.state = "TAG_OPEN"
+					this.index = 0
+				} else {
+					this.collect()
+				}
+			} else if (this.state === "TAG_OPEN") {
+				if (char === ">" && this.index === this.tagName.length) {
+					this.state = "TEXT"
+					if (!this.matched) {
+						this.cached = []
+					}
+					this.depth++
+					this.matched = true
+				} else if (this.index === 0 && char === "/") {
+					this.state = "TAG_CLOSE"
+				} else if (char === " " && (this.index === 0 || this.index === this.tagName.length)) {
+					continue
+				} else if (this.tagName[this.index] === char) {
+					this.index++
+				} else {
+					this.state = "TEXT"
+					this.collect()
+				}
+			} else if (this.state === "TAG_CLOSE") {
+				if (char === ">" && this.index === this.tagName.length) {
+					this.state = "TEXT"
+					this.depth--
+					this.matched = this.depth > 0
+					if (!this.matched) {
+						this.cached = []
+					}
+				} else if (char === " " && (this.index === 0 || this.index === this.tagName.length)) {
+					continue
+				} else if (this.tagName[this.index] === char) {
+					this.index++
+				} else {
+					this.state = "TEXT"
+					this.collect()
+				}
+			}
+		}
+	}
+	final(chunk?: string) {
+		if (chunk) {
+			this._update(chunk)
+		}
+		this.collect()
+		return this.pop()
+	}
+	update(chunk: string) {
+		this._update(chunk)
+		return this.pop()
+	}
+}