소스 검색

feat: migrate HuggingFace provider to AI SDK (#11156)

Co-authored-by: Roo Code <[email protected]>
Co-authored-by: daniel-lxs <[email protected]>
roomote[bot] 1 주 전
부모
커밋
460cff4c3b
2개의 변경된 파일706개의 추가작업 그리고 79개의 파일을 삭제
  1. 553 0
      src/api/providers/__tests__/huggingface.spec.ts
  2. 153 79
      src/api/providers/huggingface.ts

+ 553 - 0
src/api/providers/__tests__/huggingface.spec.ts

@@ -0,0 +1,553 @@
+// npx vitest run src/api/providers/__tests__/huggingface.spec.ts
+
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
+	mockStreamText: vi.fn(),
+	mockGenerateText: vi.fn(),
+}))
+
+vi.mock("ai", async (importOriginal) => {
+	const actual = await importOriginal<typeof import("ai")>()
+	return {
+		...actual,
+		streamText: mockStreamText,
+		generateText: mockGenerateText,
+	}
+})
+
+vi.mock("@ai-sdk/openai-compatible", () => ({
+	createOpenAICompatible: vi.fn(() => {
+		// Return a function that returns a mock language model
+		return vi.fn(() => ({
+			modelId: "meta-llama/Llama-3.3-70B-Instruct",
+			provider: "huggingface",
+		}))
+	}),
+}))
+
+// Mock the fetchers
+vi.mock("../fetchers/huggingface", () => ({
+	getHuggingFaceModels: vi.fn(() => Promise.resolve({})),
+	getCachedHuggingFaceModels: vi.fn(() => ({})),
+}))
+
+import type { Anthropic } from "@anthropic-ai/sdk"
+
+import type { ApiHandlerOptions } from "../../../shared/api"
+
+import { HuggingFaceHandler } from "../huggingface"
+
+describe("HuggingFaceHandler", () => {
+	let handler: HuggingFaceHandler
+	let mockOptions: ApiHandlerOptions
+
+	beforeEach(() => {
+		mockOptions = {
+			huggingFaceApiKey: "test-huggingface-api-key",
+			huggingFaceModelId: "meta-llama/Llama-3.3-70B-Instruct",
+		}
+		handler = new HuggingFaceHandler(mockOptions)
+		vi.clearAllMocks()
+	})
+
+	describe("constructor", () => {
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(HuggingFaceHandler)
+			expect(handler.getModel().id).toBe(mockOptions.huggingFaceModelId)
+		})
+
+		it("should use default model ID if not provided", () => {
+			const handlerWithoutModel = new HuggingFaceHandler({
+				...mockOptions,
+				huggingFaceModelId: undefined,
+			})
+			expect(handlerWithoutModel.getModel().id).toBe("meta-llama/Llama-3.3-70B-Instruct")
+		})
+
+		it("should throw error if API key is not provided", () => {
+			expect(() => {
+				new HuggingFaceHandler({
+					...mockOptions,
+					huggingFaceApiKey: undefined,
+				})
+			}).toThrow("Hugging Face API key is required")
+		})
+	})
+
+	describe("getModel", () => {
+		it("should return default model when no model is specified", () => {
+			const handlerWithoutModel = new HuggingFaceHandler({
+				huggingFaceApiKey: "test-huggingface-api-key",
+			})
+			const model = handlerWithoutModel.getModel()
+			expect(model.id).toBe("meta-llama/Llama-3.3-70B-Instruct")
+			expect(model.info).toBeDefined()
+		})
+
+		it("should return specified model when valid model is provided", () => {
+			const testModelId = "mistralai/Mistral-7B-Instruct-v0.3"
+			const handlerWithModel = new HuggingFaceHandler({
+				huggingFaceModelId: testModelId,
+				huggingFaceApiKey: "test-huggingface-api-key",
+			})
+			const model = handlerWithModel.getModel()
+			expect(model.id).toBe(testModelId)
+		})
+
+		it("should include model parameters from getModelParams", () => {
+			const model = handler.getModel()
+			expect(model).toHaveProperty("temperature")
+			expect(model).toHaveProperty("maxTokens")
+		})
+
+		it("should return fallback info when model not in cache", () => {
+			const model = handler.getModel()
+			expect(model.info).toEqual(
+				expect.objectContaining({
+					maxTokens: 8192,
+					contextWindow: 131072,
+					supportsImages: false,
+					supportsPromptCache: false,
+				}),
+			)
+		})
+	})
+
+	describe("createMessage", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [
+					{
+						type: "text" as const,
+						text: "Hello!",
+					},
+				],
+			},
+		]
+
+		it("should handle streaming responses", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response from HuggingFace" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response from HuggingFace")
+		})
+
+		it("should include usage information", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 20,
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(10)
+			expect(usageChunks[0].outputTokens).toBe(20)
+		})
+
+		it("should handle cached tokens in usage data from providerMetadata", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 100,
+				outputTokens: 50,
+			})
+
+			// HuggingFace provides cache metrics via providerMetadata for supported models
+			const mockProviderMetadata = Promise.resolve({
+				huggingface: {
+					promptCacheHitTokens: 30,
+					promptCacheMissTokens: 70,
+				},
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(100)
+			expect(usageChunks[0].outputTokens).toBe(50)
+			expect(usageChunks[0].cacheReadTokens).toBe(30)
+			expect(usageChunks[0].cacheWriteTokens).toBe(70)
+		})
+
+		it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					cachedInputTokens: 25,
+				},
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].cacheReadTokens).toBe(25)
+			expect(usageChunks[0].cacheWriteTokens).toBeUndefined()
+		})
+
+		it("should pass correct temperature (0.7 default) to streamText", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const handlerWithDefaultTemp = new HuggingFaceHandler({
+				huggingFaceApiKey: "test-key",
+				huggingFaceModelId: "meta-llama/Llama-3.3-70B-Instruct",
+			})
+
+			const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages)
+			for await (const _ of stream) {
+				// consume stream
+			}
+
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.7,
+				}),
+			)
+		})
+
+		it("should use user-specified temperature over provider defaults", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const handlerWithCustomTemp = new HuggingFaceHandler({
+				huggingFaceApiKey: "test-key",
+				huggingFaceModelId: "meta-llama/Llama-3.3-70B-Instruct",
+				modelTemperature: 0.7,
+			})
+
+			const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages)
+			for await (const _ of stream) {
+				// consume stream
+			}
+
+			// User-specified temperature should take precedence over everything
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.7,
+				}),
+			)
+		})
+
+		it("should handle stream with multiple chunks", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Hello" }
+				yield { type: "text-delta", text: " world" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const textChunks = chunks.filter((c) => c.type === "text")
+			expect(textChunks[0]).toEqual({ type: "text", text: "Hello" })
+			expect(textChunks[1]).toEqual({ type: "text", text: " world" })
+
+			const usageChunks = chunks.filter((c) => c.type === "usage")
+			expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 })
+		})
+
+		it("should handle errors with handleAiSdkError", async () => {
+			async function* mockFullStream(): AsyncGenerator<any> {
+				yield { type: "text-delta", text: "" } // Yield something before error to satisfy lint
+				throw new Error("API Error")
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+
+			await expect(async () => {
+				for await (const _ of stream) {
+					// consume stream
+				}
+			}).rejects.toThrow("HuggingFace: API Error")
+		})
+	})
+
+	describe("completePrompt", () => {
+		it("should complete a prompt using generateText", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion from HuggingFace",
+			})
+
+			const result = await handler.completePrompt("Test prompt")
+
+			expect(result).toBe("Test completion from HuggingFace")
+			expect(mockGenerateText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					prompt: "Test prompt",
+				}),
+			)
+		})
+
+		it("should use default temperature in completePrompt", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion",
+			})
+
+			await handler.completePrompt("Test prompt")
+
+			expect(mockGenerateText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.7,
+				}),
+			)
+		})
+	})
+
+	describe("processUsageMetrics", () => {
+		it("should correctly process usage metrics including cache information from providerMetadata", () => {
+			class TestHuggingFaceHandler extends HuggingFaceHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
+
+			const testHandler = new TestHuggingFaceHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+			}
+
+			const providerMetadata = {
+				huggingface: {
+					promptCacheHitTokens: 20,
+					promptCacheMissTokens: 80,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage, providerMetadata)
+
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheWriteTokens).toBe(80)
+			expect(result.cacheReadTokens).toBe(20)
+		})
+
+		it("should handle missing cache metrics gracefully", () => {
+			class TestHuggingFaceHandler extends HuggingFaceHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
+
+			const testHandler = new TestHuggingFaceHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
+
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheWriteTokens).toBeUndefined()
+			expect(result.cacheReadTokens).toBeUndefined()
+		})
+
+		it("should include reasoning tokens when provided", () => {
+			class TestHuggingFaceHandler extends HuggingFaceHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
+
+			const testHandler = new TestHuggingFaceHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					reasoningTokens: 30,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
+
+			expect(result.reasoningTokens).toBe(30)
+		})
+	})
+
+	describe("tool handling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "Hello!" }],
+			},
+		]
+
+		it("should handle tool calls in streaming", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-input-start",
+					id: "tool-call-1",
+					toolName: "read_file",
+				}
+				yield {
+					type: "tool-input-delta",
+					id: "tool-call-1",
+					delta: '{"path":"test.ts"}',
+				}
+				yield {
+					type: "tool-input-end",
+					id: "tool-call-1",
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
+							},
+						},
+					},
+				],
+			})
+
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
+			const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
+			const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
+
+			expect(toolCallStartChunks.length).toBe(1)
+			expect(toolCallStartChunks[0].id).toBe("tool-call-1")
+			expect(toolCallStartChunks[0].name).toBe("read_file")
+
+			expect(toolCallDeltaChunks.length).toBe(1)
+			expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
+
+			expect(toolCallEndChunks.length).toBe(1)
+			expect(toolCallEndChunks[0].id).toBe("tool-call-1")
+		})
+	})
+})

+ 153 - 79
src/api/providers/huggingface.ts

@@ -1,22 +1,37 @@
-import OpenAI from "openai"
 import { Anthropic } from "@anthropic-ai/sdk"
+import { createOpenAICompatible } from "@ai-sdk/openai-compatible"
+import { streamText, generateText, ToolSet } from "ai"
 
-import type { ModelRecord } from "@roo-code/types"
+import type { ModelRecord, ModelInfo } from "@roo-code/types"
 
 import type { ApiHandlerOptions } from "../../shared/api"
-import { ApiStream } from "../transform/stream"
-import { convertToOpenAiMessages } from "../transform/openai-format"
-import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+
+import {
+	convertToAiSdkMessages,
+	convertToolsForAiSdk,
+	processAiSdkStreamPart,
+	mapToolChoice,
+	handleAiSdkError,
+} from "../transform/ai-sdk"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { getModelParams } from "../transform/model-params"
+
 import { DEFAULT_HEADERS } from "./constants"
 import { BaseProvider } from "./base-provider"
 import { getHuggingFaceModels, getCachedHuggingFaceModels } from "./fetchers/huggingface"
-import { handleOpenAIError } from "./utils/openai-error-handler"
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+
+const HUGGINGFACE_DEFAULT_TEMPERATURE = 0.7
 
+/**
+ * HuggingFace provider using @ai-sdk/openai-compatible for OpenAI-compatible API.
+ * Uses HuggingFace's OpenAI-compatible endpoint to enable tool message support.
+ * @see https://github.com/vercel/ai/issues/10766 - Workaround for tool messages not supported in @ai-sdk/huggingface
+ */
 export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler {
-	private client: OpenAI
-	private options: ApiHandlerOptions
+	protected options: ApiHandlerOptions
+	protected provider: ReturnType<typeof createOpenAICompatible>
 	private modelCache: ModelRecord | null = null
-	private readonly providerName = "HuggingFace"
 
 	constructor(options: ApiHandlerOptions) {
 		super()
@@ -26,10 +41,14 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
 			throw new Error("Hugging Face API key is required")
 		}
 
-		this.client = new OpenAI({
+		// Create an OpenAI-compatible provider pointing to HuggingFace's /v1 endpoint
+		// This fixes "tool messages not supported" error - the HuggingFace SDK doesn't
+		// properly handle function_call_output format, but OpenAI SDK does
+		this.provider = createOpenAICompatible({
+			name: "huggingface",
 			baseURL: "https://router.huggingface.co/v1",
 			apiKey: this.options.huggingFaceApiKey,
-			defaultHeaders: DEFAULT_HEADERS,
+			headers: DEFAULT_HEADERS,
 		})
 
 		// Try to get cached models first
@@ -47,91 +66,146 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
 		}
 	}
 
-	override async *createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		metadata?: ApiHandlerCreateMessageMetadata,
-	): ApiStream {
-		const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
-		const temperature = this.options.modelTemperature ?? 0.7
-
-		const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
-			model: modelId,
-			temperature,
-			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
-			stream: true,
-			stream_options: { include_usage: true },
-		}
+	override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } {
+		const id = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
 
-		// Add max_tokens if specified
-		if (this.options.includeMaxTokens && this.options.modelMaxTokens) {
-			params.max_tokens = this.options.modelMaxTokens
-		}
+		// Try to get model info from cache
+		const cachedInfo = this.modelCache?.[id]
 
-		let stream
-		try {
-			stream = await this.client.chat.completions.create(params)
-		} catch (error) {
-			throw handleOpenAIError(error, this.providerName)
+		const info: ModelInfo = cachedInfo || {
+			maxTokens: 8192,
+			contextWindow: 131072,
+			supportsImages: false,
+			supportsPromptCache: false,
 		}
 
-		for await (const chunk of stream) {
-			const delta = chunk.choices[0]?.delta
+		const params = getModelParams({
+			format: "openai",
+			modelId: id,
+			model: info,
+			settings: this.options,
+			defaultTemperature: HUGGINGFACE_DEFAULT_TEMPERATURE,
+		})
 
-			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
-				}
-			}
+		return { id, info, ...params }
+	}
 
-			if (chunk.usage) {
-				yield {
-					type: "usage",
-					inputTokens: chunk.usage.prompt_tokens || 0,
-					outputTokens: chunk.usage.completion_tokens || 0,
-				}
+	/**
+	 * Get the language model for the configured model ID.
+	 */
+	protected getLanguageModel() {
+		const { id } = this.getModel()
+		return this.provider(id)
+	}
+
+	/**
+	 * Process usage metrics from the AI SDK response.
+	 */
+	protected processUsageMetrics(
+		usage: {
+			inputTokens?: number
+			outputTokens?: number
+			details?: {
+				cachedInputTokens?: number
+				reasoningTokens?: number
+			}
+		},
+		providerMetadata?: {
+			huggingface?: {
+				promptCacheHitTokens?: number
+				promptCacheMissTokens?: number
 			}
+		},
+	): ApiStreamUsageChunk {
+		// Extract cache metrics from HuggingFace's providerMetadata if available
+		const cacheReadTokens = providerMetadata?.huggingface?.promptCacheHitTokens ?? usage.details?.cachedInputTokens
+		const cacheWriteTokens = providerMetadata?.huggingface?.promptCacheMissTokens
+
+		return {
+			type: "usage",
+			inputTokens: usage.inputTokens || 0,
+			outputTokens: usage.outputTokens || 0,
+			cacheReadTokens,
+			cacheWriteTokens,
+			reasoningTokens: usage.details?.reasoningTokens,
 		}
 	}
 
-	async completePrompt(prompt: string): Promise<string> {
-		const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
-
-		try {
-			const response = await this.client.chat.completions.create({
-				model: modelId,
-				messages: [{ role: "user", content: prompt }],
-			})
+	/**
+	 * Get the max tokens parameter to include in the request.
+	 */
+	protected getMaxOutputTokens(): number | undefined {
+		const { info } = this.getModel()
+		return this.options.modelMaxTokens || info.maxTokens || undefined
+	}
 
-			return response.choices[0]?.message.content || ""
-		} catch (error) {
-			throw handleOpenAIError(error, this.providerName)
+	/**
+	 * Create a message stream using the AI SDK.
+	 */
+	override async *createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		metadata?: ApiHandlerCreateMessageMetadata,
+	): ApiStream {
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		// Convert messages to AI SDK format
+		const aiSdkMessages = convertToAiSdkMessages(messages)
+
+		// Convert tools to OpenAI format first, then to AI SDK format
+		const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
+		const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
+
+		// Build the request options
+		const requestOptions: Parameters<typeof streamText>[0] = {
+			model: languageModel,
+			system: systemPrompt,
+			messages: aiSdkMessages,
+			temperature: this.options.modelTemperature ?? temperature ?? HUGGINGFACE_DEFAULT_TEMPERATURE,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			tools: aiSdkTools,
+			toolChoice: mapToolChoice(metadata?.tool_choice),
 		}
-	}
 
-	override getModel() {
-		const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
+		// Use streamText for streaming responses
+		const result = streamText(requestOptions)
 
-		// Try to get model info from cache
-		const modelInfo = this.modelCache?.[modelId]
+		try {
+			// Process the full stream to get all events
+			for await (const part of result.fullStream) {
+				// Use the processAiSdkStreamPart utility to convert stream parts
+				for (const chunk of processAiSdkStreamPart(part)) {
+					yield chunk
+				}
+			}
 
-		if (modelInfo) {
-			return {
-				id: modelId,
-				info: modelInfo,
+			// Yield usage metrics at the end, including cache metrics from providerMetadata
+			const usage = await result.usage
+			const providerMetadata = await result.providerMetadata
+			if (usage) {
+				yield this.processUsageMetrics(usage, providerMetadata as any)
 			}
+		} catch (error) {
+			// Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.)
+			throw handleAiSdkError(error, "HuggingFace")
 		}
+	}
 
-		// Fallback to default values if model not found in cache
-		return {
-			id: modelId,
-			info: {
-				maxTokens: 8192,
-				contextWindow: 131072,
-				supportsImages: false,
-				supportsPromptCache: false,
-			},
-		}
+	/**
+	 * Complete a prompt using the AI SDK generateText.
+	 */
+	async completePrompt(prompt: string): Promise<string> {
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		const { text } = await generateText({
+			model: languageModel,
+			prompt,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			temperature: this.options.modelTemperature ?? temperature ?? HUGGINGFACE_DEFAULT_TEMPERATURE,
+		})
+
+		return text
 	}
 }