Просмотр исходного кода

Add pass / fail events for evals (#2656)

Chris Estreich 8 месяцев назад
Родитель
Сommit
a4d2de4534

+ 24 - 21
evals/apps/cli/src/index.ts

@@ -16,6 +16,7 @@ import {
 	IpcMessageType,
 	TaskCommandName,
 	rooCodeDefaults,
+	EvalEventName,
 } from "@evals/types"
 import {
 	type Run,
@@ -34,7 +35,7 @@ import { IpcServer, IpcClient } from "@evals/ipc"
 import { __dirname, extensionDevelopmentPath, exercisesPath } from "./paths.js"
 import { getExercises } from "./exercises.js"
 
-type TaskResult = { success: boolean; retry: boolean }
+type TaskResult = { success: boolean }
 type TaskPromise = Promise<TaskResult>
 
 const TASK_START_DELAY = 10 * 1_000
@@ -116,24 +117,25 @@ const run = async (toolbox: GluegunToolbox) => {
 
 	const runningPromises: TaskPromise[] = []
 
-	// Retries aren't implemented yet, but the return values are set up to
-	// support them.
 	const processTask = async (task: Task, delay = 0) => {
 		if (task.finishedAt === null) {
 			await new Promise((resolve) => setTimeout(resolve, delay))
-			const { retry } = await runExercise({ run, task, server })
-
-			if (retry) {
-				return { success: false, retry: true }
-			}
+			await runExercise({ run, task, server })
 		}
 
 		if (task.passed === null) {
 			const passed = await runUnitTest({ task })
 			await updateTask(task.id, { passed })
-			return { success: passed, retry: false }
+
+			server.broadcast({
+				type: IpcMessageType.TaskEvent,
+				origin: IpcOrigin.Server,
+				data: { eventName: passed ? EvalEventName.Pass : EvalEventName.Fail, taskId: task.id },
+			})
+
+			return { success: passed }
 		} else {
-			return { success: task.passed, retry: false }
+			return { success: task.passed }
 		}
 	}
 
@@ -200,7 +202,7 @@ const runExercise = async ({ run, task, server }: { run: Run; task: Task; server
 	} catch (error) {
 		console.log(`${Date.now()} [cli#runExercise | ${language} / ${exercise}] unable to connect`)
 		client.disconnect()
-		return { success: false, retry: false }
+		return { success: false }
 	}
 
 	let taskStartedAt = Date.now()
@@ -209,16 +211,15 @@ const runExercise = async ({ run, task, server }: { run: Run; task: Task; server
 	let rooTaskId: string | undefined
 	let isClientDisconnected = false
 
-	const ignoreEvents: RooCodeEventName[] = [
-		RooCodeEventName.Message,
-		RooCodeEventName.TaskTokenUsageUpdated,
-		RooCodeEventName.TaskAskResponded,
-	]
+	const ignoreEvents: Record<"broadcast" | "log", (RooCodeEventName | EvalEventName)[]> = {
+		broadcast: [RooCodeEventName.Message],
+		log: [RooCodeEventName.Message, RooCodeEventName.TaskTokenUsageUpdated, RooCodeEventName.TaskAskResponded],
+	}
 
 	client.on(IpcMessageType.TaskEvent, async (taskEvent) => {
 		const { eventName, payload } = taskEvent
 
-		if (taskEvent.eventName !== RooCodeEventName.Message) {
+		if (!ignoreEvents.broadcast.includes(eventName)) {
 			server.broadcast({
 				type: IpcMessageType.TaskEvent,
 				origin: IpcOrigin.Server,
@@ -227,7 +228,7 @@ const runExercise = async ({ run, task, server }: { run: Run; task: Task; server
 			})
 		}
 
-		if (!ignoreEvents.includes(eventName)) {
+		if (!ignoreEvents.log.includes(eventName)) {
 			console.log(
 				`${Date.now()} [cli#runExercise | ${language} / ${exercise}] taskEvent -> ${eventName}`,
 				payload,
@@ -320,11 +321,10 @@ const runExercise = async ({ run, task, server }: { run: Run; task: Task; server
 				data: { commandName: TaskCommandName.CancelTask, data: rooTaskId },
 			})
 
-			// Give the server some time to cancel the task.
+			// Allow some time for the task to cancel.
 			await new Promise((resolve) => setTimeout(resolve, 5_000))
 		}
 
-		// TODO: Notify clients that the task timed out.
 		await updateTask(task.id, { finishedAt: new Date() })
 	}
 
@@ -336,12 +336,15 @@ const runExercise = async ({ run, task, server }: { run: Run; task: Task; server
 				clientId: client.clientId!,
 				data: { commandName: TaskCommandName.CloseTask, data: rooTaskId },
 			})
+
+			// Allow some time for the window to close.
+			await new Promise((resolve) => setTimeout(resolve, 2_000))
 		}
 
 		client.disconnect()
 	}
 
-	return { success: !!taskFinishedAt, retry: false }
+	return { success: !!taskFinishedAt }
 }
 
 const runUnitTest = async ({ task }: { task: Task }) => {

+ 5 - 5
evals/apps/web/src/hooks/use-run-status.ts

@@ -1,7 +1,7 @@
 import { useState, useCallback, useRef } from "react"
 import { useQuery, keepPreviousData } from "@tanstack/react-query"
 
-import { RooCodeEventName, taskEventSchema, TokenUsage } from "@evals/types"
+import { TokenUsage, taskEventSchema, RooCodeEventName, EvalEventName } from "@evals/types"
 import { Run } from "@evals/db"
 
 import { getTasks } from "@/lib/server/tasks"
@@ -51,10 +51,6 @@ export const useRunStatus = (run: Run) => {
 			case RooCodeEventName.TaskStarted:
 				startTimes.current.set(taskId, Date.now())
 				break
-			case RooCodeEventName.TaskCompleted:
-			case RooCodeEventName.TaskAborted:
-				setTasksUpdatedAt(Date.now())
-				break
 			case RooCodeEventName.TaskTokenUsageUpdated: {
 				const startTime = startTimes.current.get(taskId)
 				const duration = startTime ? Date.now() - startTime : undefined
@@ -62,6 +58,10 @@ export const useRunStatus = (run: Run) => {
 				setUsageUpdatedAt(Date.now())
 				break
 			}
+			case EvalEventName.Pass:
+			case EvalEventName.Fail:
+				setTasksUpdatedAt(Date.now())
+				break
 		}
 	}, [])
 

+ 16 - 5
evals/packages/types/src/ipc.ts

@@ -50,12 +50,12 @@ export type TaskCommand = z.infer<typeof taskCommandSchema>
  * TaskEvent
  */
 
+export enum EvalEventName {
+	Pass = "pass",
+	Fail = "fail",
+}
+
 export const taskEventSchema = z.discriminatedUnion("eventName", [
-	z.object({
-		eventName: z.literal(RooCodeEventName.Connect),
-		payload: z.unknown(),
-		taskId: z.number(),
-	}),
 	z.object({
 		eventName: z.literal(RooCodeEventName.Message),
 		payload: rooCodeEventsSchema.shape[RooCodeEventName.Message],
@@ -111,6 +111,16 @@ export const taskEventSchema = z.discriminatedUnion("eventName", [
 		payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskTokenUsageUpdated],
 		taskId: z.number().optional(),
 	}),
+	z.object({
+		eventName: z.literal(EvalEventName.Pass),
+		payload: z.undefined(),
+		taskId: z.number(),
+	}),
+	z.object({
+		eventName: z.literal(EvalEventName.Fail),
+		payload: z.undefined(),
+		taskId: z.number(),
+	}),
 ])
 
 export type TaskEvent = z.infer<typeof taskEventSchema>
@@ -125,6 +135,7 @@ export enum IpcMessageType {
 	Ack = "Ack",
 	TaskCommand = "TaskCommand",
 	TaskEvent = "TaskEvent",
+	EvalEvent = "EvalEvent",
 }
 
 export enum IpcOrigin {