openai.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. package provider
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "time"
  9. "github.com/opencode-ai/opencode/internal/config"
  10. "github.com/opencode-ai/opencode/internal/llm/tools"
  11. "github.com/opencode-ai/opencode/internal/logging"
  12. "github.com/opencode-ai/opencode/internal/message"
  13. "github.com/openai/openai-go"
  14. "github.com/openai/openai-go/option"
  15. "github.com/openai/openai-go/shared"
  16. )
  17. type openaiOptions struct {
  18. baseURL string
  19. disableCache bool
  20. reasoningEffort string
  21. }
  22. type OpenAIOption func(*openaiOptions)
  23. type openaiClient struct {
  24. providerOptions providerClientOptions
  25. options openaiOptions
  26. client openai.Client
  27. }
  28. type OpenAIClient ProviderClient
  29. func newOpenAIClient(opts providerClientOptions) OpenAIClient {
  30. openaiOpts := openaiOptions{
  31. reasoningEffort: "medium",
  32. }
  33. for _, o := range opts.openaiOptions {
  34. o(&openaiOpts)
  35. }
  36. openaiClientOptions := []option.RequestOption{}
  37. if opts.apiKey != "" {
  38. openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
  39. }
  40. if openaiOpts.baseURL != "" {
  41. openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
  42. }
  43. client := openai.NewClient(openaiClientOptions...)
  44. return &openaiClient{
  45. providerOptions: opts,
  46. options: openaiOpts,
  47. client: client,
  48. }
  49. }
  50. func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
  51. // Add system message first
  52. openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
  53. for _, msg := range messages {
  54. switch msg.Role {
  55. case message.User:
  56. openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
  57. case message.Assistant:
  58. assistantMsg := openai.ChatCompletionAssistantMessageParam{
  59. Role: "assistant",
  60. }
  61. if msg.Content().String() != "" {
  62. assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
  63. OfString: openai.String(msg.Content().String()),
  64. }
  65. }
  66. if len(msg.ToolCalls()) > 0 {
  67. assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
  68. for i, call := range msg.ToolCalls() {
  69. assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
  70. ID: call.ID,
  71. Type: "function",
  72. Function: openai.ChatCompletionMessageToolCallFunctionParam{
  73. Name: call.Name,
  74. Arguments: call.Input,
  75. },
  76. }
  77. }
  78. }
  79. openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
  80. OfAssistant: &assistantMsg,
  81. })
  82. case message.Tool:
  83. for _, result := range msg.ToolResults() {
  84. openaiMessages = append(openaiMessages,
  85. openai.ToolMessage(result.Content, result.ToolCallID),
  86. )
  87. }
  88. }
  89. }
  90. return
  91. }
  92. func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
  93. openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
  94. for i, tool := range tools {
  95. info := tool.Info()
  96. openaiTools[i] = openai.ChatCompletionToolParam{
  97. Function: openai.FunctionDefinitionParam{
  98. Name: info.Name,
  99. Description: openai.String(info.Description),
  100. Parameters: openai.FunctionParameters{
  101. "type": "object",
  102. "properties": info.Parameters,
  103. "required": info.Required,
  104. },
  105. },
  106. }
  107. }
  108. return openaiTools
  109. }
  110. func (o *openaiClient) finishReason(reason string) message.FinishReason {
  111. switch reason {
  112. case "stop":
  113. return message.FinishReasonEndTurn
  114. case "length":
  115. return message.FinishReasonMaxTokens
  116. case "tool_calls":
  117. return message.FinishReasonToolUse
  118. default:
  119. return message.FinishReasonUnknown
  120. }
  121. }
  122. func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
  123. params := openai.ChatCompletionNewParams{
  124. Model: openai.ChatModel(o.providerOptions.model.APIModel),
  125. Messages: messages,
  126. Tools: tools,
  127. }
  128. if o.providerOptions.model.CanReason == true {
  129. params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
  130. switch o.options.reasoningEffort {
  131. case "low":
  132. params.ReasoningEffort = shared.ReasoningEffortLow
  133. case "medium":
  134. params.ReasoningEffort = shared.ReasoningEffortMedium
  135. case "high":
  136. params.ReasoningEffort = shared.ReasoningEffortHigh
  137. default:
  138. params.ReasoningEffort = shared.ReasoningEffortMedium
  139. }
  140. } else {
  141. params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
  142. }
  143. return params
  144. }
  145. func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
  146. params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
  147. cfg := config.Get()
  148. if cfg.Debug {
  149. jsonData, _ := json.Marshal(params)
  150. logging.Debug("Prepared messages", "messages", string(jsonData))
  151. }
  152. attempts := 0
  153. for {
  154. attempts++
  155. openaiResponse, err := o.client.Chat.Completions.New(
  156. ctx,
  157. params,
  158. )
  159. // If there is an error we are going to see if we can retry the call
  160. if err != nil {
  161. retry, after, retryErr := o.shouldRetry(attempts, err)
  162. if retryErr != nil {
  163. return nil, retryErr
  164. }
  165. if retry {
  166. logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
  167. select {
  168. case <-ctx.Done():
  169. return nil, ctx.Err()
  170. case <-time.After(time.Duration(after) * time.Millisecond):
  171. continue
  172. }
  173. }
  174. return nil, retryErr
  175. }
  176. content := ""
  177. if openaiResponse.Choices[0].Message.Content != "" {
  178. content = openaiResponse.Choices[0].Message.Content
  179. }
  180. return &ProviderResponse{
  181. Content: content,
  182. ToolCalls: o.toolCalls(*openaiResponse),
  183. Usage: o.usage(*openaiResponse),
  184. FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)),
  185. }, nil
  186. }
  187. }
  188. func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
  189. params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
  190. params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
  191. IncludeUsage: openai.Bool(true),
  192. }
  193. cfg := config.Get()
  194. if cfg.Debug {
  195. jsonData, _ := json.Marshal(params)
  196. logging.Debug("Prepared messages", "messages", string(jsonData))
  197. }
  198. attempts := 0
  199. eventChan := make(chan ProviderEvent)
  200. go func() {
  201. for {
  202. attempts++
  203. openaiStream := o.client.Chat.Completions.NewStreaming(
  204. ctx,
  205. params,
  206. )
  207. acc := openai.ChatCompletionAccumulator{}
  208. currentContent := ""
  209. toolCalls := make([]message.ToolCall, 0)
  210. for openaiStream.Next() {
  211. chunk := openaiStream.Current()
  212. acc.AddChunk(chunk)
  213. if tool, ok := acc.JustFinishedToolCall(); ok {
  214. toolCalls = append(toolCalls, message.ToolCall{
  215. ID: tool.Id,
  216. Name: tool.Name,
  217. Input: tool.Arguments,
  218. Type: "function",
  219. })
  220. }
  221. for _, choice := range chunk.Choices {
  222. if choice.Delta.Content != "" {
  223. eventChan <- ProviderEvent{
  224. Type: EventContentDelta,
  225. Content: choice.Delta.Content,
  226. }
  227. currentContent += choice.Delta.Content
  228. }
  229. }
  230. }
  231. err := openaiStream.Err()
  232. if err == nil || errors.Is(err, io.EOF) {
  233. // Stream completed successfully
  234. eventChan <- ProviderEvent{
  235. Type: EventComplete,
  236. Response: &ProviderResponse{
  237. Content: currentContent,
  238. ToolCalls: toolCalls,
  239. Usage: o.usage(acc.ChatCompletion),
  240. FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)),
  241. },
  242. }
  243. close(eventChan)
  244. return
  245. }
  246. // If there is an error we are going to see if we can retry the call
  247. retry, after, retryErr := o.shouldRetry(attempts, err)
  248. if retryErr != nil {
  249. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  250. close(eventChan)
  251. return
  252. }
  253. if retry {
  254. logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
  255. select {
  256. case <-ctx.Done():
  257. // context cancelled
  258. if ctx.Err() == nil {
  259. eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
  260. }
  261. close(eventChan)
  262. return
  263. case <-time.After(time.Duration(after) * time.Millisecond):
  264. continue
  265. }
  266. }
  267. eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
  268. close(eventChan)
  269. return
  270. }
  271. }()
  272. return eventChan
  273. }
  274. func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
  275. var apierr *openai.Error
  276. if !errors.As(err, &apierr) {
  277. return false, 0, err
  278. }
  279. if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
  280. return false, 0, err
  281. }
  282. if attempts > maxRetries {
  283. return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
  284. }
  285. retryMs := 0
  286. retryAfterValues := apierr.Response.Header.Values("Retry-After")
  287. backoffMs := 2000 * (1 << (attempts - 1))
  288. jitterMs := int(float64(backoffMs) * 0.2)
  289. retryMs = backoffMs + jitterMs
  290. if len(retryAfterValues) > 0 {
  291. if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
  292. retryMs = retryMs * 1000
  293. }
  294. }
  295. return true, int64(retryMs), nil
  296. }
  297. func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
  298. var toolCalls []message.ToolCall
  299. if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
  300. for _, call := range completion.Choices[0].Message.ToolCalls {
  301. toolCall := message.ToolCall{
  302. ID: call.ID,
  303. Name: call.Function.Name,
  304. Input: call.Function.Arguments,
  305. Type: "function",
  306. Finished: true,
  307. }
  308. toolCalls = append(toolCalls, toolCall)
  309. }
  310. }
  311. return toolCalls
  312. }
  313. func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
  314. cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
  315. inputTokens := completion.Usage.PromptTokens - cachedTokens
  316. return TokenUsage{
  317. InputTokens: inputTokens,
  318. OutputTokens: completion.Usage.CompletionTokens,
  319. CacheCreationTokens: 0, // OpenAI doesn't provide this directly
  320. CacheReadTokens: cachedTokens,
  321. }
  322. }
  323. func WithOpenAIBaseURL(baseURL string) OpenAIOption {
  324. return func(options *openaiOptions) {
  325. options.baseURL = baseURL
  326. }
  327. }
  328. func WithOpenAIDisableCache() OpenAIOption {
  329. return func(options *openaiOptions) {
  330. options.disableCache = true
  331. }
  332. }
  333. func WithReasoningEffort(effort string) OpenAIOption {
  334. return func(options *openaiOptions) {
  335. defaultReasoningEffort := "medium"
  336. switch effort {
  337. case "low", "medium", "high":
  338. defaultReasoningEffort = effort
  339. default:
  340. logging.Warn("Invalid reasoning effort, using default: medium")
  341. }
  342. options.reasoningEffort = defaultReasoningEffort
  343. }
  344. }