Browse Source

wip: refactoring tui

adamdottv 9 months ago
parent
commit
ce5b3126d3

+ 0 - 2
internal/tui/app/app.go

@@ -23,7 +23,6 @@ import (
 type App struct {
 	Client   *client.ClientWithResponses
 	Events   *client.Client
-	State    map[string]any
 	Session  *client.SessionInfo
 	Messages []client.MessageInfo
 
@@ -76,7 +75,6 @@ func New(ctx context.Context) (*App, error) {
 	agentBridge := NewAgentServiceBridge(httpClient)
 
 	app := &App{
-		State:             make(map[string]any),
 		Client:            httpClient,
 		Events:            eventClient,
 		Session:           &client.SessionInfo{},

+ 29 - 82
internal/tui/components/chat/messages.go

@@ -1,6 +1,7 @@
 package chat
 
 import (
+	"fmt"
 	"time"
 
 	"github.com/charmbracelet/bubbles/key"
@@ -93,63 +94,9 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	case renderFinishedMsg:
 		m.rendering = false
 		m.viewport.GotoBottom()
-
 	case state.StateUpdatedMsg:
 		m.renderView()
 		m.viewport.GotoBottom()
-
-		// case pubsub.Event[message.Message]:
-		// 	needsRerender := false
-		// 	if msg.Type == message.EventMessageCreated {
-		// 		if msg.Payload.SessionID == m.app.CurrentSessionOLD.ID {
-		// 			messageExists := false
-		// 			for _, v := range m.messages {
-		// 				if v.ID == msg.Payload.ID {
-		// 					messageExists = true
-		// 					break
-		// 				}
-		// 			}
-		//
-		// 			if !messageExists {
-		// 				if len(m.messages) > 0 {
-		// 					lastMsgID := m.messages[len(m.messages)-1].ID
-		// 					delete(m.cachedContent, lastMsgID)
-		// 				}
-		//
-		// 				m.messages = append(m.messages, msg.Payload)
-		// 				delete(m.cachedContent, m.currentMsgID)
-		// 				m.currentMsgID = msg.Payload.ID
-		// 				needsRerender = true
-		// 			}
-		// 		}
-		// 		// There are tool calls from the child task
-		// 		for _, v := range m.messages {
-		// 			for _, c := range v.ToolCalls() {
-		// 				if c.ID == msg.Payload.SessionID {
-		// 					delete(m.cachedContent, v.ID)
-		// 					needsRerender = true
-		// 				}
-		// 			}
-		// 		}
-		// 	} else if msg.Type == message.EventMessageUpdated && msg.Payload.SessionID == m.app.CurrentSessionOLD.ID {
-		// 		for i, v := range m.messages {
-		// 			if v.ID == msg.Payload.ID {
-		// 				m.messages[i] = msg.Payload
-		// 				delete(m.cachedContent, msg.Payload.ID)
-		// 				needsRerender = true
-		// 				break
-		// 			}
-		// 		}
-		// 	}
-		// 	if needsRerender {
-		// 		m.renderView()
-		// 		if len(m.messages) > 0 {
-		// 			if (msg.Type == message.EventMessageCreated) ||
-		// 				(msg.Type == message.EventMessageUpdated && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
-		// 				m.viewport.GotoBottom()
-		// 			}
-		// 		}
-		// 	}
 	}
 
 	spinner, cmd := m.spinner.Update(msg)
@@ -190,7 +137,6 @@ func (m *messagesCmp) renderView() {
 	m.viewport.SetContent(
 		styles.ForceReplaceBackgroundWithLipgloss(
 			styles.BaseStyle().
-				Width(m.width).
 				Render(
 					lipgloss.JoinVertical(
 						lipgloss.Top,
@@ -212,11 +158,12 @@ func (m *messagesCmp) View() string {
 				lipgloss.JoinVertical(
 					lipgloss.Top,
 					"Loading...",
-					// m.working(),
+					m.working(),
 					m.help(),
 				),
 			)
 	}
+
 	if len(m.app.Messages) == 0 {
 		content := baseStyle.
 			Width(m.width).
@@ -243,7 +190,7 @@ func (m *messagesCmp) View() string {
 			lipgloss.JoinVertical(
 				lipgloss.Top,
 				m.viewport.View(),
-				// m.working(),
+				m.working(),
 				m.help(),
 			),
 		)
@@ -285,31 +232,31 @@ func hasUnfinishedToolCalls(messages []message.Message) bool {
 	return false
 }
 
-// func (m *messagesCmp) working() string {
-// 	text := ""
-// 	if m.IsAgentWorking() && len(m.app.Messages) > 0 {
-// 		t := theme.CurrentTheme()
-// 		baseStyle := styles.BaseStyle()
-//
-// 		task := "Thinking..."
-// 		lastMessage := m.app.Messages[len(m.app.Messages)-1]
-// 		if hasToolsWithoutResponse(m.app.Messages) {
-// 			task = "Waiting for tool response..."
-// 		} else if hasUnfinishedToolCalls(m.app.Messages) {
-// 			task = "Building tool call..."
-// 		} else if !lastMessage.IsFinished() {
-// 			task = "Generating..."
-// 		}
-// 		if task != "" {
-// 			text += baseStyle.
-// 				Width(m.width).
-// 				Foreground(t.Primary()).
-// 				Bold(true).
-// 				Render(fmt.Sprintf("%s %s ", m.spinner.View(), task))
-// 		}
-// 	}
-// 	return text
-// }
+func (m *messagesCmp) working() string {
+	text := ""
+	if len(m.app.Messages) > 0 {
+		t := theme.CurrentTheme()
+		baseStyle := styles.BaseStyle()
+
+		task := "Working..."
+		// lastMessage := m.app.Messages[len(m.app.Messages)-1]
+		// if hasToolsWithoutResponse(m.app.Messages) {
+		// 	task = "Waiting for tool response..."
+		// } else if hasUnfinishedToolCalls(m.app.Messages) {
+		// 	task = "Building tool call..."
+		// } else if !lastMessage.IsFinished() {
+		// 	task = "Generating..."
+		// }
+		if task != "" {
+			text += baseStyle.
+				Width(m.width).
+				Foreground(t.Primary()).
+				Bold(true).
+				Render(fmt.Sprintf("%s %s ", m.spinner.View(), task))
+		}
+	}
+	return text
+}
 
 func (m *messagesCmp) help() string {
 	t := theme.CurrentTheme()

+ 1 - 1
internal/tui/components/chat/sidebar.go

@@ -86,7 +86,7 @@ func (m *sidebarCmp) sessionSection() string {
 
 	sessionValue := baseStyle.
 		Foreground(t.Text()).
-		Render(fmt.Sprintf(": %s", m.app.CurrentSessionOLD.Title))
+		Render(fmt.Sprintf(": %s", m.app.Session.Title))
 
 	return sessionKey + sessionValue
 }

+ 14 - 73
internal/tui/tui.go

@@ -2,7 +2,6 @@ package tui
 
 import (
 	"context"
-	"encoding/json"
 	"log/slog"
 	"strings"
 
@@ -267,85 +266,27 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			}
 		}
 
-	case client.EventStorageWrite:
-		parts := strings.Split(msg.Key, "/")
-		if len(parts) < 3 {
-			return a, nil
-		}
-
-		if parts[0] == "session" && parts[1] == "info" {
-			sessionId := parts[2]
-			if sessionId == a.app.Session.Id {
-				var sessionInfo client.SessionInfo
-				bytes, _ := json.Marshal(msg.Content)
-				if err := json.Unmarshal(bytes, &sessionInfo); err != nil {
-					status.Error(err.Error())
-					return a, nil
+	case client.EventMessageUpdated:
+		if msg.Properties.Info.Metadata.SessionID == a.app.Session.Id {
+			for i, m := range a.app.Messages {
+				if m.Id == msg.Properties.Info.Id {
+					a.app.Messages[i] = msg.Properties.Info
+					slog.Debug("Updated message", "message", msg.Properties.Info)
+					return a.updateAllPages(state.StateUpdatedMsg{State: nil})
 				}
-
-				a.app.Session = &sessionInfo
 			}
 
-			return a.updateAllPages(state.StateUpdatedMsg{State: a.app.State})
+			a.app.Messages = append(a.app.Messages, msg.Properties.Info)
+			slog.Debug("Appended message", "message", msg.Properties.Info)
+			return a.updateAllPages(state.StateUpdatedMsg{State: nil})
 		}
 
-		if parts[0] == "session" && parts[1] == "message" {
-			sessionId := parts[2]
-			if sessionId == a.app.Session.Id {
-				messageId := parts[3]
-				var message client.MessageInfo
-				bytes, _ := json.Marshal(msg.Content)
-				if err := json.Unmarshal(bytes, &message); err != nil {
-					status.Error(err.Error())
-					return a, nil
-				}
-
-				for i, m := range a.app.Messages {
-					if m.Id == messageId {
-						a.app.Messages[i] = message
-						slog.Debug("Updated message", "message", message)
-						return a.updateAllPages(state.StateUpdatedMsg{State: a.app.State})
-					}
-				}
-
-				a.app.Messages = append(a.app.Messages, message)
-				slog.Debug("Appended message", "message", message)
-
-				// a.app.CurrentSession.MessageCount++
-				// a.app.CurrentSession.PromptTokens += message.PromptTokens
-				// a.app.CurrentSession.CompletionTokens += message.CompletionTokens
-				// a.app.CurrentSession.Cost += message.Cost
-				// a.app.CurrentSession.UpdatedAt = message.CreatedAt
-			}
-
-			return a.updateAllPages(state.StateUpdatedMsg{State: a.app.State})
-		}
-
-		// log key and content
-		slog.Debug("Received SSE event", "key", msg.Key, "content", msg.Content)
-
-		current := a.app.State
-
-		for i, part := range parts {
-			if i == len(parts)-1 {
-				current[part] = msg.Content
-			} else {
-				if _, exists := current[part]; !exists {
-					current[part] = make(map[string]any)
-				}
-
-				nextLevel, ok := current[part].(map[string]any)
-				if !ok {
-					current[part] = make(map[string]any)
-					nextLevel = current[part].(map[string]any)
-				}
-				current = nextLevel
-			}
+	case client.EventSessionUpdated:
+		if msg.Properties.Info.Id == a.app.Session.Id {
+			a.app.Session = &msg.Properties.Info
+			return a.updateAllPages(state.StateUpdatedMsg{State: nil})
 		}
 
-		// Trigger UI update by updating all pages with the new state
-		return a.updateAllPages(state.StateUpdatedMsg{State: a.app.State})
-
 	case dialog.CloseQuitMsg:
 		a.showQuit = false
 		return a, nil

+ 47 - 0
pkg/client/gen/openapi.json

@@ -177,6 +177,50 @@
         }
       }
     },
+    "/session_summarize": {
+      "post": {
+        "responses": {
+          "200": {
+            "description": "Summarize the session",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "type": "boolean"
+                }
+              }
+            }
+          }
+        },
+        "operationId": "postSession_summarize",
+        "parameters": [],
+        "description": "Summarize the session",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "type": "object",
+                "properties": {
+                  "sessionID": {
+                    "type": "string"
+                  },
+                  "providerID": {
+                    "type": "string"
+                  },
+                  "modelID": {
+                    "type": "string"
+                  }
+                },
+                "required": [
+                  "sessionID",
+                  "providerID",
+                  "modelID"
+                ]
+              }
+            }
+          }
+        }
+      }
+    },
     "/session_chat": {
       "post": {
         "responses": {
@@ -411,6 +455,9 @@
                   "cost": {
                     "type": "number"
                   },
+                  "summary": {
+                    "type": "boolean"
+                  },
                   "tokens": {
                     "type": "object",
                     "properties": {

+ 150 - 0
pkg/client/generated-client.go

@@ -71,6 +71,7 @@ type MessageInfo struct {
 			Cost       float32 `json:"cost"`
 			ModelID    string  `json:"modelID"`
 			ProviderID string  `json:"providerID"`
+			Summary    *bool   `json:"summary,omitempty"`
 			Tokens     struct {
 				Input     float32 `json:"input"`
 				Output    float32 `json:"output"`
@@ -221,6 +222,13 @@ type PostSessionShareJSONBody struct {
 	SessionID string `json:"sessionID"`
 }
 
+// PostSessionSummarizeJSONBody defines parameters for PostSessionSummarize.
+type PostSessionSummarizeJSONBody struct {
+	ModelID    string `json:"modelID"`
+	ProviderID string `json:"providerID"`
+	SessionID  string `json:"sessionID"`
+}
+
 // PostSessionAbortJSONRequestBody defines body for PostSessionAbort for application/json ContentType.
 type PostSessionAbortJSONRequestBody PostSessionAbortJSONBody
 
@@ -233,6 +241,9 @@ type PostSessionMessagesJSONRequestBody PostSessionMessagesJSONBody
 // PostSessionShareJSONRequestBody defines body for PostSessionShare for application/json ContentType.
 type PostSessionShareJSONRequestBody PostSessionShareJSONBody
 
+// PostSessionSummarizeJSONRequestBody defines body for PostSessionSummarize for application/json ContentType.
+type PostSessionSummarizeJSONRequestBody PostSessionSummarizeJSONBody
+
 // AsEventStorageWrite returns the union data inside the Event as a EventStorageWrite
 func (t Event) AsEventStorageWrite() (EventStorageWrite, error) {
 	var body EventStorageWrite
@@ -814,6 +825,11 @@ type ClientInterface interface {
 	PostSessionShareWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error)
 
 	PostSessionShare(ctx context.Context, body PostSessionShareJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error)
+
+	// PostSessionSummarizeWithBody request with any body
+	PostSessionSummarizeWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error)
+
+	PostSessionSummarize(ctx context.Context, body PostSessionSummarizeJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error)
 }
 
 func (c *Client) GetEvent(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) {
@@ -960,6 +976,30 @@ func (c *Client) PostSessionShare(ctx context.Context, body PostSessionShareJSON
 	return c.Client.Do(req)
 }
 
+func (c *Client) PostSessionSummarizeWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) {
+	req, err := NewPostSessionSummarizeRequestWithBody(c.Server, contentType, body)
+	if err != nil {
+		return nil, err
+	}
+	req = req.WithContext(ctx)
+	if err := c.applyEditors(ctx, req, reqEditors); err != nil {
+		return nil, err
+	}
+	return c.Client.Do(req)
+}
+
+func (c *Client) PostSessionSummarize(ctx context.Context, body PostSessionSummarizeJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) {
+	req, err := NewPostSessionSummarizeRequest(c.Server, body)
+	if err != nil {
+		return nil, err
+	}
+	req = req.WithContext(ctx)
+	if err := c.applyEditors(ctx, req, reqEditors); err != nil {
+		return nil, err
+	}
+	return c.Client.Do(req)
+}
+
 // NewGetEventRequest generates requests for GetEvent
 func NewGetEventRequest(server string) (*http.Request, error) {
 	var err error
@@ -1228,6 +1268,46 @@ func NewPostSessionShareRequestWithBody(server string, contentType string, body
 	return req, nil
 }
 
+// NewPostSessionSummarizeRequest calls the generic PostSessionSummarize builder with application/json body
+func NewPostSessionSummarizeRequest(server string, body PostSessionSummarizeJSONRequestBody) (*http.Request, error) {
+	var bodyReader io.Reader
+	buf, err := json.Marshal(body)
+	if err != nil {
+		return nil, err
+	}
+	bodyReader = bytes.NewReader(buf)
+	return NewPostSessionSummarizeRequestWithBody(server, "application/json", bodyReader)
+}
+
+// NewPostSessionSummarizeRequestWithBody generates requests for PostSessionSummarize with any type of body
+func NewPostSessionSummarizeRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) {
+	var err error
+
+	serverURL, err := url.Parse(server)
+	if err != nil {
+		return nil, err
+	}
+
+	operationPath := fmt.Sprintf("/session_summarize")
+	if operationPath[0] == '/' {
+		operationPath = "." + operationPath
+	}
+
+	queryURL, err := serverURL.Parse(operationPath)
+	if err != nil {
+		return nil, err
+	}
+
+	req, err := http.NewRequest("POST", queryURL.String(), body)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Add("Content-Type", contentType)
+
+	return req, nil
+}
+
 func (c *Client) applyEditors(ctx context.Context, req *http.Request, additionalEditors []RequestEditorFn) error {
 	for _, r := range c.RequestEditors {
 		if err := r(ctx, req); err != nil {
@@ -1302,6 +1382,11 @@ type ClientWithResponsesInterface interface {
 	PostSessionShareWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionShareResponse, error)
 
 	PostSessionShareWithResponse(ctx context.Context, body PostSessionShareJSONRequestBody, reqEditors ...RequestEditorFn) (*PostSessionShareResponse, error)
+
+	// PostSessionSummarizeWithBodyWithResponse request with any body
+	PostSessionSummarizeWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionSummarizeResponse, error)
+
+	PostSessionSummarizeWithResponse(ctx context.Context, body PostSessionSummarizeJSONRequestBody, reqEditors ...RequestEditorFn) (*PostSessionSummarizeResponse, error)
 }
 
 type GetEventResponse struct {
@@ -1480,6 +1565,28 @@ func (r PostSessionShareResponse) StatusCode() int {
 	return 0
 }
 
+type PostSessionSummarizeResponse struct {
+	Body         []byte
+	HTTPResponse *http.Response
+	JSON200      *bool
+}
+
+// Status returns HTTPResponse.Status
+func (r PostSessionSummarizeResponse) Status() string {
+	if r.HTTPResponse != nil {
+		return r.HTTPResponse.Status
+	}
+	return http.StatusText(0)
+}
+
+// StatusCode returns HTTPResponse.StatusCode
+func (r PostSessionSummarizeResponse) StatusCode() int {
+	if r.HTTPResponse != nil {
+		return r.HTTPResponse.StatusCode
+	}
+	return 0
+}
+
 // GetEventWithResponse request returning *GetEventResponse
 func (c *ClientWithResponses) GetEventWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetEventResponse, error) {
 	rsp, err := c.GetEvent(ctx, reqEditors...)
@@ -1584,6 +1691,23 @@ func (c *ClientWithResponses) PostSessionShareWithResponse(ctx context.Context,
 	return ParsePostSessionShareResponse(rsp)
 }
 
+// PostSessionSummarizeWithBodyWithResponse request with arbitrary body returning *PostSessionSummarizeResponse
+func (c *ClientWithResponses) PostSessionSummarizeWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostSessionSummarizeResponse, error) {
+	rsp, err := c.PostSessionSummarizeWithBody(ctx, contentType, body, reqEditors...)
+	if err != nil {
+		return nil, err
+	}
+	return ParsePostSessionSummarizeResponse(rsp)
+}
+
+func (c *ClientWithResponses) PostSessionSummarizeWithResponse(ctx context.Context, body PostSessionSummarizeJSONRequestBody, reqEditors ...RequestEditorFn) (*PostSessionSummarizeResponse, error) {
+	rsp, err := c.PostSessionSummarize(ctx, body, reqEditors...)
+	if err != nil {
+		return nil, err
+	}
+	return ParsePostSessionSummarizeResponse(rsp)
+}
+
 // ParseGetEventResponse parses an HTTP response from a GetEventWithResponse call
 func ParseGetEventResponse(rsp *http.Response) (*GetEventResponse, error) {
 	bodyBytes, err := io.ReadAll(rsp.Body)
@@ -1791,3 +1915,29 @@ func ParsePostSessionShareResponse(rsp *http.Response) (*PostSessionShareRespons
 
 	return response, nil
 }
+
+// ParsePostSessionSummarizeResponse parses an HTTP response from a PostSessionSummarizeWithResponse call
+func ParsePostSessionSummarizeResponse(rsp *http.Response) (*PostSessionSummarizeResponse, error) {
+	bodyBytes, err := io.ReadAll(rsp.Body)
+	defer func() { _ = rsp.Body.Close() }()
+	if err != nil {
+		return nil, err
+	}
+
+	response := &PostSessionSummarizeResponse{
+		Body:         bodyBytes,
+		HTTPResponse: rsp,
+	}
+
+	switch {
+	case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200:
+		var dest bool
+		if err := json.Unmarshal(bodyBytes, &dest); err != nil {
+			return nil, err
+		}
+		response.JSON200 = &dest
+
+	}
+
+	return response, nil
+}