openai_completion.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package provider
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "time"
  10. "github.com/openai/openai-go"
  11. "github.com/openai/openai-go/shared"
  12. "github.com/sst/opencode/internal/config"
  13. "github.com/sst/opencode/internal/llm/models"
  14. "github.com/sst/opencode/internal/llm/tools"
  15. "github.com/sst/opencode/internal/message"
  16. "github.com/sst/opencode/internal/status"
  17. )
  18. func (o *openaiClient) convertMessagesToChatCompletionMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
  19. // Add system message first
  20. openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
  21. for _, msg := range messages {
  22. switch msg.Role {
  23. case message.User:
  24. var content []openai.ChatCompletionContentPartUnionParam
  25. textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
  26. content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
  27. for _, binaryContent := range msg.BinaryContent() {
  28. imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
  29. imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
  30. content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
  31. }
  32. openaiMessages = append(openaiMessages, openai.UserMessage(content))
  33. case message.Assistant:
  34. assistantMsg := openai.ChatCompletionAssistantMessageParam{
  35. Role: "assistant",
  36. }
  37. if msg.Content().String() != "" {
  38. assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
  39. OfString: openai.String(msg.Content().String()),
  40. }
  41. }
  42. if len(msg.ToolCalls()) > 0 {
  43. assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
  44. for i, call := range msg.ToolCalls() {
  45. assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
  46. ID: call.ID,
  47. Type: "function",
  48. Function: openai.ChatCompletionMessageToolCallFunctionParam{
  49. Name: call.Name,
  50. Arguments: call.Input,
  51. },
  52. }
  53. }
  54. }
  55. openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
  56. OfAssistant: &assistantMsg,
  57. })
  58. case message.Tool:
  59. for _, result := range msg.ToolResults() {
  60. openaiMessages = append(openaiMessages,
  61. openai.ToolMessage(result.Content, result.ToolCallID),
  62. )
  63. }
  64. }
  65. }
  66. return
  67. }
  68. func (o *openaiClient) convertToChatCompletionTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
  69. openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
  70. for i, tool := range tools {
  71. info := tool.Info()
  72. openaiTools[i] = openai.ChatCompletionToolParam{
  73. Function: openai.FunctionDefinitionParam{
  74. Name: info.Name,
  75. Description: openai.String(info.Description),
  76. Parameters: openai.FunctionParameters{
  77. "type": "object",
  78. "properties": info.Parameters,
  79. "required": info.Required,
  80. },
  81. },
  82. }
  83. }
  84. return openaiTools
  85. }
  86. func (o *openaiClient) preparedChatCompletionParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
  87. params := openai.ChatCompletionNewParams{
  88. Model: openai.ChatModel(o.providerOptions.model.APIModel),
  89. Messages: messages,
  90. Tools: tools,
  91. }
  92. if o.providerOptions.model.CanReason == true {
  93. params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
  94. switch o.options.reasoningEffort {
  95. case "low":
  96. params.ReasoningEffort = shared.ReasoningEffortLow
  97. case "medium":
  98. params.ReasoningEffort = shared.ReasoningEffortMedium
  99. case "high":
  100. params.ReasoningEffort = shared.ReasoningEffortHigh
  101. default:
  102. params.ReasoningEffort = shared.ReasoningEffortMedium
  103. }
  104. } else {
  105. params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
  106. }
  107. if o.providerOptions.model.Provider == models.ProviderOpenRouter {
  108. params.WithExtraFields(map[string]any{
  109. "provider": map[string]any{
  110. "require_parameters": true,
  111. },
  112. })
  113. }
  114. return params
  115. }
  116. func (o *openaiClient) sendChatcompletionMessage(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
  117. params := o.preparedChatCompletionParams(o.convertMessagesToChatCompletionMessages(messages), o.convertToChatCompletionTools(tools))
  118. cfg := config.Get()
  119. if cfg.Debug {
  120. jsonData, _ := json.Marshal(params)
  121. slog.Debug("Prepared messages", "messages", string(jsonData))
  122. }
  123. attempts := 0
  124. for {
  125. attempts++
  126. openaiResponse, err := o.client.Chat.Completions.New(
  127. ctx,
  128. params,
  129. )
  130. // If there is an error we are going to see if we can retry the call
  131. if err != nil {
  132. retry, after, retryErr := o.shouldRetry(attempts, err)
  133. duration := time.Duration(after) * time.Millisecond
  134. if retryErr != nil {
  135. return nil, retryErr
  136. }
  137. if retry {
  138. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  139. select {
  140. case <-ctx.Done():
  141. return nil, ctx.Err()
  142. case <-time.After(duration):
  143. continue
  144. }
  145. }
  146. return nil, retryErr
  147. }
  148. content := ""
  149. if openaiResponse.Choices[0].Message.Content != "" {
  150. content = openaiResponse.Choices[0].Message.Content
  151. }
  152. toolCalls := o.chatCompletionToolCalls(*openaiResponse)
  153. finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
  154. if len(toolCalls) > 0 {
  155. finishReason = message.FinishReasonToolUse
  156. }
  157. return &ProviderResponse{
  158. Content: content,
  159. ToolCalls: toolCalls,
  160. Usage: o.usage(*openaiResponse),
  161. FinishReason: finishReason,
  162. }, nil
  163. }
  164. }
  165. func (o *openaiClient) streamChatCompletionMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  166. params := o.preparedChatCompletionParams(o.convertMessagesToChatCompletionMessages(messages), o.convertToChatCompletionTools(tools))
  167. params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
  168. IncludeUsage: openai.Bool(true),
  169. }
  170. cfg := config.Get()
  171. if cfg.Debug {
  172. jsonData, _ := json.Marshal(params)
  173. slog.Debug("Prepared messages", "messages", string(jsonData))
  174. }
  175. attempts := 0
  176. eventChan := make(chan ProviderEvent)
  177. go func() {
  178. for {
  179. attempts++
  180. openaiStream := o.client.Chat.Completions.NewStreaming(
  181. ctx,
  182. params,
  183. )
  184. acc := openai.ChatCompletionAccumulator{}
  185. currentContent := ""
  186. toolCalls := make([]message.ToolCall, 0)
  187. for openaiStream.Next() {
  188. chunk := openaiStream.Current()
  189. acc.AddChunk(chunk)
  190. for _, choice := range chunk.Choices {
  191. if choice.Delta.Content != "" {
  192. eventChan <- ProviderEvent{
  193. Type: EventContentDelta,
  194. Content: choice.Delta.Content,
  195. }
  196. currentContent += choice.Delta.Content
  197. }
  198. }
  199. }
  200. err := openaiStream.Err()
  201. if err == nil || errors.Is(err, io.EOF) {
  202. // Stream completed successfully
  203. finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
  204. if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
  205. toolCalls = append(toolCalls, o.chatCompletionToolCalls(acc.ChatCompletion)...)
  206. }
  207. if len(toolCalls) > 0 {
  208. finishReason = message.FinishReasonToolUse
  209. }
  210. eventChan <- ProviderEvent{
  211. Type: EventComplete,
  212. Response: &ProviderResponse{
  213. Content: currentContent,
  214. ToolCalls: toolCalls,
  215. Usage: o.usage(acc.ChatCompletion),
  216. FinishReason: finishReason,
  217. },
  218. }
  219. close(eventChan)
  220. return
  221. }
  222. // If there is an error we are going to see if we can retry the call
  223. retry, after, retryErr := o.shouldRetry(attempts, err)
  224. duration := time.Duration(after) * time.Millisecond
  225. if retryErr != nil {
  226. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  227. close(eventChan)
  228. return
  229. }
  230. if retry {
  231. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  232. select {
  233. case <-ctx.Done():
  234. // context cancelled
  235. if ctx.Err() == nil {
  236. eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
  237. }
  238. close(eventChan)
  239. return
  240. case <-time.After(duration):
  241. continue
  242. }
  243. }
  244. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  245. close(eventChan)
  246. return
  247. }
  248. }()
  249. return eventChan
  250. }
  251. func (o *openaiClient) chatCompletionToolCalls(completion openai.ChatCompletion) []message.ToolCall {
  252. var toolCalls []message.ToolCall
  253. if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
  254. for _, call := range completion.Choices[0].Message.ToolCalls {
  255. toolCall := message.ToolCall{
  256. ID: call.ID,
  257. Name: call.Function.Name,
  258. Input: call.Function.Arguments,
  259. Type: "function",
  260. Finished: true,
  261. }
  262. toolCalls = append(toolCalls, toolCall)
  263. }
  264. }
  265. return toolCalls
  266. }
  267. func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
  268. cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
  269. inputTokens := completion.Usage.PromptTokens - cachedTokens
  270. return TokenUsage{
  271. InputTokens: inputTokens,
  272. OutputTokens: completion.Usage.CompletionTokens,
  273. CacheCreationTokens: 0, // OpenAI doesn't provide this directly
  274. CacheReadTokens: cachedTokens,
  275. }
  276. }