| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- package provider
- import (
- "context"
- "fmt"
- "log/slog"
- "github.com/sst/opencode/internal/llm/models"
- "github.com/sst/opencode/internal/llm/tools"
- "github.com/sst/opencode/internal/message"
- )
- type EventType string
- const maxRetries = 8
- const (
- EventContentStart EventType = "content_start"
- EventToolUseStart EventType = "tool_use_start"
- EventToolUseDelta EventType = "tool_use_delta"
- EventToolUseStop EventType = "tool_use_stop"
- EventContentDelta EventType = "content_delta"
- EventThinkingDelta EventType = "thinking_delta"
- EventContentStop EventType = "content_stop"
- EventComplete EventType = "complete"
- EventError EventType = "error"
- EventWarning EventType = "warning"
- )
- type TokenUsage struct {
- InputTokens int64
- OutputTokens int64
- CacheCreationTokens int64
- CacheReadTokens int64
- }
- type ProviderResponse struct {
- Content string
- ToolCalls []message.ToolCall
- Usage TokenUsage
- FinishReason message.FinishReason
- }
- type ProviderEvent struct {
- Type EventType
- Content string
- Thinking string
- Response *ProviderResponse
- ToolCall *message.ToolCall
- Error error
- }
- type Provider interface {
- SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
- StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
- Model() models.Model
- MaxTokens() int64
- }
- type providerClientOptions struct {
- apiKey string
- model models.Model
- maxTokens int64
- systemMessage string
- anthropicOptions []AnthropicOption
- openaiOptions []OpenAIOption
- geminiOptions []GeminiOption
- bedrockOptions []BedrockOption
- }
- type ProviderClientOption func(*providerClientOptions)
- type ProviderClient interface {
- send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
- stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
- }
- type baseProvider[C ProviderClient] struct {
- options providerClientOptions
- client C
- }
- func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
- clientOptions := providerClientOptions{}
- for _, o := range opts {
- o(&clientOptions)
- }
- switch providerName {
- case models.ProviderAnthropic:
- return &baseProvider[AnthropicClient]{
- options: clientOptions,
- client: newAnthropicClient(clientOptions),
- }, nil
- case models.ProviderOpenAI:
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderGemini:
- return &baseProvider[GeminiClient]{
- options: clientOptions,
- client: newGeminiClient(clientOptions),
- }, nil
- case models.ProviderBedrock:
- return &baseProvider[BedrockClient]{
- options: clientOptions,
- client: newBedrockClient(clientOptions),
- }, nil
- case models.ProviderGROQ:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
- )
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderAzure:
- return &baseProvider[AzureClient]{
- options: clientOptions,
- client: newAzureClient(clientOptions),
- }, nil
- case models.ProviderVertexAI:
- return &baseProvider[VertexAIClient]{
- options: clientOptions,
- client: newVertexAIClient(clientOptions),
- }, nil
- case models.ProviderOpenRouter:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
- WithOpenAIExtraHeaders(map[string]string{
- "HTTP-Referer": "opencode.ai",
- "X-Title": "OpenCode",
- }),
- )
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderXAI:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://api.x.ai/v1"),
- )
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
- case models.ProviderMock:
- // TODO: implement mock client for test
- panic("not implemented")
- }
- return nil, fmt.Errorf("provider not supported: %s", providerName)
- }
- func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
- for _, msg := range messages {
- // The message has no content
- if len(msg.Parts) == 0 {
- continue
- }
- cleaned = append(cleaned, msg)
- }
- return
- }
- func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = p.cleanMessages(messages)
- response, err := p.client.send(ctx, messages, tools)
- if err == nil && response != nil {
- slog.Debug("API request token usage",
- "model", p.options.model.Name,
- "input_tokens", response.Usage.InputTokens,
- "output_tokens", response.Usage.OutputTokens,
- "cache_creation_tokens", response.Usage.CacheCreationTokens,
- "cache_read_tokens", response.Usage.CacheReadTokens,
- "total_tokens", response.Usage.InputTokens+response.Usage.OutputTokens)
- }
- return response, err
- }
- func (p *baseProvider[C]) Model() models.Model {
- return p.options.model
- }
- func (p *baseProvider[C]) MaxTokens() int64 {
- return p.options.maxTokens
- }
- func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- messages = p.cleanMessages(messages)
- eventChan := p.client.stream(ctx, messages, tools)
- // Create a new channel to intercept events
- wrappedChan := make(chan ProviderEvent)
- go func() {
- defer close(wrappedChan)
- for event := range eventChan {
- // Pass the event through
- wrappedChan <- event
- // Log token usage when we get the complete event
- if event.Type == EventComplete && event.Response != nil {
- slog.Debug("API streaming request token usage",
- "model", p.options.model.Name,
- "input_tokens", event.Response.Usage.InputTokens,
- "output_tokens", event.Response.Usage.OutputTokens,
- "cache_creation_tokens", event.Response.Usage.CacheCreationTokens,
- "cache_read_tokens", event.Response.Usage.CacheReadTokens,
- "total_tokens", event.Response.Usage.InputTokens+event.Response.Usage.OutputTokens)
- }
- }
- }()
- return wrappedChan
- }
- func WithAPIKey(apiKey string) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.apiKey = apiKey
- }
- }
- func WithModel(model models.Model) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.model = model
- }
- }
- func WithMaxTokens(maxTokens int64) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.maxTokens = maxTokens
- }
- }
- func WithSystemMessage(systemMessage string) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.systemMessage = systemMessage
- }
- }
- func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.anthropicOptions = anthropicOptions
- }
- }
- func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.openaiOptions = openaiOptions
- }
- }
- func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.geminiOptions = geminiOptions
- }
- }
- func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.bedrockOptions = bedrockOptions
- }
- }
|