Browse Source

better discriminator

Dax Raad 9 months ago
parent
commit
f3b224090c
3 changed files with 94 additions and 57 deletions
  1. 2 13
      js/src/server/message.ts
  2. 23 17
      pkg/client/gen/openapi.json
  3. 69 27
      pkg/client/generated-client.go

+ 2 - 13
js/src/server/message.ts

@@ -38,7 +38,7 @@ const ToolResult = z
   });
 
 const ToolInvocation = z
-  .union([ToolCall, ToolPartialCall, ToolResult])
+  .discriminatedUnion("state", [ToolCall, ToolPartialCall, ToolResult])
   .openapi({
     ref: "Session.Message.ToolInvocation",
   });
@@ -103,25 +103,14 @@ const StepStartPart = z
     ref: "Session.Message.Part.StepStart",
   });
 
-const DataPart = z
-  .object({
-    type: z.custom<`data-${string}`>(),
-    id: z.string().optional(),
-    data: z.unknown(),
-  })
-  .openapi({
-    ref: "Session.Message.Part.Data",
-  });
-
 const Part = z
-  .union([
+  .discriminatedUnion("type", [
     TextPart,
     ReasoningPart,
     ToolInvocationPart,
     SourceUrlPart,
     FilePart,
     StepStartPart,
-    DataPart,
   ])
   .openapi({
     ref: "Session.Message.Part",

+ 23 - 17
pkg/client/gen/openapi.json

@@ -321,7 +321,7 @@
         ]
       },
       "Session.Message.Part": {
-        "anyOf": [
+        "oneOf": [
           {
             "$ref": "#/components/schemas/Session.Message.Part.Text"
           },
@@ -339,11 +339,19 @@
           },
           {
             "$ref": "#/components/schemas/Session.Message.Part.StepStart"
-          },
-          {
-            "$ref": "#/components/schemas/Session.Message.Part.Data"
           }
-        ]
+        ],
+        "discriminator": {
+          "propertyName": "type",
+          "mapping": {
+            "text": "#/components/schemas/Session.Message.Part.Text",
+            "reasoning": "#/components/schemas/Session.Message.Part.Reasoning",
+            "tool-invocation": "#/components/schemas/Session.Message.Part.ToolInvocation",
+            "source-url": "#/components/schemas/Session.Message.Part.SourceUrl",
+            "file": "#/components/schemas/Session.Message.Part.File",
+            "step-start": "#/components/schemas/Session.Message.Part.StepStart"
+          }
+        }
       },
       "Session.Message.Part.Text": {
         "type": "object",
@@ -398,7 +406,7 @@
         ]
       },
       "Session.Message.ToolInvocation": {
-        "anyOf": [
+        "oneOf": [
           {
             "$ref": "#/components/schemas/Session.Message.ToolInvocation.ToolCall"
           },
@@ -408,7 +416,15 @@
           {
             "$ref": "#/components/schemas/Session.Message.ToolInvocation.ToolResult"
           }
-        ]
+        ],
+        "discriminator": {
+          "propertyName": "state",
+          "mapping": {
+            "call": "#/components/schemas/Session.Message.ToolInvocation.ToolCall",
+            "partial-call": "#/components/schemas/Session.Message.ToolInvocation.ToolPartialCall",
+            "result": "#/components/schemas/Session.Message.ToolInvocation.ToolResult"
+          }
+        }
       },
       "Session.Message.ToolInvocation.ToolCall": {
         "type": "object",
@@ -560,16 +576,6 @@
           "type"
         ]
       },
-      "Session.Message.Part.Data": {
-        "type": "object",
-        "properties": {
-          "type": {},
-          "id": {
-            "type": "string"
-          },
-          "data": {}
-        }
-      },
       "Provider.Info": {
         "type": "object",
         "properties": {

+ 69 - 27
pkg/client/generated-client.go

@@ -7,6 +7,7 @@ import (
 	"bytes"
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
@@ -75,13 +76,6 @@ type SessionMessagePart struct {
 	union json.RawMessage
 }
 
-// SessionMessagePartData defines model for Session.Message.Part.Data.
-type SessionMessagePartData struct {
-	Data *interface{} `json:"data,omitempty"`
-	Id   *string      `json:"id,omitempty"`
-	Type *interface{} `json:"type,omitempty"`
-}
-
 // SessionMessagePartFile defines model for Session.Message.Part.File.
 type SessionMessagePartFile struct {
 	Filename  *string `json:"filename,omitempty"`
@@ -192,6 +186,7 @@ func (t SessionMessagePart) AsSessionMessagePartText() (SessionMessagePartText,
 
 // FromSessionMessagePartText overwrites any union data inside the SessionMessagePart as the provided SessionMessagePartText
 func (t *SessionMessagePart) FromSessionMessagePartText(v SessionMessagePartText) error {
+	v.Type = "text"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -199,6 +194,7 @@ func (t *SessionMessagePart) FromSessionMessagePartText(v SessionMessagePartText
 
 // MergeSessionMessagePartText performs a merge with any union data inside the SessionMessagePart, using the provided SessionMessagePartText
 func (t *SessionMessagePart) MergeSessionMessagePartText(v SessionMessagePartText) error {
+	v.Type = "text"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -218,6 +214,7 @@ func (t SessionMessagePart) AsSessionMessagePartReasoning() (SessionMessagePartR
 
 // FromSessionMessagePartReasoning overwrites any union data inside the SessionMessagePart as the provided SessionMessagePartReasoning
 func (t *SessionMessagePart) FromSessionMessagePartReasoning(v SessionMessagePartReasoning) error {
+	v.Type = "reasoning"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -225,6 +222,7 @@ func (t *SessionMessagePart) FromSessionMessagePartReasoning(v SessionMessagePar
 
 // MergeSessionMessagePartReasoning performs a merge with any union data inside the SessionMessagePart, using the provided SessionMessagePartReasoning
 func (t *SessionMessagePart) MergeSessionMessagePartReasoning(v SessionMessagePartReasoning) error {
+	v.Type = "reasoning"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -244,6 +242,7 @@ func (t SessionMessagePart) AsSessionMessagePartToolInvocation() (SessionMessage
 
 // FromSessionMessagePartToolInvocation overwrites any union data inside the SessionMessagePart as the provided SessionMessagePartToolInvocation
 func (t *SessionMessagePart) FromSessionMessagePartToolInvocation(v SessionMessagePartToolInvocation) error {
+	v.Type = "tool-invocation"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -251,6 +250,7 @@ func (t *SessionMessagePart) FromSessionMessagePartToolInvocation(v SessionMessa
 
 // MergeSessionMessagePartToolInvocation performs a merge with any union data inside the SessionMessagePart, using the provided SessionMessagePartToolInvocation
 func (t *SessionMessagePart) MergeSessionMessagePartToolInvocation(v SessionMessagePartToolInvocation) error {
+	v.Type = "tool-invocation"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -270,6 +270,7 @@ func (t SessionMessagePart) AsSessionMessagePartSourceUrl() (SessionMessagePartS
 
 // FromSessionMessagePartSourceUrl overwrites any union data inside the SessionMessagePart as the provided SessionMessagePartSourceUrl
 func (t *SessionMessagePart) FromSessionMessagePartSourceUrl(v SessionMessagePartSourceUrl) error {
+	v.Type = "source-url"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -277,6 +278,7 @@ func (t *SessionMessagePart) FromSessionMessagePartSourceUrl(v SessionMessagePar
 
 // MergeSessionMessagePartSourceUrl performs a merge with any union data inside the SessionMessagePart, using the provided SessionMessagePartSourceUrl
 func (t *SessionMessagePart) MergeSessionMessagePartSourceUrl(v SessionMessagePartSourceUrl) error {
+	v.Type = "source-url"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -296,6 +298,7 @@ func (t SessionMessagePart) AsSessionMessagePartFile() (SessionMessagePartFile,
 
 // FromSessionMessagePartFile overwrites any union data inside the SessionMessagePart as the provided SessionMessagePartFile
 func (t *SessionMessagePart) FromSessionMessagePartFile(v SessionMessagePartFile) error {
+	v.Type = "file"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -303,6 +306,7 @@ func (t *SessionMessagePart) FromSessionMessagePartFile(v SessionMessagePartFile
 
 // MergeSessionMessagePartFile performs a merge with any union data inside the SessionMessagePart, using the provided SessionMessagePartFile
 func (t *SessionMessagePart) MergeSessionMessagePartFile(v SessionMessagePartFile) error {
+	v.Type = "file"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -322,6 +326,7 @@ func (t SessionMessagePart) AsSessionMessagePartStepStart() (SessionMessagePartS
 
 // FromSessionMessagePartStepStart overwrites any union data inside the SessionMessagePart as the provided SessionMessagePartStepStart
 func (t *SessionMessagePart) FromSessionMessagePartStepStart(v SessionMessagePartStepStart) error {
+	v.Type = "step-start"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -329,6 +334,7 @@ func (t *SessionMessagePart) FromSessionMessagePartStepStart(v SessionMessagePar
 
 // MergeSessionMessagePartStepStart performs a merge with any union data inside the SessionMessagePart, using the provided SessionMessagePartStepStart
 func (t *SessionMessagePart) MergeSessionMessagePartStepStart(v SessionMessagePartStepStart) error {
+	v.Type = "step-start"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -339,30 +345,35 @@ func (t *SessionMessagePart) MergeSessionMessagePartStepStart(v SessionMessagePa
 	return err
 }
 
-// AsSessionMessagePartData returns the union data inside the SessionMessagePart as a SessionMessagePartData
-func (t SessionMessagePart) AsSessionMessagePartData() (SessionMessagePartData, error) {
-	var body SessionMessagePartData
-	err := json.Unmarshal(t.union, &body)
-	return body, err
-}
-
-// FromSessionMessagePartData overwrites any union data inside the SessionMessagePart as the provided SessionMessagePartData
-func (t *SessionMessagePart) FromSessionMessagePartData(v SessionMessagePartData) error {
-	b, err := json.Marshal(v)
-	t.union = b
-	return err
+func (t SessionMessagePart) Discriminator() (string, error) {
+	var discriminator struct {
+		Discriminator string `json:"type"`
+	}
+	err := json.Unmarshal(t.union, &discriminator)
+	return discriminator.Discriminator, err
 }
 
-// MergeSessionMessagePartData performs a merge with any union data inside the SessionMessagePart, using the provided SessionMessagePartData
-func (t *SessionMessagePart) MergeSessionMessagePartData(v SessionMessagePartData) error {
-	b, err := json.Marshal(v)
+func (t SessionMessagePart) ValueByDiscriminator() (interface{}, error) {
+	discriminator, err := t.Discriminator()
 	if err != nil {
-		return err
+		return nil, err
+	}
+	switch discriminator {
+	case "file":
+		return t.AsSessionMessagePartFile()
+	case "reasoning":
+		return t.AsSessionMessagePartReasoning()
+	case "source-url":
+		return t.AsSessionMessagePartSourceUrl()
+	case "step-start":
+		return t.AsSessionMessagePartStepStart()
+	case "text":
+		return t.AsSessionMessagePartText()
+	case "tool-invocation":
+		return t.AsSessionMessagePartToolInvocation()
+	default:
+		return nil, errors.New("unknown discriminator value: " + discriminator)
 	}
-
-	merged, err := runtime.JSONMerge(t.union, b)
-	t.union = merged
-	return err
 }
 
 func (t SessionMessagePart) MarshalJSON() ([]byte, error) {
@@ -384,6 +395,7 @@ func (t SessionMessageToolInvocation) AsSessionMessageToolInvocationToolCall() (
 
 // FromSessionMessageToolInvocationToolCall overwrites any union data inside the SessionMessageToolInvocation as the provided SessionMessageToolInvocationToolCall
 func (t *SessionMessageToolInvocation) FromSessionMessageToolInvocationToolCall(v SessionMessageToolInvocationToolCall) error {
+	v.State = "call"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -391,6 +403,7 @@ func (t *SessionMessageToolInvocation) FromSessionMessageToolInvocationToolCall(
 
 // MergeSessionMessageToolInvocationToolCall performs a merge with any union data inside the SessionMessageToolInvocation, using the provided SessionMessageToolInvocationToolCall
 func (t *SessionMessageToolInvocation) MergeSessionMessageToolInvocationToolCall(v SessionMessageToolInvocationToolCall) error {
+	v.State = "call"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -410,6 +423,7 @@ func (t SessionMessageToolInvocation) AsSessionMessageToolInvocationToolPartialC
 
 // FromSessionMessageToolInvocationToolPartialCall overwrites any union data inside the SessionMessageToolInvocation as the provided SessionMessageToolInvocationToolPartialCall
 func (t *SessionMessageToolInvocation) FromSessionMessageToolInvocationToolPartialCall(v SessionMessageToolInvocationToolPartialCall) error {
+	v.State = "partial-call"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -417,6 +431,7 @@ func (t *SessionMessageToolInvocation) FromSessionMessageToolInvocationToolParti
 
 // MergeSessionMessageToolInvocationToolPartialCall performs a merge with any union data inside the SessionMessageToolInvocation, using the provided SessionMessageToolInvocationToolPartialCall
 func (t *SessionMessageToolInvocation) MergeSessionMessageToolInvocationToolPartialCall(v SessionMessageToolInvocationToolPartialCall) error {
+	v.State = "partial-call"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -436,6 +451,7 @@ func (t SessionMessageToolInvocation) AsSessionMessageToolInvocationToolResult()
 
 // FromSessionMessageToolInvocationToolResult overwrites any union data inside the SessionMessageToolInvocation as the provided SessionMessageToolInvocationToolResult
 func (t *SessionMessageToolInvocation) FromSessionMessageToolInvocationToolResult(v SessionMessageToolInvocationToolResult) error {
+	v.State = "result"
 	b, err := json.Marshal(v)
 	t.union = b
 	return err
@@ -443,6 +459,7 @@ func (t *SessionMessageToolInvocation) FromSessionMessageToolInvocationToolResul
 
 // MergeSessionMessageToolInvocationToolResult performs a merge with any union data inside the SessionMessageToolInvocation, using the provided SessionMessageToolInvocationToolResult
 func (t *SessionMessageToolInvocation) MergeSessionMessageToolInvocationToolResult(v SessionMessageToolInvocationToolResult) error {
+	v.State = "result"
 	b, err := json.Marshal(v)
 	if err != nil {
 		return err
@@ -453,6 +470,31 @@ func (t *SessionMessageToolInvocation) MergeSessionMessageToolInvocationToolResu
 	return err
 }
 
+func (t SessionMessageToolInvocation) Discriminator() (string, error) {
+	var discriminator struct {
+		Discriminator string `json:"state"`
+	}
+	err := json.Unmarshal(t.union, &discriminator)
+	return discriminator.Discriminator, err
+}
+
+func (t SessionMessageToolInvocation) ValueByDiscriminator() (interface{}, error) {
+	discriminator, err := t.Discriminator()
+	if err != nil {
+		return nil, err
+	}
+	switch discriminator {
+	case "call":
+		return t.AsSessionMessageToolInvocationToolCall()
+	case "partial-call":
+		return t.AsSessionMessageToolInvocationToolPartialCall()
+	case "result":
+		return t.AsSessionMessageToolInvocationToolResult()
+	default:
+		return nil, errors.New("unknown discriminator value: " + discriminator)
+	}
+}
+
 func (t SessionMessageToolInvocation) MarshalJSON() ([]byte, error) {
 	b, err := t.union.MarshalJSON()
 	return b, err