openai_response.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. package provider
  2. import (
  3. "github.com/openai/openai-go"
  4. "github.com/openai/openai-go/responses"
  5. "github.com/sst/opencode/internal/llm/models"
  6. "github.com/sst/opencode/internal/llm/tools"
  7. "github.com/sst/opencode/internal/message"
  8. "context"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "time"
  14. "log/slog"
  15. "github.com/openai/openai-go/shared"
  16. "github.com/sst/opencode/internal/config"
  17. "github.com/sst/opencode/internal/status"
  18. )
  19. func (o *openaiClient) convertMessagesToResponseParams(messages []message.Message) responses.ResponseInputParam {
  20. inputItems := responses.ResponseInputParam{}
  21. inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
  22. OfMessage: &responses.EasyInputMessageParam{
  23. Content: responses.EasyInputMessageContentUnionParam{OfString: openai.String(o.providerOptions.systemMessage)},
  24. Role: responses.EasyInputMessageRoleSystem,
  25. },
  26. })
  27. for _, msg := range messages {
  28. switch msg.Role {
  29. case message.User:
  30. inputItemContentList := responses.ResponseInputMessageContentListParam{
  31. responses.ResponseInputContentUnionParam{
  32. OfInputText: &responses.ResponseInputTextParam{
  33. Text: msg.Content().String(),
  34. },
  35. },
  36. }
  37. for _, binaryContent := range msg.BinaryContent() {
  38. inputItemContentList = append(inputItemContentList, responses.ResponseInputContentUnionParam{
  39. OfInputImage: &responses.ResponseInputImageParam{
  40. ImageURL: openai.String(binaryContent.String(models.ProviderOpenAI)),
  41. },
  42. })
  43. }
  44. userMsg := responses.ResponseInputItemUnionParam{
  45. OfInputMessage: &responses.ResponseInputItemMessageParam{
  46. Content: inputItemContentList,
  47. Role: string(responses.ResponseInputMessageItemRoleUser),
  48. },
  49. }
  50. inputItems = append(inputItems, userMsg)
  51. case message.Assistant:
  52. if msg.Content().String() != "" {
  53. assistantMsg := responses.ResponseInputItemUnionParam{
  54. OfOutputMessage: &responses.ResponseOutputMessageParam{
  55. Content: []responses.ResponseOutputMessageContentUnionParam{{
  56. OfOutputText: &responses.ResponseOutputTextParam{
  57. Text: msg.Content().String(),
  58. },
  59. }},
  60. },
  61. }
  62. inputItems = append(inputItems, assistantMsg)
  63. }
  64. if len(msg.ToolCalls()) > 0 {
  65. for _, call := range msg.ToolCalls() {
  66. toolMsg := responses.ResponseInputItemUnionParam{
  67. OfFunctionCall: &responses.ResponseFunctionToolCallParam{
  68. CallID: call.ID,
  69. Name: call.Name,
  70. Arguments: call.Input,
  71. },
  72. }
  73. inputItems = append(inputItems, toolMsg)
  74. }
  75. }
  76. case message.Tool:
  77. for _, result := range msg.ToolResults() {
  78. toolMsg := responses.ResponseInputItemUnionParam{
  79. OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
  80. Output: result.Content,
  81. CallID: result.ToolCallID,
  82. },
  83. }
  84. inputItems = append(inputItems, toolMsg)
  85. }
  86. }
  87. }
  88. return inputItems
  89. }
  90. func (o *openaiClient) convertToResponseTools(tools []tools.BaseTool) []responses.ToolUnionParam {
  91. outputTools := make([]responses.ToolUnionParam, len(tools))
  92. for i, tool := range tools {
  93. info := tool.Info()
  94. outputTools[i] = responses.ToolUnionParam{
  95. OfFunction: &responses.FunctionToolParam{
  96. Name: info.Name,
  97. Description: openai.String(info.Description),
  98. Parameters: map[string]any{
  99. "type": "object",
  100. "properties": info.Parameters,
  101. "required": info.Required,
  102. },
  103. },
  104. }
  105. }
  106. return outputTools
  107. }
  108. func (o *openaiClient) preparedResponseParams(input responses.ResponseInputParam, tools []responses.ToolUnionParam) responses.ResponseNewParams {
  109. params := responses.ResponseNewParams{
  110. Model: shared.ResponsesModel(o.providerOptions.model.APIModel),
  111. Input: responses.ResponseNewParamsInputUnion{OfInputItemList: input},
  112. Tools: tools,
  113. }
  114. params.MaxOutputTokens = openai.Int(o.providerOptions.maxTokens)
  115. if o.providerOptions.model.CanReason == true {
  116. switch o.options.reasoningEffort {
  117. case "low":
  118. params.Reasoning.Effort = shared.ReasoningEffortLow
  119. case "medium":
  120. params.Reasoning.Effort = shared.ReasoningEffortMedium
  121. case "high":
  122. params.Reasoning.Effort = shared.ReasoningEffortHigh
  123. default:
  124. params.Reasoning.Effort = shared.ReasoningEffortMedium
  125. }
  126. }
  127. if o.providerOptions.model.Provider == models.ProviderOpenRouter {
  128. params.WithExtraFields(map[string]any{
  129. "provider": map[string]any{
  130. "require_parameters": true,
  131. },
  132. })
  133. }
  134. return params
  135. }
  136. func (o *openaiClient) sendResponseMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
  137. params := o.preparedResponseParams(o.convertMessagesToResponseParams(messages), o.convertToResponseTools(tools))
  138. cfg := config.Get()
  139. if cfg.Debug {
  140. jsonData, _ := json.Marshal(params)
  141. slog.Debug("Prepared messages", "messages", string(jsonData))
  142. }
  143. attempts := 0
  144. for {
  145. attempts++
  146. openaiResponse, err := o.client.Responses.New(
  147. ctx,
  148. params,
  149. )
  150. // If there is an error we are going to see if we can retry the call
  151. if err != nil {
  152. retry, after, retryErr := o.shouldRetry(attempts, err)
  153. duration := time.Duration(after) * time.Millisecond
  154. if retryErr != nil {
  155. return nil, retryErr
  156. }
  157. if retry {
  158. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  159. select {
  160. case <-ctx.Done():
  161. return nil, ctx.Err()
  162. case <-time.After(duration):
  163. continue
  164. }
  165. }
  166. return nil, retryErr
  167. }
  168. content := ""
  169. if openaiResponse.OutputText() != "" {
  170. content = openaiResponse.OutputText()
  171. }
  172. toolCalls := o.responseToolCalls(*openaiResponse)
  173. finishReason := o.finishReason("stop")
  174. if len(toolCalls) > 0 {
  175. finishReason = message.FinishReasonToolUse
  176. }
  177. return &ProviderResponse{
  178. Content: content,
  179. ToolCalls: toolCalls,
  180. Usage: o.responseUsage(*openaiResponse),
  181. FinishReason: finishReason,
  182. }, nil
  183. }
  184. }
  185. func (o *openaiClient) streamResponseMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  186. eventChan := make(chan ProviderEvent)
  187. params := o.preparedResponseParams(o.convertMessagesToResponseParams(messages), o.convertToResponseTools(tools))
  188. cfg := config.Get()
  189. if cfg.Debug {
  190. jsonData, _ := json.Marshal(params)
  191. slog.Debug("Prepared messages", "messages", string(jsonData))
  192. }
  193. attempts := 0
  194. go func() {
  195. for {
  196. attempts++
  197. stream := o.client.Responses.NewStreaming(ctx, params)
  198. outputText := ""
  199. currentToolCallID := ""
  200. for stream.Next() {
  201. event := stream.Current()
  202. switch event := event.AsAny().(type) {
  203. case responses.ResponseCompletedEvent:
  204. toolCalls := o.responseToolCalls(event.Response)
  205. finishReason := o.finishReason("stop")
  206. if len(toolCalls) > 0 {
  207. finishReason = message.FinishReasonToolUse
  208. }
  209. eventChan <- ProviderEvent{
  210. Type: EventComplete,
  211. Response: &ProviderResponse{
  212. Content: outputText,
  213. ToolCalls: toolCalls,
  214. Usage: o.responseUsage(event.Response),
  215. FinishReason: finishReason,
  216. },
  217. }
  218. close(eventChan)
  219. return
  220. case responses.ResponseTextDeltaEvent:
  221. outputText += event.Delta
  222. eventChan <- ProviderEvent{
  223. Type: EventContentDelta,
  224. Content: event.Delta,
  225. }
  226. case responses.ResponseTextDoneEvent:
  227. eventChan <- ProviderEvent{
  228. Type: EventContentStop,
  229. Content: outputText,
  230. }
  231. close(eventChan)
  232. return
  233. case responses.ResponseOutputItemAddedEvent:
  234. if event.Item.Type == "function_call" {
  235. currentToolCallID = event.Item.ID
  236. eventChan <- ProviderEvent{
  237. Type: EventToolUseStart,
  238. ToolCall: &message.ToolCall{
  239. ID: event.Item.ID,
  240. Name: event.Item.Name,
  241. Finished: false,
  242. },
  243. }
  244. }
  245. case responses.ResponseFunctionCallArgumentsDeltaEvent:
  246. if event.ItemID == currentToolCallID {
  247. eventChan <- ProviderEvent{
  248. Type: EventToolUseDelta,
  249. ToolCall: &message.ToolCall{
  250. ID: currentToolCallID,
  251. Finished: false,
  252. Input: event.Delta,
  253. },
  254. }
  255. }
  256. case responses.ResponseFunctionCallArgumentsDoneEvent:
  257. if event.ItemID == currentToolCallID {
  258. eventChan <- ProviderEvent{
  259. Type: EventToolUseStop,
  260. ToolCall: &message.ToolCall{
  261. ID: currentToolCallID,
  262. Input: event.Arguments,
  263. },
  264. }
  265. currentToolCallID = ""
  266. }
  267. case responses.ResponseOutputItemDoneEvent:
  268. if event.Item.Type == "function_call" {
  269. eventChan <- ProviderEvent{
  270. Type: EventToolUseStop,
  271. ToolCall: &message.ToolCall{
  272. ID: event.Item.ID,
  273. Name: event.Item.Name,
  274. Input: event.Item.Arguments,
  275. Finished: true,
  276. },
  277. }
  278. currentToolCallID = ""
  279. }
  280. }
  281. }
  282. err := stream.Err()
  283. if err == nil || errors.Is(err, io.EOF) {
  284. close(eventChan)
  285. return
  286. }
  287. // If there is an error we are going to see if we can retry the call
  288. retry, after, retryErr := o.shouldRetry(attempts, err)
  289. duration := time.Duration(after) * time.Millisecond
  290. if retryErr != nil {
  291. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  292. close(eventChan)
  293. return
  294. }
  295. if retry {
  296. status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
  297. select {
  298. case <-ctx.Done():
  299. // context cancelled
  300. if ctx.Err() == nil {
  301. eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
  302. }
  303. close(eventChan)
  304. return
  305. case <-time.After(duration):
  306. continue
  307. }
  308. }
  309. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  310. close(eventChan)
  311. return
  312. }
  313. }()
  314. return eventChan
  315. }
  316. func (o *openaiClient) responseToolCalls(response responses.Response) []message.ToolCall {
  317. var toolCalls []message.ToolCall
  318. for _, output := range response.Output {
  319. if output.Type == "function_call" {
  320. call := output.AsFunctionCall()
  321. toolCall := message.ToolCall{
  322. ID: call.ID,
  323. Name: call.Name,
  324. Input: call.Arguments,
  325. Type: "function",
  326. Finished: true,
  327. }
  328. toolCalls = append(toolCalls, toolCall)
  329. }
  330. }
  331. return toolCalls
  332. }
  333. func (o *openaiClient) responseUsage(response responses.Response) TokenUsage {
  334. cachedTokens := response.Usage.InputTokensDetails.CachedTokens
  335. inputTokens := response.Usage.InputTokens - cachedTokens
  336. return TokenUsage{
  337. InputTokens: inputTokens,
  338. OutputTokens: response.Usage.OutputTokens,
  339. CacheCreationTokens: 0, // OpenAI doesn't provide this directly
  340. CacheReadTokens: cachedTokens,
  341. }
  342. }