|
|
@@ -2,21 +2,14 @@ package provider
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
- "encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
- "io"
|
|
|
- "time"
|
|
|
-
|
|
|
+ "log/slog"
|
|
|
"github.com/openai/openai-go"
|
|
|
"github.com/openai/openai-go/option"
|
|
|
- "github.com/openai/openai-go/shared"
|
|
|
- "github.com/sst/opencode/internal/config"
|
|
|
"github.com/sst/opencode/internal/llm/models"
|
|
|
"github.com/sst/opencode/internal/llm/tools"
|
|
|
"github.com/sst/opencode/internal/message"
|
|
|
- "github.com/sst/opencode/internal/status"
|
|
|
- "log/slog"
|
|
|
)
|
|
|
|
|
|
type openaiOptions struct {
|
|
|
@@ -66,87 +59,21 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
|
|
|
- // Add system message first
|
|
|
- openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
|
|
|
-
|
|
|
- for _, msg := range messages {
|
|
|
- switch msg.Role {
|
|
|
- case message.User:
|
|
|
- var content []openai.ChatCompletionContentPartUnionParam
|
|
|
- textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
|
|
|
- content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
|
|
|
- for _, binaryContent := range msg.BinaryContent() {
|
|
|
- imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
|
|
|
- imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
|
|
|
-
|
|
|
- content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
|
|
|
- }
|
|
|
-
|
|
|
- openaiMessages = append(openaiMessages, openai.UserMessage(content))
|
|
|
-
|
|
|
- case message.Assistant:
|
|
|
- assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
|
|
- Role: "assistant",
|
|
|
- }
|
|
|
-
|
|
|
- if msg.Content().String() != "" {
|
|
|
- assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
|
|
|
- OfString: openai.String(msg.Content().String()),
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if len(msg.ToolCalls()) > 0 {
|
|
|
- assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
|
|
|
- for i, call := range msg.ToolCalls() {
|
|
|
- assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
|
|
|
- ID: call.ID,
|
|
|
- Type: "function",
|
|
|
- Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
|
|
- Name: call.Name,
|
|
|
- Arguments: call.Input,
|
|
|
- },
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
|
|
|
- OfAssistant: &assistantMsg,
|
|
|
- })
|
|
|
-
|
|
|
- case message.Tool:
|
|
|
- for _, result := range msg.ToolResults() {
|
|
|
- openaiMessages = append(openaiMessages,
|
|
|
- openai.ToolMessage(result.Content, result.ToolCallID),
|
|
|
- )
|
|
|
- }
|
|
|
- }
|
|
|
+func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
|
|
|
+ if o.providerOptions.model.ID == models.OpenAIModels[models.CodexMini].ID || o.providerOptions.model.ID == models.OpenAIModels[models.O1Pro].ID {
|
|
|
+ return o.sendResponseMessages(ctx, messages, tools)
|
|
|
}
|
|
|
-
|
|
|
- return
|
|
|
+ return o.sendChatcompletionMessage(ctx, messages, tools)
|
|
|
}
|
|
|
|
|
|
-func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
|
|
|
- openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
|
|
|
-
|
|
|
- for i, tool := range tools {
|
|
|
- info := tool.Info()
|
|
|
- openaiTools[i] = openai.ChatCompletionToolParam{
|
|
|
- Function: openai.FunctionDefinitionParam{
|
|
|
- Name: info.Name,
|
|
|
- Description: openai.String(info.Description),
|
|
|
- Parameters: openai.FunctionParameters{
|
|
|
- "type": "object",
|
|
|
- "properties": info.Parameters,
|
|
|
- "required": info.Required,
|
|
|
- },
|
|
|
- },
|
|
|
- }
|
|
|
+func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
|
|
+ if o.providerOptions.model.ID == models.OpenAIModels[models.CodexMini].ID || o.providerOptions.model.ID == models.OpenAIModels[models.O1Pro].ID {
|
|
|
+ return o.streamResponseMessages(ctx, messages, tools)
|
|
|
}
|
|
|
-
|
|
|
- return openaiTools
|
|
|
+ return o.streamChatCompletionMessages(ctx, messages, tools)
|
|
|
}
|
|
|
|
|
|
+
|
|
|
func (o *openaiClient) finishReason(reason string) message.FinishReason {
|
|
|
switch reason {
|
|
|
case "stop":
|
|
|
@@ -160,190 +87,6 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
|
|
|
- params := openai.ChatCompletionNewParams{
|
|
|
- Model: openai.ChatModel(o.providerOptions.model.APIModel),
|
|
|
- Messages: messages,
|
|
|
- Tools: tools,
|
|
|
- }
|
|
|
-
|
|
|
- if o.providerOptions.model.CanReason == true {
|
|
|
- params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
|
|
|
- switch o.options.reasoningEffort {
|
|
|
- case "low":
|
|
|
- params.ReasoningEffort = shared.ReasoningEffortLow
|
|
|
- case "medium":
|
|
|
- params.ReasoningEffort = shared.ReasoningEffortMedium
|
|
|
- case "high":
|
|
|
- params.ReasoningEffort = shared.ReasoningEffortHigh
|
|
|
- default:
|
|
|
- params.ReasoningEffort = shared.ReasoningEffortMedium
|
|
|
- }
|
|
|
- } else {
|
|
|
- params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
|
|
|
- }
|
|
|
-
|
|
|
- if o.providerOptions.model.Provider == models.ProviderOpenRouter {
|
|
|
- params.WithExtraFields(map[string]any{
|
|
|
- "provider": map[string]any{
|
|
|
- "require_parameters": true,
|
|
|
- },
|
|
|
- })
|
|
|
- }
|
|
|
-
|
|
|
- return params
|
|
|
-}
|
|
|
-
|
|
|
-func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
|
|
|
- params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
|
|
|
- cfg := config.Get()
|
|
|
- if cfg.Debug {
|
|
|
- jsonData, _ := json.Marshal(params)
|
|
|
- slog.Debug("Prepared messages", "messages", string(jsonData))
|
|
|
- }
|
|
|
- attempts := 0
|
|
|
- for {
|
|
|
- attempts++
|
|
|
- openaiResponse, err := o.client.Chat.Completions.New(
|
|
|
- ctx,
|
|
|
- params,
|
|
|
- )
|
|
|
- // If there is an error we are going to see if we can retry the call
|
|
|
- if err != nil {
|
|
|
- retry, after, retryErr := o.shouldRetry(attempts, err)
|
|
|
- duration := time.Duration(after) * time.Millisecond
|
|
|
- if retryErr != nil {
|
|
|
- return nil, retryErr
|
|
|
- }
|
|
|
- if retry {
|
|
|
- status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- return nil, ctx.Err()
|
|
|
- case <-time.After(duration):
|
|
|
- continue
|
|
|
- }
|
|
|
- }
|
|
|
- return nil, retryErr
|
|
|
- }
|
|
|
-
|
|
|
- content := ""
|
|
|
- if openaiResponse.Choices[0].Message.Content != "" {
|
|
|
- content = openaiResponse.Choices[0].Message.Content
|
|
|
- }
|
|
|
-
|
|
|
- toolCalls := o.toolCalls(*openaiResponse)
|
|
|
- finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
|
|
|
-
|
|
|
- if len(toolCalls) > 0 {
|
|
|
- finishReason = message.FinishReasonToolUse
|
|
|
- }
|
|
|
-
|
|
|
- return &ProviderResponse{
|
|
|
- Content: content,
|
|
|
- ToolCalls: toolCalls,
|
|
|
- Usage: o.usage(*openaiResponse),
|
|
|
- FinishReason: finishReason,
|
|
|
- }, nil
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
|
|
- params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
|
|
|
- params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
|
|
- IncludeUsage: openai.Bool(true),
|
|
|
- }
|
|
|
-
|
|
|
- cfg := config.Get()
|
|
|
- if cfg.Debug {
|
|
|
- jsonData, _ := json.Marshal(params)
|
|
|
- slog.Debug("Prepared messages", "messages", string(jsonData))
|
|
|
- }
|
|
|
-
|
|
|
- attempts := 0
|
|
|
- eventChan := make(chan ProviderEvent)
|
|
|
-
|
|
|
- go func() {
|
|
|
- for {
|
|
|
- attempts++
|
|
|
- openaiStream := o.client.Chat.Completions.NewStreaming(
|
|
|
- ctx,
|
|
|
- params,
|
|
|
- )
|
|
|
-
|
|
|
- acc := openai.ChatCompletionAccumulator{}
|
|
|
- currentContent := ""
|
|
|
- toolCalls := make([]message.ToolCall, 0)
|
|
|
-
|
|
|
- for openaiStream.Next() {
|
|
|
- chunk := openaiStream.Current()
|
|
|
- acc.AddChunk(chunk)
|
|
|
-
|
|
|
- for _, choice := range chunk.Choices {
|
|
|
- if choice.Delta.Content != "" {
|
|
|
- eventChan <- ProviderEvent{
|
|
|
- Type: EventContentDelta,
|
|
|
- Content: choice.Delta.Content,
|
|
|
- }
|
|
|
- currentContent += choice.Delta.Content
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- err := openaiStream.Err()
|
|
|
- if err == nil || errors.Is(err, io.EOF) {
|
|
|
- // Stream completed successfully
|
|
|
- finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
|
|
|
- if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
|
|
|
- toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
|
|
|
- }
|
|
|
- if len(toolCalls) > 0 {
|
|
|
- finishReason = message.FinishReasonToolUse
|
|
|
- }
|
|
|
-
|
|
|
- eventChan <- ProviderEvent{
|
|
|
- Type: EventComplete,
|
|
|
- Response: &ProviderResponse{
|
|
|
- Content: currentContent,
|
|
|
- ToolCalls: toolCalls,
|
|
|
- Usage: o.usage(acc.ChatCompletion),
|
|
|
- FinishReason: finishReason,
|
|
|
- },
|
|
|
- }
|
|
|
- close(eventChan)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // If there is an error we are going to see if we can retry the call
|
|
|
- retry, after, retryErr := o.shouldRetry(attempts, err)
|
|
|
- duration := time.Duration(after) * time.Millisecond
|
|
|
- if retryErr != nil {
|
|
|
- eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
|
|
- close(eventChan)
|
|
|
- return
|
|
|
- }
|
|
|
- if retry {
|
|
|
- status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- // context cancelled
|
|
|
- if ctx.Err() == nil {
|
|
|
- eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
|
|
- }
|
|
|
- close(eventChan)
|
|
|
- return
|
|
|
- case <-time.After(duration):
|
|
|
- continue
|
|
|
- }
|
|
|
- }
|
|
|
- eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
|
|
- close(eventChan)
|
|
|
- return
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
- return eventChan
|
|
|
-}
|
|
|
|
|
|
func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
|
|
var apierr *openai.Error
|
|
|
@@ -373,36 +116,6 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
|
|
|
return true, int64(retryMs), nil
|
|
|
}
|
|
|
|
|
|
-func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
|
|
|
- var toolCalls []message.ToolCall
|
|
|
-
|
|
|
- if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
|
|
|
- for _, call := range completion.Choices[0].Message.ToolCalls {
|
|
|
- toolCall := message.ToolCall{
|
|
|
- ID: call.ID,
|
|
|
- Name: call.Function.Name,
|
|
|
- Input: call.Function.Arguments,
|
|
|
- Type: "function",
|
|
|
- Finished: true,
|
|
|
- }
|
|
|
- toolCalls = append(toolCalls, toolCall)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return toolCalls
|
|
|
-}
|
|
|
-
|
|
|
-func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
|
|
|
- cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
|
|
|
- inputTokens := completion.Usage.PromptTokens - cachedTokens
|
|
|
-
|
|
|
- return TokenUsage{
|
|
|
- InputTokens: inputTokens,
|
|
|
- OutputTokens: completion.Usage.CompletionTokens,
|
|
|
- CacheCreationTokens: 0, // OpenAI doesn't provide this directly
|
|
|
- CacheReadTokens: cachedTokens,
|
|
|
- }
|
|
|
-}
|
|
|
|
|
|
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
|
|
|
return func(options *openaiOptions) {
|