Dax Raad 9 months ago
parent
commit
ff786d9139
4 changed files with 275 additions and 5 deletions
  1. 39 0
      js/src/server/server.ts
  2. 23 1
      js/src/session/session.ts
  3. 52 2
      pkg/client/gen/openapi.json
  4. 161 2
      pkg/client/generated-client.go

+ 39 - 0
js/src/server/server.ts

@@ -158,8 +158,47 @@ export namespace Server {
           return c.json(sessions);
           return c.json(sessions);
         },
         },
       )
       )
+      .post(
+        "/session_abort",
+        describeRoute({
+          description: "Abort a session",
+          responses: {
+            200: {
+              description: "Aborted session",
+              content: {
+                "application/json": {
+                  schema: resolver(z.boolean()),
+                },
+              },
+            },
+          },
+        }),
+        zValidator(
+          "json",
+          z.object({
+            sessionID: z.string(),
+          }),
+        ),
+        async (c) => {
+          const body = c.req.valid("json");
+          return c.json(Session.abort(body.sessionID));
+        },
+      )
       .post(
       .post(
         "/session_chat",
         "/session_chat",
+        describeRoute({
+          description: "Chat with a model",
+          responses: {
+            200: {
+              description: "Chat with a model",
+              content: {
+                "application/json": {
+                  schema: resolver(SessionMessage),
+                },
+              },
+            },
+          },
+        }),
         zValidator(
         zValidator(
           "json",
           "json",
           z.object({
           z.object({

+ 23 - 1
js/src/session/session.ts

@@ -129,6 +129,16 @@ export namespace Session {
     }
     }
   }
   }
 
 
+  const pending = new Map<string, AbortController>();
+
+  export function abort(sessionID: string) {
+    const controller = pending.get(sessionID);
+    if (!controller) return false;
+    controller.abort();
+    pending.delete(sessionID);
+    return true;
+  }
+
   export async function chat(input: {
   export async function chat(input: {
     sessionID: string;
     sessionID: string;
     providerID: string;
     providerID: string;
@@ -225,6 +235,8 @@ export namespace Session {
         tool: {},
         tool: {},
       },
       },
     };
     };
+    const controller = new AbortController();
+    pending.set(input.sessionID, controller);
     const result = streamText({
     const result = streamText({
       onStepFinish: (step) => {
       onStepFinish: (step) => {
         update(input.sessionID, (draft) => {
         update(input.sessionID, (draft) => {
@@ -240,6 +252,8 @@ export namespace Session {
             .toNumber();
             .toNumber();
         });
         });
       },
       },
+      abortSignal: controller.signal,
+      maxRetries: 6,
       stopWhen: stepCountIs(1000),
       stopWhen: stepCountIs(1000),
       messages: convertToModelMessages(msgs),
       messages: convertToModelMessages(msgs),
       temperature: 0,
       temperature: 0,
@@ -251,7 +265,14 @@ export namespace Session {
     let text: TextUIPart | undefined;
     let text: TextUIPart | undefined;
     const reader = result.toUIMessageStream().getReader();
     const reader = result.toUIMessageStream().getReader();
     while (true) {
     while (true) {
-      const { done, value } = await reader.read();
+      const result = await reader.read().catch((e) => {
+        if (e instanceof DOMException && e.name === "AbortError") {
+          return;
+        }
+        throw e;
+      });
+      if (!result) break;
+      const { done, value } = result;
       if (done) break;
       if (done) break;
       l.info("part", {
       l.info("part", {
         type: value.type,
         type: value.type,
@@ -316,6 +337,7 @@ export namespace Session {
       }
       }
       await write(next);
       await write(next);
     }
     }
+    pending.delete(input.sessionID);
     next.metadata!.time.completed = Date.now();
     next.metadata!.time.completed = Date.now();
     await write(next);
     await write(next);
     return next;
     return next;

+ 52 - 2
pkg/client/gen/openapi.json

@@ -161,11 +161,59 @@
         "description": "List all sessions"
         "description": "List all sessions"
       }
       }
     },
     },
+    "/session_abort": {
+      "post": {
+        "responses": {
+          "200": {
+            "description": "Aborted session",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "type": "boolean"
+                }
+              }
+            }
+          }
+        },
+        "operationId": "postSession_abort",
+        "parameters": [],
+        "description": "Abort a session",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "type": "object",
+                "properties": {
+                  "sessionID": {
+                    "type": "string"
+                  }
+                },
+                "required": [
+                  "sessionID"
+                ]
+              }
+            }
+          }
+        }
+      }
+    },
     "/session_chat": {
     "/session_chat": {
       "post": {
       "post": {
-        "responses": {},
+        "responses": {
+          "200": {
+            "description": "Chat with a model",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/Session.Message"
+                }
+              }
+            }
+          }
+        },
         "operationId": "postSession_chat",
         "operationId": "postSession_chat",
         "parameters": [],
         "parameters": [],
+        "description": "Chat with a model",
         "requestBody": {
         "requestBody": {
           "content": {
           "content": {
             "application/json": {
             "application/json": {
@@ -628,12 +676,14 @@
                 },
                 },
                 "attachment": {
                 "attachment": {
                   "type": "boolean"
                   "type": "boolean"
+                },
+                "reasoning": {
+                  "type": "boolean"
                 }
                 }
               },
               },
               "required": [
               "required": [
                 "cost",
                 "cost",
                 "contextWindow",
                 "contextWindow",
-                "maxTokens",
                 "attachment"
                 "attachment"
               ]
               ]
             }
             }

+ 161 - 2
pkg/client/generated-client.go

@@ -35,8 +35,9 @@ type ProviderInfo struct {
 			Output       float32 `json:"output"`
 			Output       float32 `json:"output"`
 			OutputCached float32 `json:"outputCached"`
 			OutputCached float32 `json:"outputCached"`
 		} `json:"cost"`
 		} `json:"cost"`
-		MaxTokens float32 `json:"maxTokens"`
-		Name      *string `json:"name,omitempty"`
+		MaxTokens *float32 `json:"maxTokens,omitempty"`
+		Name      *string  `json:"name,omitempty"`
+		Reasoning *bool    `json:"reasoning,omitempty"`
 	} `json:"models"`
 	} `json:"models"`
 	Options *map[string]interface{} `json:"options,omitempty"`
 	Options *map[string]interface{} `json:"options,omitempty"`
 }
 }
@@ -151,6 +152,11 @@ type SessionMessageToolInvocationToolResult struct {
 	ToolName   string                 `json:"toolName"`
 	ToolName   string                 `json:"toolName"`
 }
 }
 
 
+// PostSessionAbortJSONBody defines parameters for PostSessionAbort.
+type PostSessionAbortJSONBody struct {
+	SessionID string `json:"sessionID"`
+}
+
 // PostSessionChatJSONBody defines parameters for PostSessionChat.
 // PostSessionChatJSONBody defines parameters for PostSessionChat.
 type PostSessionChatJSONBody struct {
 type PostSessionChatJSONBody struct {
 	ModelID    string               `json:"modelID"`
 	ModelID    string               `json:"modelID"`
@@ -169,6 +175,9 @@ type PostSessionShareJSONBody struct {
 	SessionID string `json:"sessionID"`
 	SessionID string `json:"sessionID"`
 }
 }
 
 
+// PostSessionAbortJSONRequestBody defines body for PostSessionAbort for application/json ContentType.
+type PostSessionAbortJSONRequestBody PostSessionAbortJSONBody
+
 // PostSessionChatJSONRequestBody defines body for PostSessionChat for application/json ContentType.
 // PostSessionChatJSONRequestBody defines body for PostSessionChat for application/json ContentType.
 type PostSessionChatJSONRequestBody PostSessionChatJSONBody
 type PostSessionChatJSONRequestBody PostSessionChatJSONBody
 
 
@@ -582,6 +591,11 @@ type ClientInterface interface {
 	// PostProviderList request
 	// PostProviderList request
 	PostProviderList(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error)
 	PostProviderList(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error)
 
 
+	// PostSessionAbortWithBody request with any body
+	PostSessionAbortWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error)
+
+	PostSessionAbort(ctx context.Context, body PostSessionAbortJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error)
+
 	// PostSessionChatWithBody request with any body
 	// PostSessionChatWithBody request with any body
 	PostSessionChatWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error)
 	PostSessionChatWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error)
 
 
@@ -616,6 +630,30 @@ func (c *Client) PostProviderList(ctx context.Context, reqEditors ...RequestEdit
 	return c.Client.Do(req)
 	return c.Client.Do(req)
 }
 }
 
 
+func (c *Client) PostSessionAbortWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) {
+	req, err := NewPostSessionAbortRequestWithBody(c.Server, contentType, body)
+	if err != nil {
+		return nil, err
+	}
+	req = req.WithContext(ctx)
+	if err := c.applyEditors(ctx, req, reqEditors); err != nil {
+		return nil, err
+	}
+	return c.Client.Do(req)
+}
+
+func (c *Client) PostSessionAbort(ctx context.Context, body PostSessionAbortJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) {
+	req, err := NewPostSessionAbortRequest(c.Server, body)
+	if err != nil {
+		return nil, err
+	}
+	req = req.WithContext(ctx)
+	if err := c.applyEditors(ctx, req, reqEditors); err != nil {
+		return nil, err
+	}
+	return c.Client.Do(req)
+}
+
 func (c *Client) PostSessionChatWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) {
 func (c *Client) PostSessionChatWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) {
 	req, err := NewPostSessionChatRequestWithBody(c.Server, contentType, body)
 	req, err := NewPostSessionChatRequestWithBody(c.Server, contentType, body)
 	if err != nil {
 	if err != nil {
@@ -739,6 +777,46 @@ func NewPostProviderListRequest(server string) (*http.Request, error) {
 	return req, nil
 	return req, nil
 }
 }
 
 
+// NewPostSessionAbortRequest calls the generic PostSessionAbort builder with application/json body
+func NewPostSessionAbortRequest(server string, body PostSessionAbortJSONRequestBody) (*http.Request, error) {
+	var bodyReader io.Reader
+	buf, err := json.Marshal(body)
+	if err != nil {
+		return nil, err
+	}
+	bodyReader = bytes.NewReader(buf)
+	return NewPostSessionAbortRequestWithBody(server, "application/json", bodyReader)
+}
+
+// NewPostSessionAbortRequestWithBody generates requests for PostSessionAbort with any type of body
+func NewPostSessionAbortRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) {
+	var err error
+
+	serverURL, err := url.Parse(server)
+	if err != nil {
+		return nil, err
+	}
+
+	operationPath := fmt.Sprintf("/session_abort")
+	if operationPath[0] == '/' {
+		operationPath = "." + operationPath
+	}
+
+	queryURL, err := serverURL.Parse(operationPath)
+	if err != nil {
+		return nil, err
+	}
+
+	req, err := http.NewRequest("POST", queryURL.String(), body)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Add("Content-Type", contentType)
+
+	return req, nil
+}
+
 // NewPostSessionChatRequest calls the generic PostSessionChat builder with application/json body
 // NewPostSessionChatRequest calls the generic PostSessionChat builder with application/json body
 func NewPostSessionChatRequest(server string, body PostSessionChatJSONRequestBody) (*http.Request, error) {
 func NewPostSessionChatRequest(server string, body PostSessionChatJSONRequestBody) (*http.Request, error) {
 	var bodyReader io.Reader
 	var bodyReader io.Reader
@@ -959,6 +1037,11 @@ type ClientWithResponsesInterface interface {
 	// PostProviderListWithResponse request
 	// PostProviderListWithResponse request
 	PostProviderListWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*PostProviderListResponse, error)
 	PostProviderListWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*PostProviderListResponse, error)
 
 
+	// PostSessionAbortWithBodyWithResponse request with any body
+	PostSessionAbortWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionAbortResponse, error)
+
+	PostSessionAbortWithResponse(ctx context.Context, body PostSessionAbortJSONRequestBody, reqEditors ...RequestEditorFn) (*PostSessionAbortResponse, error)
+
 	// PostSessionChatWithBodyWithResponse request with any body
 	// PostSessionChatWithBodyWithResponse request with any body
 	PostSessionChatWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionChatResponse, error)
 	PostSessionChatWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionChatResponse, error)
 
 
@@ -1003,9 +1086,32 @@ func (r PostProviderListResponse) StatusCode() int {
 	return 0
 	return 0
 }
 }
 
 
+type PostSessionAbortResponse struct {
+	Body         []byte
+	HTTPResponse *http.Response
+	JSON200      *bool
+}
+
+// Status returns HTTPResponse.Status
+func (r PostSessionAbortResponse) Status() string {
+	if r.HTTPResponse != nil {
+		return r.HTTPResponse.Status
+	}
+	return http.StatusText(0)
+}
+
+// StatusCode returns HTTPResponse.StatusCode
+func (r PostSessionAbortResponse) StatusCode() int {
+	if r.HTTPResponse != nil {
+		return r.HTTPResponse.StatusCode
+	}
+	return 0
+}
+
 type PostSessionChatResponse struct {
 type PostSessionChatResponse struct {
 	Body         []byte
 	Body         []byte
 	HTTPResponse *http.Response
 	HTTPResponse *http.Response
+	JSON200      *SessionMessage
 }
 }
 
 
 // Status returns HTTPResponse.Status
 // Status returns HTTPResponse.Status
@@ -1131,6 +1237,23 @@ func (c *ClientWithResponses) PostProviderListWithResponse(ctx context.Context,
 	return ParsePostProviderListResponse(rsp)
 	return ParsePostProviderListResponse(rsp)
 }
 }
 
 
+// PostSessionAbortWithBodyWithResponse request with arbitrary body returning *PostSessionAbortResponse
+func (c *ClientWithResponses) PostSessionAbortWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionAbortResponse, error) {
+	rsp, err := c.PostSessionAbortWithBody(ctx, contentType, body, reqEditors...)
+	if err != nil {
+		return nil, err
+	}
+	return ParsePostSessionAbortResponse(rsp)
+}
+
+func (c *ClientWithResponses) PostSessionAbortWithResponse(ctx context.Context, body PostSessionAbortJSONRequestBody, reqEditors ...RequestEditorFn) (*PostSessionAbortResponse, error) {
+	rsp, err := c.PostSessionAbort(ctx, body, reqEditors...)
+	if err != nil {
+		return nil, err
+	}
+	return ParsePostSessionAbortResponse(rsp)
+}
+
 // PostSessionChatWithBodyWithResponse request with arbitrary body returning *PostSessionChatResponse
 // PostSessionChatWithBodyWithResponse request with arbitrary body returning *PostSessionChatResponse
 func (c *ClientWithResponses) PostSessionChatWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionChatResponse, error) {
 func (c *ClientWithResponses) PostSessionChatWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionChatResponse, error) {
 	rsp, err := c.PostSessionChatWithBody(ctx, contentType, body, reqEditors...)
 	rsp, err := c.PostSessionChatWithBody(ctx, contentType, body, reqEditors...)
@@ -1226,6 +1349,32 @@ func ParsePostProviderListResponse(rsp *http.Response) (*PostProviderListRespons
 	return response, nil
 	return response, nil
 }
 }
 
 
+// ParsePostSessionAbortResponse parses an HTTP response from a PostSessionAbortWithResponse call
+func ParsePostSessionAbortResponse(rsp *http.Response) (*PostSessionAbortResponse, error) {
+	bodyBytes, err := io.ReadAll(rsp.Body)
+	defer func() { _ = rsp.Body.Close() }()
+	if err != nil {
+		return nil, err
+	}
+
+	response := &PostSessionAbortResponse{
+		Body:         bodyBytes,
+		HTTPResponse: rsp,
+	}
+
+	switch {
+	case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200:
+		var dest bool
+		if err := json.Unmarshal(bodyBytes, &dest); err != nil {
+			return nil, err
+		}
+		response.JSON200 = &dest
+
+	}
+
+	return response, nil
+}
+
 // ParsePostSessionChatResponse parses an HTTP response from a PostSessionChatWithResponse call
 // ParsePostSessionChatResponse parses an HTTP response from a PostSessionChatWithResponse call
 func ParsePostSessionChatResponse(rsp *http.Response) (*PostSessionChatResponse, error) {
 func ParsePostSessionChatResponse(rsp *http.Response) (*PostSessionChatResponse, error) {
 	bodyBytes, err := io.ReadAll(rsp.Body)
 	bodyBytes, err := io.ReadAll(rsp.Body)
@@ -1239,6 +1388,16 @@ func ParsePostSessionChatResponse(rsp *http.Response) (*PostSessionChatResponse,
 		HTTPResponse: rsp,
 		HTTPResponse: rsp,
 	}
 	}
 
 
+	switch {
+	case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200:
+		var dest SessionMessage
+		if err := json.Unmarshal(bodyBytes, &dest); err != nil {
+			return nil, err
+		}
+		response.JSON200 = &dest
+
+	}
+
 	return response, nil
 	return response, nil
 }
 }