provider.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. package provider
  2. import (
  3. "context"
  4. "fmt"
  5. "log/slog"
  6. "github.com/sst/opencode/internal/llm/models"
  7. "github.com/sst/opencode/internal/llm/tools"
  8. "github.com/sst/opencode/internal/message"
  9. )
  10. type EventType string
  11. const maxRetries = 8
  12. const (
  13. EventContentStart EventType = "content_start"
  14. EventToolUseStart EventType = "tool_use_start"
  15. EventToolUseDelta EventType = "tool_use_delta"
  16. EventToolUseStop EventType = "tool_use_stop"
  17. EventContentDelta EventType = "content_delta"
  18. EventThinkingDelta EventType = "thinking_delta"
  19. EventContentStop EventType = "content_stop"
  20. EventComplete EventType = "complete"
  21. EventError EventType = "error"
  22. EventWarning EventType = "warning"
  23. )
  24. type TokenUsage struct {
  25. InputTokens int64
  26. OutputTokens int64
  27. CacheCreationTokens int64
  28. CacheReadTokens int64
  29. }
  30. type ProviderResponse struct {
  31. Content string
  32. ToolCalls []message.ToolCall
  33. Usage TokenUsage
  34. FinishReason message.FinishReason
  35. }
  36. type ProviderEvent struct {
  37. Type EventType
  38. Content string
  39. Thinking string
  40. Response *ProviderResponse
  41. ToolCall *message.ToolCall
  42. Error error
  43. }
  44. type Provider interface {
  45. SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
  46. StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
  47. Model() models.Model
  48. MaxTokens() int64
  49. }
  50. type providerClientOptions struct {
  51. apiKey string
  52. model models.Model
  53. maxTokens int64
  54. systemMessage string
  55. anthropicOptions []AnthropicOption
  56. openaiOptions []OpenAIOption
  57. geminiOptions []GeminiOption
  58. bedrockOptions []BedrockOption
  59. }
  60. type ProviderClientOption func(*providerClientOptions)
  61. type ProviderClient interface {
  62. send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
  63. stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
  64. }
  65. type baseProvider[C ProviderClient] struct {
  66. options providerClientOptions
  67. client C
  68. }
  69. func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
  70. clientOptions := providerClientOptions{}
  71. for _, o := range opts {
  72. o(&clientOptions)
  73. }
  74. switch providerName {
  75. case models.ProviderAnthropic:
  76. return &baseProvider[AnthropicClient]{
  77. options: clientOptions,
  78. client: newAnthropicClient(clientOptions),
  79. }, nil
  80. case models.ProviderOpenAI:
  81. return &baseProvider[OpenAIClient]{
  82. options: clientOptions,
  83. client: newOpenAIClient(clientOptions),
  84. }, nil
  85. case models.ProviderGemini:
  86. return &baseProvider[GeminiClient]{
  87. options: clientOptions,
  88. client: newGeminiClient(clientOptions),
  89. }, nil
  90. case models.ProviderBedrock:
  91. return &baseProvider[BedrockClient]{
  92. options: clientOptions,
  93. client: newBedrockClient(clientOptions),
  94. }, nil
  95. case models.ProviderGROQ:
  96. clientOptions.openaiOptions = append(clientOptions.openaiOptions,
  97. WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  98. )
  99. return &baseProvider[OpenAIClient]{
  100. options: clientOptions,
  101. client: newOpenAIClient(clientOptions),
  102. }, nil
  103. case models.ProviderAzure:
  104. return &baseProvider[AzureClient]{
  105. options: clientOptions,
  106. client: newAzureClient(clientOptions),
  107. }, nil
  108. case models.ProviderVertexAI:
  109. return &baseProvider[VertexAIClient]{
  110. options: clientOptions,
  111. client: newVertexAIClient(clientOptions),
  112. }, nil
  113. case models.ProviderOpenRouter:
  114. clientOptions.openaiOptions = append(clientOptions.openaiOptions,
  115. WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
  116. WithOpenAIExtraHeaders(map[string]string{
  117. "HTTP-Referer": "opencode.ai",
  118. "X-Title": "OpenCode",
  119. }),
  120. )
  121. return &baseProvider[OpenAIClient]{
  122. options: clientOptions,
  123. client: newOpenAIClient(clientOptions),
  124. }, nil
  125. case models.ProviderXAI:
  126. clientOptions.openaiOptions = append(clientOptions.openaiOptions,
  127. WithOpenAIBaseURL("https://api.x.ai/v1"),
  128. )
  129. return &baseProvider[OpenAIClient]{
  130. options: clientOptions,
  131. client: newOpenAIClient(clientOptions),
  132. }, nil
  133. case models.ProviderMock:
  134. // TODO: implement mock client for test
  135. panic("not implemented")
  136. }
  137. return nil, fmt.Errorf("provider not supported: %s", providerName)
  138. }
  139. func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
  140. for _, msg := range messages {
  141. // The message has no content
  142. if len(msg.Parts) == 0 {
  143. continue
  144. }
  145. cleaned = append(cleaned, msg)
  146. }
  147. return
  148. }
  149. func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
  150. messages = p.cleanMessages(messages)
  151. response, err := p.client.send(ctx, messages, tools)
  152. if err == nil && response != nil {
  153. slog.Debug("API request token usage",
  154. "model", p.options.model.Name,
  155. "input_tokens", response.Usage.InputTokens,
  156. "output_tokens", response.Usage.OutputTokens,
  157. "cache_creation_tokens", response.Usage.CacheCreationTokens,
  158. "cache_read_tokens", response.Usage.CacheReadTokens,
  159. "total_tokens", response.Usage.InputTokens+response.Usage.OutputTokens)
  160. }
  161. return response, err
  162. }
  163. func (p *baseProvider[C]) Model() models.Model {
  164. return p.options.model
  165. }
  166. func (p *baseProvider[C]) MaxTokens() int64 {
  167. return p.options.maxTokens
  168. }
  169. func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  170. messages = p.cleanMessages(messages)
  171. eventChan := p.client.stream(ctx, messages, tools)
  172. // Create a new channel to intercept events
  173. wrappedChan := make(chan ProviderEvent)
  174. go func() {
  175. defer close(wrappedChan)
  176. for event := range eventChan {
  177. // Pass the event through
  178. wrappedChan <- event
  179. // Log token usage when we get the complete event
  180. if event.Type == EventComplete && event.Response != nil {
  181. slog.Debug("API streaming request token usage",
  182. "model", p.options.model.Name,
  183. "input_tokens", event.Response.Usage.InputTokens,
  184. "output_tokens", event.Response.Usage.OutputTokens,
  185. "cache_creation_tokens", event.Response.Usage.CacheCreationTokens,
  186. "cache_read_tokens", event.Response.Usage.CacheReadTokens,
  187. "total_tokens", event.Response.Usage.InputTokens+event.Response.Usage.OutputTokens)
  188. }
  189. }
  190. }()
  191. return wrappedChan
  192. }
  193. func WithAPIKey(apiKey string) ProviderClientOption {
  194. return func(options *providerClientOptions) {
  195. options.apiKey = apiKey
  196. }
  197. }
  198. func WithModel(model models.Model) ProviderClientOption {
  199. return func(options *providerClientOptions) {
  200. options.model = model
  201. }
  202. }
  203. func WithMaxTokens(maxTokens int64) ProviderClientOption {
  204. return func(options *providerClientOptions) {
  205. options.maxTokens = maxTokens
  206. }
  207. }
  208. func WithSystemMessage(systemMessage string) ProviderClientOption {
  209. return func(options *providerClientOptions) {
  210. options.systemMessage = systemMessage
  211. }
  212. }
  213. func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
  214. return func(options *providerClientOptions) {
  215. options.anthropicOptions = anthropicOptions
  216. }
  217. }
  218. func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
  219. return func(options *providerClientOptions) {
  220. options.openaiOptions = openaiOptions
  221. }
  222. }
  223. func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
  224. return func(options *providerClientOptions) {
  225. options.geminiOptions = geminiOptions
  226. }
  227. }
  228. func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
  229. return func(options *providerClientOptions) {
  230. options.bedrockOptions = bedrockOptions
  231. }
  232. }