소스 검색

Rate limit when starting a subtask (#4453)

Co-authored-by: Daniel Riccio <[email protected]>
Co-authored-by: Matt Rubens <[email protected]>
Olwer Altuve 6 달 전
부모
커밋
4b2b05f262
3개의 변경된 파일410개의 추가작업 그리고 35개의 파일을 삭제
  1. 2 1
      .vscode/settings.json
  2. 16 6
      src/core/task/Task.ts
  3. 392 28
      src/core/task/__tests__/Task.spec.ts

+ 2 - 1
.vscode/settings.json

@@ -9,5 +9,6 @@
 		"dist": true // set this to false to include "dist" folder in search results
 	},
 	// Turn off tsc task auto detection since we have the necessary tasks as npm scripts
-	"typescript.tsc.autoDetect": "off"
+	"typescript.tsc.autoDetect": "off",
+	"vitest.disableWorkspaceWarning": true
 }

+ 16 - 6
src/core/task/Task.ts

@@ -140,9 +140,17 @@ export class Task extends EventEmitter<ClineEvents> {
 	// API
 	readonly apiConfiguration: ProviderSettings
 	api: ApiHandler
-	private lastApiRequestTime?: number
+	private static lastGlobalApiRequestTime?: number
 	private consecutiveAutoApprovedRequestsCount: number = 0
 
+	/**
+	 * Reset the global API request timestamp. This should only be used for testing.
+	 * @internal
+	 */
+	static resetGlobalApiRequestTime(): void {
+		Task.lastGlobalApiRequestTime = undefined
+	}
+
 	toolRepetitionDetector: ToolRepetitionDetector
 	rooIgnoreController?: RooIgnoreController
 	rooProtectedController?: RooProtectedController
@@ -1657,10 +1665,11 @@ export class Task extends EventEmitter<ClineEvents> {
 
 		let rateLimitDelay = 0
 
-		// Only apply rate limiting if this isn't the first request
-		if (this.lastApiRequestTime) {
+		// Use the shared timestamp so that subtasks respect the same rate-limit
+		// window as their parent tasks.
+		if (Task.lastGlobalApiRequestTime) {
 			const now = Date.now()
-			const timeSinceLastRequest = now - this.lastApiRequestTime
+			const timeSinceLastRequest = now - Task.lastGlobalApiRequestTime
 			const rateLimit = apiConfiguration?.rateLimitSeconds || 0
 			rateLimitDelay = Math.ceil(Math.max(0, rateLimit * 1000 - timeSinceLastRequest) / 1000)
 		}
@@ -1675,8 +1684,9 @@ export class Task extends EventEmitter<ClineEvents> {
 			}
 		}
 
-		// Update last request time before making the request
-		this.lastApiRequestTime = Date.now()
+		// Update last request time before making the request so that subsequent
+		// requests — even from new subtasks — will honour the provider's rate-limit.
+		Task.lastGlobalApiRequestTime = Date.now()
 
 		const systemPrompt = await this.getSystemPrompt()
 		const { contextTokens } = this.getTokenUsage()

+ 392 - 28
src/core/task/__tests__/Task.spec.ts

@@ -18,38 +18,55 @@ import { MultiSearchReplaceDiffStrategy } from "../../diff/strategies/multi-sear
 import { MultiFileSearchReplaceDiffStrategy } from "../../diff/strategies/multi-file-search-replace"
 import { EXPERIMENT_IDS } from "../../../shared/experiments"
 
+// Mock delay before any imports that might use it
+vi.mock("delay", () => ({
+	__esModule: true,
+	default: vi.fn().mockResolvedValue(undefined),
+}))
+
+import delay from "delay"
+
 vi.mock("execa", () => ({
 	execa: vi.fn(),
 }))
 
-vi.mock("fs/promises", () => ({
-	mkdir: vi.fn().mockResolvedValue(undefined),
-	writeFile: vi.fn().mockResolvedValue(undefined),
-	readFile: vi.fn().mockImplementation((filePath) => {
-		if (filePath.includes("ui_messages.json")) {
-			return Promise.resolve(JSON.stringify(mockMessages))
-		}
-		if (filePath.includes("api_conversation_history.json")) {
-			return Promise.resolve(
-				JSON.stringify([
-					{
-						role: "user",
-						content: [{ type: "text", text: "historical task" }],
-						ts: Date.now(),
-					},
-					{
-						role: "assistant",
-						content: [{ type: "text", text: "I'll help you with that task." }],
-						ts: Date.now(),
-					},
-				]),
-			)
-		}
-		return Promise.resolve("[]")
-	}),
-	unlink: vi.fn().mockResolvedValue(undefined),
-	rmdir: vi.fn().mockResolvedValue(undefined),
-}))
+vi.mock("fs/promises", async (importOriginal) => {
+	const actual = (await importOriginal()) as Record<string, any>
+	const mockFunctions = {
+		mkdir: vi.fn().mockResolvedValue(undefined),
+		writeFile: vi.fn().mockResolvedValue(undefined),
+		readFile: vi.fn().mockImplementation((filePath) => {
+			if (filePath.includes("ui_messages.json")) {
+				return Promise.resolve(JSON.stringify(mockMessages))
+			}
+			if (filePath.includes("api_conversation_history.json")) {
+				return Promise.resolve(
+					JSON.stringify([
+						{
+							role: "user",
+							content: [{ type: "text", text: "historical task" }],
+							ts: Date.now(),
+						},
+						{
+							role: "assistant",
+							content: [{ type: "text", text: "I'll help you with that task." }],
+							ts: Date.now(),
+						},
+					]),
+				)
+			}
+			return Promise.resolve("[]")
+		}),
+		unlink: vi.fn().mockResolvedValue(undefined),
+		rmdir: vi.fn().mockResolvedValue(undefined),
+	}
+
+	return {
+		...actual,
+		...mockFunctions,
+		default: mockFunctions,
+	}
+})
 
 vi.mock("p-wait-for", () => ({
 	default: vi.fn().mockImplementation(async () => Promise.resolve()),
@@ -869,6 +886,353 @@ describe("Cline", () => {
 			})
 		})
 
+		describe("Subtask Rate Limiting", () => {
+			let mockProvider: any
+			let mockApiConfig: any
+			let mockDelay: ReturnType<typeof vi.fn>
+
+			beforeEach(() => {
+				vi.clearAllMocks()
+				// Reset the global timestamp before each test
+				Task.resetGlobalApiRequestTime()
+
+				mockApiConfig = {
+					apiProvider: "anthropic",
+					apiKey: "test-key",
+					rateLimitSeconds: 5,
+				}
+
+				mockProvider = {
+					context: {
+						globalStorageUri: { fsPath: "/test/storage" },
+					},
+					getState: vi.fn().mockResolvedValue({
+						apiConfiguration: mockApiConfig,
+					}),
+					say: vi.fn(),
+					postStateToWebview: vi.fn().mockResolvedValue(undefined),
+					postMessageToWebview: vi.fn().mockResolvedValue(undefined),
+					updateTaskHistory: vi.fn().mockResolvedValue(undefined),
+				}
+
+				// Get the mocked delay function
+				mockDelay = delay as ReturnType<typeof vi.fn>
+				mockDelay.mockClear()
+			})
+
+			afterEach(() => {
+				// Clean up the global state after each test
+				Task.resetGlobalApiRequestTime()
+			})
+
+			it("should enforce rate limiting across parent and subtask", async () => {
+				// Add a spy to track getState calls
+				const getStateSpy = vi.spyOn(mockProvider, "getState")
+
+				// Create parent task
+				const parent = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "parent task",
+					startTask: false,
+				})
+
+				// Mock the API stream response
+				const mockStream = {
+					async *[Symbol.asyncIterator]() {
+						yield { type: "text", text: "parent response" }
+					},
+					async next() {
+						return { done: true, value: { type: "text", text: "parent response" } }
+					},
+					async return() {
+						return { done: true, value: undefined }
+					},
+					async throw(e: any) {
+						throw e
+					},
+					[Symbol.asyncDispose]: async () => {},
+				} as AsyncGenerator<ApiStreamChunk>
+
+				vi.spyOn(parent.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the parent task
+				const parentIterator = parent.attemptApiRequest(0)
+				await parentIterator.next()
+
+				// Verify no delay was applied for the first request
+				expect(mockDelay).not.toHaveBeenCalled()
+
+				// Create a subtask immediately after
+				const child = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "child task",
+					parentTask: parent,
+					rootTask: parent,
+					startTask: false,
+				})
+
+				// Mock the child's API stream
+				const childMockStream = {
+					async *[Symbol.asyncIterator]() {
+						yield { type: "text", text: "child response" }
+					},
+					async next() {
+						return { done: true, value: { type: "text", text: "child response" } }
+					},
+					async return() {
+						return { done: true, value: undefined }
+					},
+					async throw(e: any) {
+						throw e
+					},
+					[Symbol.asyncDispose]: async () => {},
+				} as AsyncGenerator<ApiStreamChunk>
+
+				vi.spyOn(child.api, "createMessage").mockReturnValue(childMockStream)
+
+				// Make an API request with the child task
+				const childIterator = child.attemptApiRequest(0)
+				await childIterator.next()
+
+				// Verify rate limiting was applied
+				expect(mockDelay).toHaveBeenCalledTimes(mockApiConfig.rateLimitSeconds)
+				expect(mockDelay).toHaveBeenCalledWith(1000)
+			}, 10000) // Increase timeout to 10 seconds
+
+			it("should not apply rate limiting if enough time has passed", async () => {
+				// Create parent task
+				const parent = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "parent task",
+					startTask: false,
+				})
+
+				// Mock the API stream response
+				const mockStream = {
+					async *[Symbol.asyncIterator]() {
+						yield { type: "text", text: "response" }
+					},
+					async next() {
+						return { done: true, value: { type: "text", text: "response" } }
+					},
+					async return() {
+						return { done: true, value: undefined }
+					},
+					async throw(e: any) {
+						throw e
+					},
+					[Symbol.asyncDispose]: async () => {},
+				} as AsyncGenerator<ApiStreamChunk>
+
+				vi.spyOn(parent.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the parent task
+				const parentIterator = parent.attemptApiRequest(0)
+				await parentIterator.next()
+
+				// Simulate time passing (more than rate limit)
+				const originalDateNow = Date.now
+				const mockTime = Date.now() + (mockApiConfig.rateLimitSeconds + 1) * 1000
+				Date.now = vi.fn(() => mockTime)
+
+				// Create a subtask after time has passed
+				const child = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "child task",
+					parentTask: parent,
+					rootTask: parent,
+					startTask: false,
+				})
+
+				vi.spyOn(child.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the child task
+				const childIterator = child.attemptApiRequest(0)
+				await childIterator.next()
+
+				// Verify no rate limiting was applied
+				expect(mockDelay).not.toHaveBeenCalled()
+
+				// Restore Date.now
+				Date.now = originalDateNow
+			})
+
+			it("should share rate limiting across multiple subtasks", async () => {
+				// Create parent task
+				const parent = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "parent task",
+					startTask: false,
+				})
+
+				// Mock the API stream response
+				const mockStream = {
+					async *[Symbol.asyncIterator]() {
+						yield { type: "text", text: "response" }
+					},
+					async next() {
+						return { done: true, value: { type: "text", text: "response" } }
+					},
+					async return() {
+						return { done: true, value: undefined }
+					},
+					async throw(e: any) {
+						throw e
+					},
+					[Symbol.asyncDispose]: async () => {},
+				} as AsyncGenerator<ApiStreamChunk>
+
+				vi.spyOn(parent.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the parent task
+				const parentIterator = parent.attemptApiRequest(0)
+				await parentIterator.next()
+
+				// Create first subtask
+				const child1 = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "child task 1",
+					parentTask: parent,
+					rootTask: parent,
+					startTask: false,
+				})
+
+				vi.spyOn(child1.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the first child task
+				const child1Iterator = child1.attemptApiRequest(0)
+				await child1Iterator.next()
+
+				// Verify rate limiting was applied
+				const firstDelayCount = mockDelay.mock.calls.length
+				expect(firstDelayCount).toBe(mockApiConfig.rateLimitSeconds)
+
+				// Clear the mock to count new delays
+				mockDelay.mockClear()
+
+				// Create second subtask immediately after
+				const child2 = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "child task 2",
+					parentTask: parent,
+					rootTask: parent,
+					startTask: false,
+				})
+
+				vi.spyOn(child2.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the second child task
+				const child2Iterator = child2.attemptApiRequest(0)
+				await child2Iterator.next()
+
+				// Verify rate limiting was applied again
+				expect(mockDelay).toHaveBeenCalledTimes(mockApiConfig.rateLimitSeconds)
+			}, 15000) // Increase timeout to 15 seconds
+
+			it("should handle rate limiting with zero rate limit", async () => {
+				// Update config to have zero rate limit
+				mockApiConfig.rateLimitSeconds = 0
+				mockProvider.getState.mockResolvedValue({
+					apiConfiguration: mockApiConfig,
+				})
+
+				// Create parent task
+				const parent = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "parent task",
+					startTask: false,
+				})
+
+				// Mock the API stream response
+				const mockStream = {
+					async *[Symbol.asyncIterator]() {
+						yield { type: "text", text: "response" }
+					},
+					async next() {
+						return { done: true, value: { type: "text", text: "response" } }
+					},
+					async return() {
+						return { done: true, value: undefined }
+					},
+					async throw(e: any) {
+						throw e
+					},
+					[Symbol.asyncDispose]: async () => {},
+				} as AsyncGenerator<ApiStreamChunk>
+
+				vi.spyOn(parent.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the parent task
+				const parentIterator = parent.attemptApiRequest(0)
+				await parentIterator.next()
+
+				// Create a subtask
+				const child = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "child task",
+					parentTask: parent,
+					rootTask: parent,
+					startTask: false,
+				})
+
+				vi.spyOn(child.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request with the child task
+				const childIterator = child.attemptApiRequest(0)
+				await childIterator.next()
+
+				// Verify no delay was applied
+				expect(mockDelay).not.toHaveBeenCalled()
+			})
+
+			it("should update global timestamp even when no rate limiting is needed", async () => {
+				// Create task
+				const task = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "test task",
+					startTask: false,
+				})
+
+				// Mock the API stream response
+				const mockStream = {
+					async *[Symbol.asyncIterator]() {
+						yield { type: "text", text: "response" }
+					},
+					async next() {
+						return { done: true, value: { type: "text", text: "response" } }
+					},
+					async return() {
+						return { done: true, value: undefined }
+					},
+					async throw(e: any) {
+						throw e
+					},
+					[Symbol.asyncDispose]: async () => {},
+				} as AsyncGenerator<ApiStreamChunk>
+
+				vi.spyOn(task.api, "createMessage").mockReturnValue(mockStream)
+
+				// Make an API request
+				const iterator = task.attemptApiRequest(0)
+				await iterator.next()
+
+				// Access the private static property via reflection for testing
+				const globalTimestamp = (Task as any).lastGlobalApiRequestTime
+				expect(globalTimestamp).toBeDefined()
+				expect(globalTimestamp).toBeGreaterThan(0)
+			})
+		})
+
 		describe("Dynamic Strategy Selection", () => {
 			let mockProvider: any
 			let mockApiConfig: any