provider.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package provider
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/opencode-ai/opencode/internal/llm/models"
  6. "github.com/opencode-ai/opencode/internal/llm/tools"
  7. "github.com/opencode-ai/opencode/internal/message"
  8. )
  9. type EventType string
  10. const maxRetries = 8
  11. const (
  12. EventContentStart EventType = "content_start"
  13. EventToolUseStart EventType = "tool_use_start"
  14. EventToolUseDelta EventType = "tool_use_delta"
  15. EventToolUseStop EventType = "tool_use_stop"
  16. EventContentDelta EventType = "content_delta"
  17. EventThinkingDelta EventType = "thinking_delta"
  18. EventContentStop EventType = "content_stop"
  19. EventComplete EventType = "complete"
  20. EventError EventType = "error"
  21. EventWarning EventType = "warning"
  22. )
  23. type TokenUsage struct {
  24. InputTokens int64
  25. OutputTokens int64
  26. CacheCreationTokens int64
  27. CacheReadTokens int64
  28. }
  29. type ProviderResponse struct {
  30. Content string
  31. ToolCalls []message.ToolCall
  32. Usage TokenUsage
  33. FinishReason message.FinishReason
  34. }
  35. type ProviderEvent struct {
  36. Type EventType
  37. Content string
  38. Thinking string
  39. Response *ProviderResponse
  40. ToolCall *message.ToolCall
  41. Error error
  42. }
  43. type Provider interface {
  44. SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
  45. StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
  46. Model() models.Model
  47. }
  48. type providerClientOptions struct {
  49. apiKey string
  50. model models.Model
  51. maxTokens int64
  52. systemMessage string
  53. anthropicOptions []AnthropicOption
  54. openaiOptions []OpenAIOption
  55. geminiOptions []GeminiOption
  56. bedrockOptions []BedrockOption
  57. }
  58. type ProviderClientOption func(*providerClientOptions)
  59. type ProviderClient interface {
  60. send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
  61. stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
  62. }
  63. type baseProvider[C ProviderClient] struct {
  64. options providerClientOptions
  65. client C
  66. }
  67. func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
  68. clientOptions := providerClientOptions{}
  69. for _, o := range opts {
  70. o(&clientOptions)
  71. }
  72. switch providerName {
  73. case models.ProviderAnthropic:
  74. return &baseProvider[AnthropicClient]{
  75. options: clientOptions,
  76. client: newAnthropicClient(clientOptions),
  77. }, nil
  78. case models.ProviderOpenAI:
  79. return &baseProvider[OpenAIClient]{
  80. options: clientOptions,
  81. client: newOpenAIClient(clientOptions),
  82. }, nil
  83. case models.ProviderGemini:
  84. return &baseProvider[GeminiClient]{
  85. options: clientOptions,
  86. client: newGeminiClient(clientOptions),
  87. }, nil
  88. case models.ProviderBedrock:
  89. return &baseProvider[BedrockClient]{
  90. options: clientOptions,
  91. client: newBedrockClient(clientOptions),
  92. }, nil
  93. case models.ProviderGROQ:
  94. clientOptions.openaiOptions = append(clientOptions.openaiOptions,
  95. WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
  96. )
  97. return &baseProvider[OpenAIClient]{
  98. options: clientOptions,
  99. client: newOpenAIClient(clientOptions),
  100. }, nil
  101. case models.ProviderAzure:
  102. return &baseProvider[AzureClient]{
  103. options: clientOptions,
  104. client: newAzureClient(clientOptions),
  105. }, nil
  106. case models.ProviderOpenRouter:
  107. clientOptions.openaiOptions = append(clientOptions.openaiOptions,
  108. WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
  109. WithOpenAIExtraHeaders(map[string]string{
  110. "HTTP-Referer": "opencode.ai",
  111. "X-Title": "OpenCode",
  112. }),
  113. )
  114. return &baseProvider[OpenAIClient]{
  115. options: clientOptions,
  116. client: newOpenAIClient(clientOptions),
  117. }, nil
  118. case models.ProviderMock:
  119. // TODO: implement mock client for test
  120. panic("not implemented")
  121. }
  122. return nil, fmt.Errorf("provider not supported: %s", providerName)
  123. }
  124. func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
  125. for _, msg := range messages {
  126. // The message has no content
  127. if len(msg.Parts) == 0 {
  128. continue
  129. }
  130. cleaned = append(cleaned, msg)
  131. }
  132. return
  133. }
  134. func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
  135. messages = p.cleanMessages(messages)
  136. return p.client.send(ctx, messages, tools)
  137. }
  138. func (p *baseProvider[C]) Model() models.Model {
  139. return p.options.model
  140. }
  141. func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  142. messages = p.cleanMessages(messages)
  143. return p.client.stream(ctx, messages, tools)
  144. }
  145. func WithAPIKey(apiKey string) ProviderClientOption {
  146. return func(options *providerClientOptions) {
  147. options.apiKey = apiKey
  148. }
  149. }
  150. func WithModel(model models.Model) ProviderClientOption {
  151. return func(options *providerClientOptions) {
  152. options.model = model
  153. }
  154. }
  155. func WithMaxTokens(maxTokens int64) ProviderClientOption {
  156. return func(options *providerClientOptions) {
  157. options.maxTokens = maxTokens
  158. }
  159. }
  160. func WithSystemMessage(systemMessage string) ProviderClientOption {
  161. return func(options *providerClientOptions) {
  162. options.systemMessage = systemMessage
  163. }
  164. }
  165. func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
  166. return func(options *providerClientOptions) {
  167. options.anthropicOptions = anthropicOptions
  168. }
  169. }
  170. func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
  171. return func(options *providerClientOptions) {
  172. options.openaiOptions = openaiOptions
  173. }
  174. }
  175. func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
  176. return func(options *providerClientOptions) {
  177. options.geminiOptions = geminiOptions
  178. }
  179. }
  180. func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
  181. return func(options *providerClientOptions) {
  182. options.bedrockOptions = bedrockOptions
  183. }
  184. }