agent.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "github.com/kujtimiihoxha/termai/internal/config"
  9. "github.com/kujtimiihoxha/termai/internal/llm/models"
  10. "github.com/kujtimiihoxha/termai/internal/llm/prompt"
  11. "github.com/kujtimiihoxha/termai/internal/llm/provider"
  12. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  13. "github.com/kujtimiihoxha/termai/internal/logging"
  14. "github.com/kujtimiihoxha/termai/internal/message"
  15. "github.com/kujtimiihoxha/termai/internal/session"
  16. )
  17. // Common errors
  18. var (
  19. ErrProviderNotEnabled = errors.New("provider is not enabled")
  20. ErrRequestCancelled = errors.New("request cancelled by user")
  21. ErrSessionBusy = errors.New("session is currently processing another request")
  22. )
  23. // Service defines the interface for generating responses
  24. type Service interface {
  25. Generate(ctx context.Context, sessionID string, content string) error
  26. Cancel(sessionID string) error
  27. }
  28. type agent struct {
  29. sessions session.Service
  30. messages message.Service
  31. model models.Model
  32. tools []tools.BaseTool
  33. agent provider.Provider
  34. titleGenerator provider.Provider
  35. activeRequests sync.Map // map[sessionID]context.CancelFunc
  36. }
  37. // NewAgent creates a new agent instance with the given model and tools
  38. func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
  39. agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
  40. if err != nil {
  41. return nil, fmt.Errorf("failed to initialize providers: %w", err)
  42. }
  43. return &agent{
  44. model: model,
  45. tools: tools,
  46. sessions: sessions,
  47. messages: messages,
  48. agent: agentProvider,
  49. titleGenerator: titleGenerator,
  50. activeRequests: sync.Map{},
  51. }, nil
  52. }
  53. // Cancel cancels an active request by session ID
  54. func (a *agent) Cancel(sessionID string) error {
  55. if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
  56. if cancel, ok := cancelFunc.(context.CancelFunc); ok {
  57. logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
  58. cancel()
  59. return nil
  60. }
  61. }
  62. return errors.New("no active request found for this session")
  63. }
  64. // Generate starts the generation process
  65. func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
  66. // Check if this session already has an active request
  67. if _, busy := a.activeRequests.Load(sessionID); busy {
  68. return ErrSessionBusy
  69. }
  70. // Create a cancellable context
  71. genCtx, cancel := context.WithCancel(ctx)
  72. // Store cancel function to allow user cancellation
  73. a.activeRequests.Store(sessionID, cancel)
  74. // Launch the generation in a goroutine
  75. go func() {
  76. defer func() {
  77. if r := recover(); r != nil {
  78. logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
  79. }
  80. }()
  81. defer a.activeRequests.Delete(sessionID)
  82. defer cancel()
  83. if err := a.generate(genCtx, sessionID, content); err != nil {
  84. if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
  85. // Log the error (avoid logging cancellations as they're expected)
  86. logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
  87. // You may want to create an error message in the chat
  88. bgCtx := context.Background()
  89. errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
  90. _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
  91. Role: message.System,
  92. Parts: []message.ContentPart{
  93. message.TextContent{
  94. Text: errorMsg,
  95. },
  96. },
  97. })
  98. if createErr != nil {
  99. logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
  100. }
  101. }
  102. }
  103. }()
  104. return nil
  105. }
  106. // IsSessionBusy checks if a session currently has an active request
  107. func (a *agent) IsSessionBusy(sessionID string) bool {
  108. _, busy := a.activeRequests.Load(sessionID)
  109. return busy
  110. } // handleTitleGeneration asynchronously generates a title for new sessions
  111. func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
  112. response, err := a.titleGenerator.SendMessages(
  113. ctx,
  114. []message.Message{
  115. {
  116. Role: message.User,
  117. Parts: []message.ContentPart{
  118. message.TextContent{
  119. Text: content,
  120. },
  121. },
  122. },
  123. },
  124. nil,
  125. )
  126. if err != nil {
  127. logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
  128. return
  129. }
  130. session, err := a.sessions.Get(ctx, sessionID)
  131. if err != nil {
  132. logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
  133. return
  134. }
  135. if response.Content != "" {
  136. session.Title = strings.TrimSpace(response.Content)
  137. session.Title = strings.ReplaceAll(session.Title, "\n", " ")
  138. if _, err := a.sessions.Save(ctx, session); err != nil {
  139. logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
  140. }
  141. }
  142. }
  143. // TrackUsage updates token usage statistics for the session
  144. func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
  145. session, err := a.sessions.Get(ctx, sessionID)
  146. if err != nil {
  147. return fmt.Errorf("failed to get session: %w", err)
  148. }
  149. cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
  150. model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
  151. model.CostPer1MIn/1e6*float64(usage.InputTokens) +
  152. model.CostPer1MOut/1e6*float64(usage.OutputTokens)
  153. session.Cost += cost
  154. session.CompletionTokens += usage.OutputTokens
  155. session.PromptTokens += usage.InputTokens
  156. _, err = a.sessions.Save(ctx, session)
  157. if err != nil {
  158. return fmt.Errorf("failed to save session: %w", err)
  159. }
  160. return nil
  161. }
  162. // processEvent handles different types of events during generation
  163. func (a *agent) processEvent(
  164. ctx context.Context,
  165. sessionID string,
  166. assistantMsg *message.Message,
  167. event provider.ProviderEvent,
  168. ) error {
  169. select {
  170. case <-ctx.Done():
  171. return ctx.Err()
  172. default:
  173. // Continue processing
  174. }
  175. switch event.Type {
  176. case provider.EventThinkingDelta:
  177. assistantMsg.AppendReasoningContent(event.Content)
  178. return a.messages.Update(ctx, *assistantMsg)
  179. case provider.EventContentDelta:
  180. assistantMsg.AppendContent(event.Content)
  181. return a.messages.Update(ctx, *assistantMsg)
  182. case provider.EventError:
  183. if errors.Is(event.Error, context.Canceled) {
  184. logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
  185. return context.Canceled
  186. }
  187. logging.ErrorPersist(event.Error.Error())
  188. return event.Error
  189. case provider.EventWarning:
  190. logging.WarnPersist(event.Info)
  191. case provider.EventInfo:
  192. logging.InfoPersist(event.Info)
  193. case provider.EventComplete:
  194. assistantMsg.SetToolCalls(event.Response.ToolCalls)
  195. assistantMsg.AddFinish(event.Response.FinishReason)
  196. if err := a.messages.Update(ctx, *assistantMsg); err != nil {
  197. return fmt.Errorf("failed to update message: %w", err)
  198. }
  199. return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
  200. }
  201. return nil
  202. }
  203. // ExecuteTools runs all tool calls sequentially and returns the results
  204. func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
  205. toolResults := make([]message.ToolResult, len(toolCalls))
  206. // Create a child context that can be canceled
  207. ctx, cancel := context.WithCancel(ctx)
  208. defer cancel()
  209. // Check if already canceled before starting any execution
  210. if ctx.Err() != nil {
  211. // Mark all tools as canceled
  212. for i, toolCall := range toolCalls {
  213. toolResults[i] = message.ToolResult{
  214. ToolCallID: toolCall.ID,
  215. Content: "Tool execution canceled by user",
  216. IsError: true,
  217. }
  218. }
  219. return toolResults, ctx.Err()
  220. }
  221. for i, toolCall := range toolCalls {
  222. // Check for cancellation before executing each tool
  223. select {
  224. case <-ctx.Done():
  225. // Mark this and all remaining tools as canceled
  226. for j := i; j < len(toolCalls); j++ {
  227. toolResults[j] = message.ToolResult{
  228. ToolCallID: toolCalls[j].ID,
  229. Content: "Tool execution canceled by user",
  230. IsError: true,
  231. }
  232. }
  233. return toolResults, ctx.Err()
  234. default:
  235. // Continue processing
  236. }
  237. response := ""
  238. isError := false
  239. found := false
  240. // Find and execute the appropriate tool
  241. for _, tool := range tls {
  242. if tool.Info().Name == toolCall.Name {
  243. found = true
  244. toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
  245. ID: toolCall.ID,
  246. Name: toolCall.Name,
  247. Input: toolCall.Input,
  248. })
  249. if toolErr != nil {
  250. if errors.Is(toolErr, context.Canceled) {
  251. response = "Tool execution canceled by user"
  252. } else {
  253. response = fmt.Sprintf("Error running tool: %s", toolErr)
  254. }
  255. isError = true
  256. } else {
  257. response = toolResult.Content
  258. isError = toolResult.IsError
  259. }
  260. break
  261. }
  262. }
  263. if !found {
  264. response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
  265. isError = true
  266. }
  267. toolResults[i] = message.ToolResult{
  268. ToolCallID: toolCall.ID,
  269. Content: response,
  270. IsError: isError,
  271. }
  272. }
  273. return toolResults, nil
  274. }
  275. // handleToolExecution processes tool calls and creates tool result messages
  276. func (a *agent) handleToolExecution(
  277. ctx context.Context,
  278. assistantMsg message.Message,
  279. ) (*message.Message, error) {
  280. select {
  281. case <-ctx.Done():
  282. // If cancelled, create tool results that indicate cancellation
  283. if len(assistantMsg.ToolCalls()) > 0 {
  284. toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
  285. for _, tc := range assistantMsg.ToolCalls() {
  286. toolResults = append(toolResults, message.ToolResult{
  287. ToolCallID: tc.ID,
  288. Content: "Tool execution canceled by user",
  289. IsError: true,
  290. })
  291. }
  292. // Use background context to ensure the message is created even if original context is cancelled
  293. bgCtx := context.Background()
  294. parts := make([]message.ContentPart, 0)
  295. for _, toolResult := range toolResults {
  296. parts = append(parts, toolResult)
  297. }
  298. msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
  299. Role: message.Tool,
  300. Parts: parts,
  301. })
  302. if err != nil {
  303. return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
  304. }
  305. return &msg, ctx.Err()
  306. }
  307. return nil, ctx.Err()
  308. default:
  309. // Continue processing
  310. }
  311. if len(assistantMsg.ToolCalls()) == 0 {
  312. return nil, nil
  313. }
  314. toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
  315. if err != nil {
  316. // If error is from cancellation, still return the partial results we have
  317. if errors.Is(err, context.Canceled) {
  318. // Use background context to ensure the message is created even if original context is cancelled
  319. bgCtx := context.Background()
  320. parts := make([]message.ContentPart, 0)
  321. for _, toolResult := range toolResults {
  322. parts = append(parts, toolResult)
  323. }
  324. msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
  325. Role: message.Tool,
  326. Parts: parts,
  327. })
  328. if createErr != nil {
  329. logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
  330. return nil, err
  331. }
  332. return &msg, err
  333. }
  334. return nil, err
  335. }
  336. parts := make([]message.ContentPart, 0, len(toolResults))
  337. for _, toolResult := range toolResults {
  338. parts = append(parts, toolResult)
  339. }
  340. msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
  341. Role: message.Tool,
  342. Parts: parts,
  343. })
  344. if err != nil {
  345. return nil, fmt.Errorf("failed to create tool message: %w", err)
  346. }
  347. return &msg, nil
  348. }
  349. // generate handles the main generation workflow
  350. func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
  351. ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
  352. // Handle context cancellation at any point
  353. if err := ctx.Err(); err != nil {
  354. return ErrRequestCancelled
  355. }
  356. messages, err := a.messages.List(ctx, sessionID)
  357. if err != nil {
  358. return fmt.Errorf("failed to list messages: %w", err)
  359. }
  360. if len(messages) == 0 {
  361. titleCtx := context.Background()
  362. go a.handleTitleGeneration(titleCtx, sessionID, content)
  363. }
  364. userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  365. Role: message.User,
  366. Parts: []message.ContentPart{
  367. message.TextContent{
  368. Text: content,
  369. },
  370. },
  371. })
  372. if err != nil {
  373. return fmt.Errorf("failed to create user message: %w", err)
  374. }
  375. messages = append(messages, userMsg)
  376. for {
  377. // Check for cancellation before each iteration
  378. select {
  379. case <-ctx.Done():
  380. return ErrRequestCancelled
  381. default:
  382. // Continue processing
  383. }
  384. eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
  385. if err != nil {
  386. if errors.Is(err, context.Canceled) {
  387. return ErrRequestCancelled
  388. }
  389. return fmt.Errorf("failed to stream response: %w", err)
  390. }
  391. assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
  392. Role: message.Assistant,
  393. Parts: []message.ContentPart{},
  394. Model: a.model.ID,
  395. })
  396. if err != nil {
  397. return fmt.Errorf("failed to create assistant message: %w", err)
  398. }
  399. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
  400. // Process events from the LLM provider
  401. for event := range eventChan {
  402. if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
  403. if errors.Is(err, context.Canceled) {
  404. // Mark as canceled but don't create separate message
  405. assistantMsg.AddFinish("canceled")
  406. _ = a.messages.Update(context.Background(), assistantMsg)
  407. return ErrRequestCancelled
  408. }
  409. assistantMsg.AddFinish("error:" + err.Error())
  410. _ = a.messages.Update(ctx, assistantMsg)
  411. return fmt.Errorf("event processing error: %w", err)
  412. }
  413. // Check for cancellation during event processing
  414. select {
  415. case <-ctx.Done():
  416. // Mark as canceled
  417. assistantMsg.AddFinish("canceled")
  418. _ = a.messages.Update(context.Background(), assistantMsg)
  419. return ErrRequestCancelled
  420. default:
  421. }
  422. }
  423. // Check for cancellation before tool execution
  424. select {
  425. case <-ctx.Done():
  426. assistantMsg.AddFinish("canceled_by_user")
  427. _ = a.messages.Update(context.Background(), assistantMsg)
  428. return ErrRequestCancelled
  429. default:
  430. }
  431. // Execute any tool calls
  432. toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
  433. if err != nil {
  434. if errors.Is(err, context.Canceled) {
  435. assistantMsg.AddFinish("canceled_by_user")
  436. _ = a.messages.Update(context.Background(), assistantMsg)
  437. return ErrRequestCancelled
  438. }
  439. return fmt.Errorf("tool execution error: %w", err)
  440. }
  441. if err := a.messages.Update(ctx, assistantMsg); err != nil {
  442. return fmt.Errorf("failed to update assistant message: %w", err)
  443. }
  444. // If no tool calls, we're done
  445. if len(assistantMsg.ToolCalls()) == 0 {
  446. break
  447. }
  448. // Add messages for next iteration
  449. messages = append(messages, assistantMsg)
  450. if toolMsg != nil {
  451. messages = append(messages, *toolMsg)
  452. }
  453. // Check for cancellation after tool execution
  454. select {
  455. case <-ctx.Done():
  456. return ErrRequestCancelled
  457. default:
  458. }
  459. }
  460. return nil
  461. }
  462. // getAgentProviders initializes the LLM providers based on the chosen model
  463. func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
  464. maxTokens := config.Get().Model.CoderMaxTokens
  465. providerConfig, ok := config.Get().Providers[model.Provider]
  466. if !ok || providerConfig.Disabled {
  467. return nil, nil, ErrProviderNotEnabled
  468. }
  469. var agentProvider provider.Provider
  470. var titleGenerator provider.Provider
  471. var err error
  472. switch model.Provider {
  473. case models.ProviderOpenAI:
  474. agentProvider, err = provider.NewOpenAIProvider(
  475. provider.WithOpenAISystemMessage(
  476. prompt.CoderOpenAISystemPrompt(),
  477. ),
  478. provider.WithOpenAIMaxTokens(maxTokens),
  479. provider.WithOpenAIModel(model),
  480. provider.WithOpenAIKey(providerConfig.APIKey),
  481. )
  482. if err != nil {
  483. return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
  484. }
  485. titleGenerator, err = provider.NewOpenAIProvider(
  486. provider.WithOpenAISystemMessage(
  487. prompt.TitlePrompt(),
  488. ),
  489. provider.WithOpenAIMaxTokens(80),
  490. provider.WithOpenAIModel(model),
  491. provider.WithOpenAIKey(providerConfig.APIKey),
  492. )
  493. if err != nil {
  494. return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
  495. }
  496. case models.ProviderAnthropic:
  497. agentProvider, err = provider.NewAnthropicProvider(
  498. provider.WithAnthropicSystemMessage(
  499. prompt.CoderAnthropicSystemPrompt(),
  500. ),
  501. provider.WithAnthropicMaxTokens(maxTokens),
  502. provider.WithAnthropicKey(providerConfig.APIKey),
  503. provider.WithAnthropicModel(model),
  504. )
  505. if err != nil {
  506. return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
  507. }
  508. titleGenerator, err = provider.NewAnthropicProvider(
  509. provider.WithAnthropicSystemMessage(
  510. prompt.TitlePrompt(),
  511. ),
  512. provider.WithAnthropicMaxTokens(80),
  513. provider.WithAnthropicKey(providerConfig.APIKey),
  514. provider.WithAnthropicModel(model),
  515. )
  516. if err != nil {
  517. return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
  518. }
  519. case models.ProviderGemini:
  520. agentProvider, err = provider.NewGeminiProvider(
  521. ctx,
  522. provider.WithGeminiSystemMessage(
  523. prompt.CoderOpenAISystemPrompt(),
  524. ),
  525. provider.WithGeminiMaxTokens(int32(maxTokens)),
  526. provider.WithGeminiKey(providerConfig.APIKey),
  527. provider.WithGeminiModel(model),
  528. )
  529. if err != nil {
  530. return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
  531. }
  532. titleGenerator, err = provider.NewGeminiProvider(
  533. ctx,
  534. provider.WithGeminiSystemMessage(
  535. prompt.TitlePrompt(),
  536. ),
  537. provider.WithGeminiMaxTokens(80),
  538. provider.WithGeminiKey(providerConfig.APIKey),
  539. provider.WithGeminiModel(model),
  540. )
  541. if err != nil {
  542. return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
  543. }
  544. case models.ProviderGROQ:
  545. agentProvider, err = provider.NewOpenAIProvider(
  546. provider.WithOpenAISystemMessage(
  547. prompt.CoderAnthropicSystemPrompt(),
  548. ),
  549. provider.WithOpenAIMaxTokens(maxTokens),
  550. provider.WithOpenAIModel(model),
  551. provider.WithOpenAIKey(providerConfig.APIKey),
  552. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  553. )
  554. if err != nil {
  555. return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
  556. }
  557. titleGenerator, err = provider.NewOpenAIProvider(
  558. provider.WithOpenAISystemMessage(
  559. prompt.TitlePrompt(),
  560. ),
  561. provider.WithOpenAIMaxTokens(80),
  562. provider.WithOpenAIModel(model),
  563. provider.WithOpenAIKey(providerConfig.APIKey),
  564. provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  565. )
  566. if err != nil {
  567. return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
  568. }
  569. case models.ProviderBedrock:
  570. agentProvider, err = provider.NewBedrockProvider(
  571. provider.WithBedrockSystemMessage(
  572. prompt.CoderAnthropicSystemPrompt(),
  573. ),
  574. provider.WithBedrockMaxTokens(maxTokens),
  575. provider.WithBedrockModel(model),
  576. )
  577. if err != nil {
  578. return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
  579. }
  580. titleGenerator, err = provider.NewBedrockProvider(
  581. provider.WithBedrockSystemMessage(
  582. prompt.TitlePrompt(),
  583. ),
  584. provider.WithBedrockMaxTokens(80),
  585. provider.WithBedrockModel(model),
  586. )
  587. if err != nil {
  588. return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
  589. }
  590. default:
  591. return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
  592. }
  593. return agentProvider, titleGenerator, nil
  594. }