| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- package provider
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "time"
- "github.com/opencode-ai/opencode/internal/config"
- "github.com/opencode-ai/opencode/internal/llm/tools"
- "github.com/opencode-ai/opencode/internal/logging"
- "github.com/opencode-ai/opencode/internal/message"
- "github.com/openai/openai-go"
- "github.com/openai/openai-go/option"
- "github.com/openai/openai-go/shared"
- )
- type openaiOptions struct {
- baseURL string
- disableCache bool
- reasoningEffort string
- }
- type OpenAIOption func(*openaiOptions)
- type openaiClient struct {
- providerOptions providerClientOptions
- options openaiOptions
- client openai.Client
- }
- type OpenAIClient ProviderClient
- func newOpenAIClient(opts providerClientOptions) OpenAIClient {
- openaiOpts := openaiOptions{
- reasoningEffort: "medium",
- }
- for _, o := range opts.openaiOptions {
- o(&openaiOpts)
- }
- openaiClientOptions := []option.RequestOption{}
- if opts.apiKey != "" {
- openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
- }
- if openaiOpts.baseURL != "" {
- openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
- }
- client := openai.NewClient(openaiClientOptions...)
- return &openaiClient{
- providerOptions: opts,
- options: openaiOpts,
- client: client,
- }
- }
- func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
- // Add system message first
- openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
- for _, msg := range messages {
- switch msg.Role {
- case message.User:
- openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
- case message.Assistant:
- assistantMsg := openai.ChatCompletionAssistantMessageParam{
- Role: "assistant",
- }
- if msg.Content().String() != "" {
- assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
- OfString: openai.String(msg.Content().String()),
- }
- }
- if len(msg.ToolCalls()) > 0 {
- assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
- for i, call := range msg.ToolCalls() {
- assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
- ID: call.ID,
- Type: "function",
- Function: openai.ChatCompletionMessageToolCallFunctionParam{
- Name: call.Name,
- Arguments: call.Input,
- },
- }
- }
- }
- openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
- OfAssistant: &assistantMsg,
- })
- case message.Tool:
- for _, result := range msg.ToolResults() {
- openaiMessages = append(openaiMessages,
- openai.ToolMessage(result.Content, result.ToolCallID),
- )
- }
- }
- }
- return
- }
- func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
- openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
- for i, tool := range tools {
- info := tool.Info()
- openaiTools[i] = openai.ChatCompletionToolParam{
- Function: openai.FunctionDefinitionParam{
- Name: info.Name,
- Description: openai.String(info.Description),
- Parameters: openai.FunctionParameters{
- "type": "object",
- "properties": info.Parameters,
- "required": info.Required,
- },
- },
- }
- }
- return openaiTools
- }
- func (o *openaiClient) finishReason(reason string) message.FinishReason {
- switch reason {
- case "stop":
- return message.FinishReasonEndTurn
- case "length":
- return message.FinishReasonMaxTokens
- case "tool_calls":
- return message.FinishReasonToolUse
- default:
- return message.FinishReasonUnknown
- }
- }
- func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
- params := openai.ChatCompletionNewParams{
- Model: openai.ChatModel(o.providerOptions.model.APIModel),
- Messages: messages,
- Tools: tools,
- }
- if o.providerOptions.model.CanReason == true {
- params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
- switch o.options.reasoningEffort {
- case "low":
- params.ReasoningEffort = shared.ReasoningEffortLow
- case "medium":
- params.ReasoningEffort = shared.ReasoningEffortMedium
- case "high":
- params.ReasoningEffort = shared.ReasoningEffortHigh
- default:
- params.ReasoningEffort = shared.ReasoningEffortMedium
- }
- } else {
- params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
- }
- return params
- }
- func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
- params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
- cfg := config.Get()
- if cfg.Debug {
- jsonData, _ := json.Marshal(params)
- logging.Debug("Prepared messages", "messages", string(jsonData))
- }
- attempts := 0
- for {
- attempts++
- openaiResponse, err := o.client.Chat.Completions.New(
- ctx,
- params,
- )
- // If there is an error we are going to see if we can retry the call
- if err != nil {
- retry, after, retryErr := o.shouldRetry(attempts, err)
- if retryErr != nil {
- return nil, retryErr
- }
- if retry {
- logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-time.After(time.Duration(after) * time.Millisecond):
- continue
- }
- }
- return nil, retryErr
- }
- content := ""
- if openaiResponse.Choices[0].Message.Content != "" {
- content = openaiResponse.Choices[0].Message.Content
- }
- return &ProviderResponse{
- Content: content,
- ToolCalls: o.toolCalls(*openaiResponse),
- Usage: o.usage(*openaiResponse),
- FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)),
- }, nil
- }
- }
- func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
- params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
- IncludeUsage: openai.Bool(true),
- }
- cfg := config.Get()
- if cfg.Debug {
- jsonData, _ := json.Marshal(params)
- logging.Debug("Prepared messages", "messages", string(jsonData))
- }
- attempts := 0
- eventChan := make(chan ProviderEvent)
- go func() {
- for {
- attempts++
- openaiStream := o.client.Chat.Completions.NewStreaming(
- ctx,
- params,
- )
- acc := openai.ChatCompletionAccumulator{}
- currentContent := ""
- toolCalls := make([]message.ToolCall, 0)
- for openaiStream.Next() {
- chunk := openaiStream.Current()
- acc.AddChunk(chunk)
- if tool, ok := acc.JustFinishedToolCall(); ok {
- toolCalls = append(toolCalls, message.ToolCall{
- ID: tool.Id,
- Name: tool.Name,
- Input: tool.Arguments,
- Type: "function",
- })
- }
- for _, choice := range chunk.Choices {
- if choice.Delta.Content != "" {
- eventChan <- ProviderEvent{
- Type: EventContentDelta,
- Content: choice.Delta.Content,
- }
- currentContent += choice.Delta.Content
- }
- }
- }
- err := openaiStream.Err()
- if err == nil || errors.Is(err, io.EOF) {
- // Stream completed successfully
- eventChan <- ProviderEvent{
- Type: EventComplete,
- Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: o.usage(acc.ChatCompletion),
- FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)),
- },
- }
- close(eventChan)
- return
- }
- // If there is an error we are going to see if we can retry the call
- retry, after, retryErr := o.shouldRetry(attempts, err)
- if retryErr != nil {
- eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
- close(eventChan)
- return
- }
- if retry {
- logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
- select {
- case <-ctx.Done():
- // context cancelled
- if ctx.Err() == nil {
- eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
- }
- close(eventChan)
- return
- case <-time.After(time.Duration(after) * time.Millisecond):
- continue
- }
- }
- eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
- close(eventChan)
- return
- }
- }()
- return eventChan
- }
- func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
- var apierr *openai.Error
- if !errors.As(err, &apierr) {
- return false, 0, err
- }
- if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
- return false, 0, err
- }
- if attempts > maxRetries {
- return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
- }
- retryMs := 0
- retryAfterValues := apierr.Response.Header.Values("Retry-After")
- backoffMs := 2000 * (1 << (attempts - 1))
- jitterMs := int(float64(backoffMs) * 0.2)
- retryMs = backoffMs + jitterMs
- if len(retryAfterValues) > 0 {
- if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
- retryMs = retryMs * 1000
- }
- }
- return true, int64(retryMs), nil
- }
- func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
- var toolCalls []message.ToolCall
- if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
- for _, call := range completion.Choices[0].Message.ToolCalls {
- toolCall := message.ToolCall{
- ID: call.ID,
- Name: call.Function.Name,
- Input: call.Function.Arguments,
- Type: "function",
- Finished: true,
- }
- toolCalls = append(toolCalls, toolCall)
- }
- }
- return toolCalls
- }
- func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
- cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
- inputTokens := completion.Usage.PromptTokens - cachedTokens
- return TokenUsage{
- InputTokens: inputTokens,
- OutputTokens: completion.Usage.CompletionTokens,
- CacheCreationTokens: 0, // OpenAI doesn't provide this directly
- CacheReadTokens: cachedTokens,
- }
- }
- func WithOpenAIBaseURL(baseURL string) OpenAIOption {
- return func(options *openaiOptions) {
- options.baseURL = baseURL
- }
- }
- func WithOpenAIDisableCache() OpenAIOption {
- return func(options *openaiOptions) {
- options.disableCache = true
- }
- }
- func WithReasoningEffort(effort string) OpenAIOption {
- return func(options *openaiOptions) {
- defaultReasoningEffort := "medium"
- switch effort {
- case "low", "medium", "high":
- defaultReasoningEffort = effort
- default:
- logging.Warn("Invalid reasoning effort, using default: medium")
- }
- options.reasoningEffort = defaultReasoningEffort
- }
- }
|