bedrock.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package provider
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "strings"
  8. "github.com/sst/opencode/internal/llm/tools"
  9. "github.com/sst/opencode/internal/message"
  10. )
  11. type bedrockOptions struct {
  12. // Bedrock specific options can be added here
  13. }
  14. type BedrockOption func(*bedrockOptions)
  15. type bedrockClient struct {
  16. providerOptions providerClientOptions
  17. options bedrockOptions
  18. childProvider ProviderClient
  19. }
  20. type BedrockClient ProviderClient
  21. func newBedrockClient(opts providerClientOptions) BedrockClient {
  22. bedrockOpts := bedrockOptions{}
  23. // Apply bedrock specific options if they are added in the future
  24. // Get AWS region from environment
  25. region := os.Getenv("AWS_REGION")
  26. if region == "" {
  27. region = os.Getenv("AWS_DEFAULT_REGION")
  28. }
  29. if region == "" {
  30. region = "us-east-1" // default region
  31. }
  32. if len(region) < 2 {
  33. return &bedrockClient{
  34. providerOptions: opts,
  35. options: bedrockOpts,
  36. childProvider: nil, // Will cause an error when used
  37. }
  38. }
  39. // Prefix the model name with region
  40. regionPrefix := region[:2]
  41. modelName := opts.model.APIModel
  42. opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
  43. // Determine which provider to use based on the model
  44. if strings.Contains(string(opts.model.APIModel), "anthropic") {
  45. // Create Anthropic client with Bedrock configuration
  46. anthropicOpts := opts
  47. anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
  48. WithAnthropicBedrock(true),
  49. WithAnthropicDisableCache(),
  50. )
  51. return &bedrockClient{
  52. providerOptions: opts,
  53. options: bedrockOpts,
  54. childProvider: newAnthropicClient(anthropicOpts),
  55. }
  56. }
  57. // Return client with nil childProvider if model is not supported
  58. // This will cause an error when used
  59. return &bedrockClient{
  60. providerOptions: opts,
  61. options: bedrockOpts,
  62. childProvider: nil,
  63. }
  64. }
  65. func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
  66. if b.childProvider == nil {
  67. return nil, errors.New("unsupported model for bedrock provider")
  68. }
  69. return b.childProvider.send(ctx, messages, tools)
  70. }
  71. func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  72. eventChan := make(chan ProviderEvent)
  73. if b.childProvider == nil {
  74. go func() {
  75. eventChan <- ProviderEvent{
  76. Type: EventError,
  77. Error: errors.New("unsupported model for bedrock provider"),
  78. }
  79. close(eventChan)
  80. }()
  81. return eventChan
  82. }
  83. return b.childProvider.stream(ctx, messages, tools)
  84. }