openai.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package provider
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "log/slog"
  7. "github.com/openai/openai-go"
  8. "github.com/openai/openai-go/option"
  9. "github.com/sst/opencode/internal/llm/models"
  10. "github.com/sst/opencode/internal/llm/tools"
  11. "github.com/sst/opencode/internal/message"
  12. )
  13. type openaiOptions struct {
  14. baseURL string
  15. disableCache bool
  16. reasoningEffort string
  17. extraHeaders map[string]string
  18. }
  19. type OpenAIOption func(*openaiOptions)
  20. type openaiClient struct {
  21. providerOptions providerClientOptions
  22. options openaiOptions
  23. client openai.Client
  24. }
  25. type OpenAIClient ProviderClient
  26. func newOpenAIClient(opts providerClientOptions) OpenAIClient {
  27. openaiOpts := openaiOptions{
  28. reasoningEffort: "medium",
  29. }
  30. for _, o := range opts.openaiOptions {
  31. o(&openaiOpts)
  32. }
  33. openaiClientOptions := []option.RequestOption{}
  34. if opts.apiKey != "" {
  35. openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
  36. }
  37. if openaiOpts.baseURL != "" {
  38. openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
  39. }
  40. if openaiOpts.extraHeaders != nil {
  41. for key, value := range openaiOpts.extraHeaders {
  42. openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
  43. }
  44. }
  45. client := openai.NewClient(openaiClientOptions...)
  46. return &openaiClient{
  47. providerOptions: opts,
  48. options: openaiOpts,
  49. client: client,
  50. }
  51. }
  52. func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
  53. if o.providerOptions.model.ID == models.OpenAIModels[models.CodexMini].ID || o.providerOptions.model.ID == models.OpenAIModels[models.O1Pro].ID {
  54. return o.sendResponseMessages(ctx, messages, tools)
  55. }
  56. return o.sendChatcompletionMessage(ctx, messages, tools)
  57. }
  58. func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  59. if o.providerOptions.model.ID == models.OpenAIModels[models.CodexMini].ID || o.providerOptions.model.ID == models.OpenAIModels[models.O1Pro].ID {
  60. return o.streamResponseMessages(ctx, messages, tools)
  61. }
  62. return o.streamChatCompletionMessages(ctx, messages, tools)
  63. }
  64. func (o *openaiClient) finishReason(reason string) message.FinishReason {
  65. switch reason {
  66. case "stop":
  67. return message.FinishReasonEndTurn
  68. case "length":
  69. return message.FinishReasonMaxTokens
  70. case "tool_calls":
  71. return message.FinishReasonToolUse
  72. default:
  73. return message.FinishReasonUnknown
  74. }
  75. }
  76. func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
  77. var apierr *openai.Error
  78. if !errors.As(err, &apierr) {
  79. return false, 0, err
  80. }
  81. if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
  82. return false, 0, err
  83. }
  84. if attempts > maxRetries {
  85. return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
  86. }
  87. retryMs := 0
  88. retryAfterValues := apierr.Response.Header.Values("Retry-After")
  89. backoffMs := 2000 * (1 << (attempts - 1))
  90. jitterMs := int(float64(backoffMs) * 0.2)
  91. retryMs = backoffMs + jitterMs
  92. if len(retryAfterValues) > 0 {
  93. if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
  94. retryMs = retryMs * 1000
  95. }
  96. }
  97. return true, int64(retryMs), nil
  98. }
  99. func WithOpenAIBaseURL(baseURL string) OpenAIOption {
  100. return func(options *openaiOptions) {
  101. options.baseURL = baseURL
  102. }
  103. }
  104. func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
  105. return func(options *openaiOptions) {
  106. options.extraHeaders = headers
  107. }
  108. }
  109. func WithOpenAIDisableCache() OpenAIOption {
  110. return func(options *openaiOptions) {
  111. options.disableCache = true
  112. }
  113. }
  114. func WithReasoningEffort(effort string) OpenAIOption {
  115. return func(options *openaiOptions) {
  116. defaultReasoningEffort := "medium"
  117. switch effort {
  118. case "low", "medium", "high":
  119. defaultReasoningEffort = effort
  120. default:
  121. slog.Warn("Invalid reasoning effort, using default: medium")
  122. }
  123. options.reasoningEffort = defaultReasoningEffort
  124. }
  125. }