openai.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. package provider
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "time"
  9. "github.com/openai/openai-go"
  10. "github.com/openai/openai-go/option"
  11. "github.com/openai/openai-go/shared"
  12. "github.com/sst/opencode/internal/config"
  13. "github.com/sst/opencode/internal/llm/models"
  14. "github.com/sst/opencode/internal/llm/tools"
  15. "github.com/sst/opencode/internal/message"
  16. "github.com/sst/opencode/internal/status"
  17. "log/slog"
  18. )
  19. type openaiOptions struct {
  20. baseURL string
  21. disableCache bool
  22. reasoningEffort string
  23. extraHeaders map[string]string
  24. }
  25. type OpenAIOption func(*openaiOptions)
  26. type openaiClient struct {
  27. providerOptions providerClientOptions
  28. options openaiOptions
  29. client openai.Client
  30. }
  31. type OpenAIClient ProviderClient
  32. func newOpenAIClient(opts providerClientOptions) OpenAIClient {
  33. openaiOpts := openaiOptions{
  34. reasoningEffort: "medium",
  35. }
  36. for _, o := range opts.openaiOptions {
  37. o(&openaiOpts)
  38. }
  39. openaiClientOptions := []option.RequestOption{}
  40. if opts.apiKey != "" {
  41. openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
  42. }
  43. if openaiOpts.baseURL != "" {
  44. openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
  45. }
  46. if openaiOpts.extraHeaders != nil {
  47. for key, value := range openaiOpts.extraHeaders {
  48. openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
  49. }
  50. }
  51. client := openai.NewClient(openaiClientOptions...)
  52. return &openaiClient{
  53. providerOptions: opts,
  54. options: openaiOpts,
  55. client: client,
  56. }
  57. }
  58. func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
  59. // Add system message first
  60. openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
  61. for _, msg := range messages {
  62. switch msg.Role {
  63. case message.User:
  64. var content []openai.ChatCompletionContentPartUnionParam
  65. textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
  66. content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
  67. for _, binaryContent := range msg.BinaryContent() {
  68. imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
  69. imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
  70. content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
  71. }
  72. openaiMessages = append(openaiMessages, openai.UserMessage(content))
  73. case message.Assistant:
  74. assistantMsg := openai.ChatCompletionAssistantMessageParam{
  75. Role: "assistant",
  76. }
  77. if msg.Content().String() != "" {
  78. assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
  79. OfString: openai.String(msg.Content().String()),
  80. }
  81. }
  82. if len(msg.ToolCalls()) > 0 {
  83. assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
  84. for i, call := range msg.ToolCalls() {
  85. assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
  86. ID: call.ID,
  87. Type: "function",
  88. Function: openai.ChatCompletionMessageToolCallFunctionParam{
  89. Name: call.Name,
  90. Arguments: call.Input,
  91. },
  92. }
  93. }
  94. }
  95. openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
  96. OfAssistant: &assistantMsg,
  97. })
  98. case message.Tool:
  99. for _, result := range msg.ToolResults() {
  100. openaiMessages = append(openaiMessages,
  101. openai.ToolMessage(result.Content, result.ToolCallID),
  102. )
  103. }
  104. }
  105. }
  106. return
  107. }
  108. func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
  109. openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
  110. for i, tool := range tools {
  111. info := tool.Info()
  112. openaiTools[i] = openai.ChatCompletionToolParam{
  113. Function: openai.FunctionDefinitionParam{
  114. Name: info.Name,
  115. Description: openai.String(info.Description),
  116. Parameters: openai.FunctionParameters{
  117. "type": "object",
  118. "properties": info.Parameters,
  119. "required": info.Required,
  120. },
  121. },
  122. }
  123. }
  124. return openaiTools
  125. }
  126. func (o *openaiClient) finishReason(reason string) message.FinishReason {
  127. switch reason {
  128. case "stop":
  129. return message.FinishReasonEndTurn
  130. case "length":
  131. return message.FinishReasonMaxTokens
  132. case "tool_calls":
  133. return message.FinishReasonToolUse
  134. default:
  135. return message.FinishReasonUnknown
  136. }
  137. }
  138. func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
  139. params := openai.ChatCompletionNewParams{
  140. Model: openai.ChatModel(o.providerOptions.model.APIModel),
  141. Messages: messages,
  142. Tools: tools,
  143. }
  144. if o.providerOptions.model.CanReason == true {
  145. params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
  146. switch o.options.reasoningEffort {
  147. case "low":
  148. params.ReasoningEffort = shared.ReasoningEffortLow
  149. case "medium":
  150. params.ReasoningEffort = shared.ReasoningEffortMedium
  151. case "high":
  152. params.ReasoningEffort = shared.ReasoningEffortHigh
  153. default:
  154. params.ReasoningEffort = shared.ReasoningEffortMedium
  155. }
  156. } else {
  157. params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
  158. }
  159. if o.providerOptions.model.Provider == models.ProviderOpenRouter {
  160. params.WithExtraFields(map[string]any{
  161. "provider": map[string]any{
  162. "require_parameters": true,
  163. },
  164. })
  165. }
  166. return params
  167. }
  168. func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
  169. params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
  170. cfg := config.Get()
  171. if cfg.Debug {
  172. jsonData, _ := json.Marshal(params)
  173. slog.Debug("Prepared messages", "messages", string(jsonData))
  174. }
  175. attempts := 0
  176. for {
  177. attempts++
  178. openaiResponse, err := o.client.Chat.Completions.New(
  179. ctx,
  180. params,
  181. )
  182. // If there is an error we are going to see if we can retry the call
  183. if err != nil {
  184. retry, after, retryErr := o.shouldRetry(attempts, err)
  185. duration := time.Duration(after) * time.Millisecond
  186. if retryErr != nil {
  187. return nil, retryErr
  188. }
  189. if retry {
  190. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  191. select {
  192. case <-ctx.Done():
  193. return nil, ctx.Err()
  194. case <-time.After(duration):
  195. continue
  196. }
  197. }
  198. return nil, retryErr
  199. }
  200. content := ""
  201. if openaiResponse.Choices[0].Message.Content != "" {
  202. content = openaiResponse.Choices[0].Message.Content
  203. }
  204. toolCalls := o.toolCalls(*openaiResponse)
  205. finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
  206. if len(toolCalls) > 0 {
  207. finishReason = message.FinishReasonToolUse
  208. }
  209. return &ProviderResponse{
  210. Content: content,
  211. ToolCalls: toolCalls,
  212. Usage: o.usage(*openaiResponse),
  213. FinishReason: finishReason,
  214. }, nil
  215. }
  216. }
  217. func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  218. params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
  219. params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
  220. IncludeUsage: openai.Bool(true),
  221. }
  222. cfg := config.Get()
  223. if cfg.Debug {
  224. jsonData, _ := json.Marshal(params)
  225. slog.Debug("Prepared messages", "messages", string(jsonData))
  226. }
  227. attempts := 0
  228. eventChan := make(chan ProviderEvent)
  229. go func() {
  230. for {
  231. attempts++
  232. openaiStream := o.client.Chat.Completions.NewStreaming(
  233. ctx,
  234. params,
  235. )
  236. acc := openai.ChatCompletionAccumulator{}
  237. currentContent := ""
  238. toolCalls := make([]message.ToolCall, 0)
  239. for openaiStream.Next() {
  240. chunk := openaiStream.Current()
  241. acc.AddChunk(chunk)
  242. for _, choice := range chunk.Choices {
  243. if choice.Delta.Content != "" {
  244. eventChan <- ProviderEvent{
  245. Type: EventContentDelta,
  246. Content: choice.Delta.Content,
  247. }
  248. currentContent += choice.Delta.Content
  249. }
  250. }
  251. }
  252. err := openaiStream.Err()
  253. if err == nil || errors.Is(err, io.EOF) {
  254. // Stream completed successfully
  255. finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
  256. if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
  257. toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
  258. }
  259. if len(toolCalls) > 0 {
  260. finishReason = message.FinishReasonToolUse
  261. }
  262. eventChan <- ProviderEvent{
  263. Type: EventComplete,
  264. Response: &ProviderResponse{
  265. Content: currentContent,
  266. ToolCalls: toolCalls,
  267. Usage: o.usage(acc.ChatCompletion),
  268. FinishReason: finishReason,
  269. },
  270. }
  271. close(eventChan)
  272. return
  273. }
  274. // If there is an error we are going to see if we can retry the call
  275. retry, after, retryErr := o.shouldRetry(attempts, err)
  276. duration := time.Duration(after) * time.Millisecond
  277. if retryErr != nil {
  278. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  279. close(eventChan)
  280. return
  281. }
  282. if retry {
  283. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  284. select {
  285. case <-ctx.Done():
  286. // context cancelled
  287. if ctx.Err() == nil {
  288. eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
  289. }
  290. close(eventChan)
  291. return
  292. case <-time.After(duration):
  293. continue
  294. }
  295. }
  296. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  297. close(eventChan)
  298. return
  299. }
  300. }()
  301. return eventChan
  302. }
  303. func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
  304. var apierr *openai.Error
  305. if !errors.As(err, &apierr) {
  306. return false, 0, err
  307. }
  308. if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
  309. return false, 0, err
  310. }
  311. if attempts > maxRetries {
  312. return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
  313. }
  314. retryMs := 0
  315. retryAfterValues := apierr.Response.Header.Values("Retry-After")
  316. backoffMs := 2000 * (1 << (attempts - 1))
  317. jitterMs := int(float64(backoffMs) * 0.2)
  318. retryMs = backoffMs + jitterMs
  319. if len(retryAfterValues) > 0 {
  320. if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
  321. retryMs = retryMs * 1000
  322. }
  323. }
  324. return true, int64(retryMs), nil
  325. }
  326. func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
  327. var toolCalls []message.ToolCall
  328. if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
  329. for _, call := range completion.Choices[0].Message.ToolCalls {
  330. toolCall := message.ToolCall{
  331. ID: call.ID,
  332. Name: call.Function.Name,
  333. Input: call.Function.Arguments,
  334. Type: "function",
  335. Finished: true,
  336. }
  337. toolCalls = append(toolCalls, toolCall)
  338. }
  339. }
  340. return toolCalls
  341. }
  342. func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
  343. cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
  344. inputTokens := completion.Usage.PromptTokens - cachedTokens
  345. return TokenUsage{
  346. InputTokens: inputTokens,
  347. OutputTokens: completion.Usage.CompletionTokens,
  348. CacheCreationTokens: 0, // OpenAI doesn't provide this directly
  349. CacheReadTokens: cachedTokens,
  350. }
  351. }
  352. func WithOpenAIBaseURL(baseURL string) OpenAIOption {
  353. return func(options *openaiOptions) {
  354. options.baseURL = baseURL
  355. }
  356. }
  357. func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
  358. return func(options *openaiOptions) {
  359. options.extraHeaders = headers
  360. }
  361. }
  362. func WithOpenAIDisableCache() OpenAIOption {
  363. return func(options *openaiOptions) {
  364. options.disableCache = true
  365. }
  366. }
  367. func WithReasoningEffort(effort string) OpenAIOption {
  368. return func(options *openaiOptions) {
  369. defaultReasoningEffort := "medium"
  370. switch effort {
  371. case "low", "medium", "high":
  372. defaultReasoningEffort = effort
  373. default:
  374. slog.Warn("Invalid reasoning effort, using default: medium")
  375. }
  376. options.reasoningEffort = defaultReasoningEffort
  377. }
  378. }