service.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. package ocm
  2. import (
  3. "bytes"
  4. "context"
  5. stdTLS "crypto/tls"
  6. "encoding/json"
  7. "errors"
  8. "io"
  9. "mime"
  10. "net"
  11. "net/http"
  12. "strings"
  13. "sync"
  14. "time"
  15. "github.com/sagernet/sing-box/adapter"
  16. boxService "github.com/sagernet/sing-box/adapter/service"
  17. "github.com/sagernet/sing-box/common/dialer"
  18. "github.com/sagernet/sing-box/common/listener"
  19. "github.com/sagernet/sing-box/common/tls"
  20. C "github.com/sagernet/sing-box/constant"
  21. "github.com/sagernet/sing-box/log"
  22. "github.com/sagernet/sing-box/option"
  23. "github.com/sagernet/sing/common"
  24. "github.com/sagernet/sing/common/buf"
  25. E "github.com/sagernet/sing/common/exceptions"
  26. M "github.com/sagernet/sing/common/metadata"
  27. N "github.com/sagernet/sing/common/network"
  28. "github.com/sagernet/sing/common/ntp"
  29. aTLS "github.com/sagernet/sing/common/tls"
  30. "github.com/go-chi/chi/v5"
  31. "github.com/openai/openai-go/v3"
  32. "github.com/openai/openai-go/v3/responses"
  33. "golang.org/x/net/http2"
  34. )
  35. func RegisterService(registry *boxService.Registry) {
  36. boxService.Register[option.OCMServiceOptions](registry, C.TypeOCM, NewService)
  37. }
  38. type errorResponse struct {
  39. Error errorDetails `json:"error"`
  40. }
  41. type errorDetails struct {
  42. Type string `json:"type"`
  43. Code string `json:"code,omitempty"`
  44. Message string `json:"message"`
  45. }
  46. func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
  47. w.Header().Set("Content-Type", "application/json")
  48. w.WriteHeader(statusCode)
  49. json.NewEncoder(w).Encode(errorResponse{
  50. Error: errorDetails{
  51. Type: errorType,
  52. Message: message,
  53. },
  54. })
  55. }
  56. func isHopByHopHeader(header string) bool {
  57. switch strings.ToLower(header) {
  58. case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
  59. return true
  60. default:
  61. return false
  62. }
  63. }
  64. type Service struct {
  65. boxService.Adapter
  66. ctx context.Context
  67. logger log.ContextLogger
  68. credentialPath string
  69. credentials *oauthCredentials
  70. users []option.OCMUser
  71. httpClient *http.Client
  72. httpHeaders http.Header
  73. listener *listener.Listener
  74. tlsConfig tls.ServerConfig
  75. httpServer *http.Server
  76. userManager *UserManager
  77. accessMutex sync.RWMutex
  78. usageTracker *AggregatedUsage
  79. trackingGroup sync.WaitGroup
  80. shuttingDown bool
  81. }
  82. func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
  83. serviceDialer, err := dialer.NewWithOptions(dialer.Options{
  84. Context: ctx,
  85. Options: option.DialerOptions{
  86. Detour: options.Detour,
  87. },
  88. RemoteIsDomain: true,
  89. })
  90. if err != nil {
  91. return nil, E.Cause(err, "create dialer")
  92. }
  93. httpClient := &http.Client{
  94. Transport: &http.Transport{
  95. ForceAttemptHTTP2: true,
  96. TLSClientConfig: &stdTLS.Config{
  97. RootCAs: adapter.RootPoolFromContext(ctx),
  98. Time: ntp.TimeFuncFromContext(ctx),
  99. },
  100. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  101. return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  102. },
  103. },
  104. }
  105. userManager := &UserManager{
  106. tokenMap: make(map[string]string),
  107. }
  108. var usageTracker *AggregatedUsage
  109. if options.UsagesPath != "" {
  110. usageTracker = &AggregatedUsage{
  111. LastUpdated: time.Now(),
  112. Combinations: make([]CostCombination, 0),
  113. filePath: options.UsagesPath,
  114. logger: logger,
  115. }
  116. }
  117. service := &Service{
  118. Adapter: boxService.NewAdapter(C.TypeOCM, tag),
  119. ctx: ctx,
  120. logger: logger,
  121. credentialPath: options.CredentialPath,
  122. users: options.Users,
  123. httpClient: httpClient,
  124. httpHeaders: options.Headers.Build(),
  125. listener: listener.New(listener.Options{
  126. Context: ctx,
  127. Logger: logger,
  128. Network: []string{N.NetworkTCP},
  129. Listen: options.ListenOptions,
  130. }),
  131. userManager: userManager,
  132. usageTracker: usageTracker,
  133. }
  134. if options.TLS != nil {
  135. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  136. if err != nil {
  137. return nil, err
  138. }
  139. service.tlsConfig = tlsConfig
  140. }
  141. return service, nil
  142. }
  143. func (s *Service) Start(stage adapter.StartStage) error {
  144. if stage != adapter.StartStateStart {
  145. return nil
  146. }
  147. s.userManager.UpdateUsers(s.users)
  148. credentials, err := platformReadCredentials(s.credentialPath)
  149. if err != nil {
  150. return E.Cause(err, "read credentials")
  151. }
  152. s.credentials = credentials
  153. if s.usageTracker != nil {
  154. err = s.usageTracker.Load()
  155. if err != nil {
  156. s.logger.Warn("load usage statistics: ", err)
  157. }
  158. }
  159. router := chi.NewRouter()
  160. router.Mount("/", s)
  161. s.httpServer = &http.Server{Handler: router}
  162. if s.tlsConfig != nil {
  163. err = s.tlsConfig.Start()
  164. if err != nil {
  165. return E.Cause(err, "create TLS config")
  166. }
  167. }
  168. tcpListener, err := s.listener.ListenTCP()
  169. if err != nil {
  170. return err
  171. }
  172. if s.tlsConfig != nil {
  173. if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
  174. s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
  175. }
  176. tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
  177. }
  178. go func() {
  179. serveErr := s.httpServer.Serve(tcpListener)
  180. if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
  181. s.logger.Error("serve error: ", serveErr)
  182. }
  183. }()
  184. return nil
  185. }
  186. func (s *Service) getAccessToken() (string, error) {
  187. s.accessMutex.RLock()
  188. if !s.credentials.needsRefresh() {
  189. token := s.credentials.getAccessToken()
  190. s.accessMutex.RUnlock()
  191. return token, nil
  192. }
  193. s.accessMutex.RUnlock()
  194. s.accessMutex.Lock()
  195. defer s.accessMutex.Unlock()
  196. if !s.credentials.needsRefresh() {
  197. return s.credentials.getAccessToken(), nil
  198. }
  199. newCredentials, err := refreshToken(s.httpClient, s.credentials)
  200. if err != nil {
  201. return "", err
  202. }
  203. s.credentials = newCredentials
  204. err = platformWriteCredentials(newCredentials, s.credentialPath)
  205. if err != nil {
  206. s.logger.Warn("persist refreshed token: ", err)
  207. }
  208. return newCredentials.getAccessToken(), nil
  209. }
  210. func (s *Service) getAccountID() string {
  211. s.accessMutex.RLock()
  212. defer s.accessMutex.RUnlock()
  213. return s.credentials.getAccountID()
  214. }
  215. func (s *Service) isAPIKeyMode() bool {
  216. s.accessMutex.RLock()
  217. defer s.accessMutex.RUnlock()
  218. return s.credentials.isAPIKeyMode()
  219. }
  220. func (s *Service) getBaseURL() string {
  221. if s.isAPIKeyMode() {
  222. return openaiAPIBaseURL
  223. }
  224. return chatGPTBackendURL
  225. }
  226. func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  227. path := r.URL.Path
  228. if !strings.HasPrefix(path, "/v1/") {
  229. writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
  230. return
  231. }
  232. var proxyPath string
  233. if s.isAPIKeyMode() {
  234. proxyPath = path
  235. } else {
  236. if path == "/v1/chat/completions" {
  237. writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
  238. "chat completions endpoint is only available in API key mode")
  239. return
  240. }
  241. proxyPath = strings.TrimPrefix(path, "/v1")
  242. }
  243. var username string
  244. if len(s.users) > 0 {
  245. authHeader := r.Header.Get("Authorization")
  246. if authHeader == "" {
  247. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
  248. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
  249. return
  250. }
  251. clientToken := strings.TrimPrefix(authHeader, "Bearer ")
  252. if clientToken == authHeader {
  253. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
  254. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
  255. return
  256. }
  257. var ok bool
  258. username, ok = s.userManager.Authenticate(clientToken)
  259. if !ok {
  260. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
  261. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
  262. return
  263. }
  264. }
  265. var requestModel string
  266. if s.usageTracker != nil && r.Body != nil {
  267. bodyBytes, err := io.ReadAll(r.Body)
  268. if err == nil {
  269. var request struct {
  270. Model string `json:"model"`
  271. }
  272. err := json.Unmarshal(bodyBytes, &request)
  273. if err == nil {
  274. requestModel = request.Model
  275. }
  276. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
  277. }
  278. }
  279. accessToken, err := s.getAccessToken()
  280. if err != nil {
  281. s.logger.Error("get access token: ", err)
  282. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
  283. return
  284. }
  285. proxyURL := s.getBaseURL() + proxyPath
  286. if r.URL.RawQuery != "" {
  287. proxyURL += "?" + r.URL.RawQuery
  288. }
  289. proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
  290. if err != nil {
  291. s.logger.Error("create proxy request: ", err)
  292. writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
  293. return
  294. }
  295. for key, values := range r.Header {
  296. if !isHopByHopHeader(key) && key != "Authorization" {
  297. proxyRequest.Header[key] = values
  298. }
  299. }
  300. for key, values := range s.httpHeaders {
  301. proxyRequest.Header.Del(key)
  302. proxyRequest.Header[key] = values
  303. }
  304. proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
  305. if accountID := s.getAccountID(); accountID != "" {
  306. proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
  307. }
  308. response, err := s.httpClient.Do(proxyRequest)
  309. if err != nil {
  310. writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
  311. return
  312. }
  313. defer response.Body.Close()
  314. for key, values := range response.Header {
  315. if !isHopByHopHeader(key) {
  316. w.Header()[key] = values
  317. }
  318. }
  319. w.WriteHeader(response.StatusCode)
  320. trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK &&
  321. (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
  322. if trackUsage {
  323. s.handleResponseWithTracking(w, response, path, requestModel, username)
  324. } else {
  325. mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
  326. if err == nil && mediaType != "text/event-stream" {
  327. _, _ = io.Copy(w, response.Body)
  328. return
  329. }
  330. flusher, ok := w.(http.Flusher)
  331. if !ok {
  332. s.logger.Error("streaming not supported")
  333. return
  334. }
  335. buffer := make([]byte, buf.BufferSize)
  336. for {
  337. n, err := response.Body.Read(buffer)
  338. if n > 0 {
  339. _, writeError := w.Write(buffer[:n])
  340. if writeError != nil {
  341. s.logger.Error("write streaming response: ", writeError)
  342. return
  343. }
  344. flusher.Flush()
  345. }
  346. if err != nil {
  347. return
  348. }
  349. }
  350. }
  351. }
  352. func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) {
  353. isChatCompletions := path == "/v1/chat/completions"
  354. mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
  355. isStreaming := err == nil && mediaType == "text/event-stream"
  356. if !isStreaming {
  357. bodyBytes, err := io.ReadAll(response.Body)
  358. if err != nil {
  359. s.logger.Error("read response body: ", err)
  360. return
  361. }
  362. var responseModel string
  363. var inputTokens, outputTokens, cachedTokens int64
  364. if isChatCompletions {
  365. var chatCompletion openai.ChatCompletion
  366. if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
  367. responseModel = chatCompletion.Model
  368. inputTokens = chatCompletion.Usage.PromptTokens
  369. outputTokens = chatCompletion.Usage.CompletionTokens
  370. cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
  371. }
  372. } else {
  373. var responsesResponse responses.Response
  374. if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
  375. responseModel = string(responsesResponse.Model)
  376. inputTokens = responsesResponse.Usage.InputTokens
  377. outputTokens = responsesResponse.Usage.OutputTokens
  378. cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
  379. }
  380. }
  381. if inputTokens > 0 || outputTokens > 0 {
  382. if responseModel == "" {
  383. responseModel = requestModel
  384. }
  385. if responseModel != "" {
  386. s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username)
  387. }
  388. }
  389. _, _ = writer.Write(bodyBytes)
  390. return
  391. }
  392. flusher, ok := writer.(http.Flusher)
  393. if !ok {
  394. s.logger.Error("streaming not supported")
  395. return
  396. }
  397. var inputTokens, outputTokens, cachedTokens int64
  398. var responseModel string
  399. buffer := make([]byte, buf.BufferSize)
  400. var leftover []byte
  401. for {
  402. n, err := response.Body.Read(buffer)
  403. if n > 0 {
  404. data := append(leftover, buffer[:n]...)
  405. lines := bytes.Split(data, []byte("\n"))
  406. if err == nil {
  407. leftover = lines[len(lines)-1]
  408. lines = lines[:len(lines)-1]
  409. } else {
  410. leftover = nil
  411. }
  412. for _, line := range lines {
  413. line = bytes.TrimSpace(line)
  414. if len(line) == 0 {
  415. continue
  416. }
  417. if bytes.HasPrefix(line, []byte("data: ")) {
  418. eventData := bytes.TrimPrefix(line, []byte("data: "))
  419. if bytes.Equal(eventData, []byte("[DONE]")) {
  420. continue
  421. }
  422. if isChatCompletions {
  423. var chatChunk openai.ChatCompletionChunk
  424. if json.Unmarshal(eventData, &chatChunk) == nil {
  425. if chatChunk.Model != "" {
  426. responseModel = chatChunk.Model
  427. }
  428. if chatChunk.Usage.PromptTokens > 0 {
  429. inputTokens = chatChunk.Usage.PromptTokens
  430. cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
  431. }
  432. if chatChunk.Usage.CompletionTokens > 0 {
  433. outputTokens = chatChunk.Usage.CompletionTokens
  434. }
  435. }
  436. } else {
  437. var streamEvent responses.ResponseStreamEventUnion
  438. if json.Unmarshal(eventData, &streamEvent) == nil {
  439. if streamEvent.Type == "response.completed" {
  440. completedEvent := streamEvent.AsResponseCompleted()
  441. if string(completedEvent.Response.Model) != "" {
  442. responseModel = string(completedEvent.Response.Model)
  443. }
  444. if completedEvent.Response.Usage.InputTokens > 0 {
  445. inputTokens = completedEvent.Response.Usage.InputTokens
  446. cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
  447. }
  448. if completedEvent.Response.Usage.OutputTokens > 0 {
  449. outputTokens = completedEvent.Response.Usage.OutputTokens
  450. }
  451. }
  452. }
  453. }
  454. }
  455. }
  456. _, writeError := writer.Write(buffer[:n])
  457. if writeError != nil {
  458. s.logger.Error("write streaming response: ", writeError)
  459. return
  460. }
  461. flusher.Flush()
  462. }
  463. if err != nil {
  464. if responseModel == "" {
  465. responseModel = requestModel
  466. }
  467. if inputTokens > 0 || outputTokens > 0 {
  468. if responseModel != "" {
  469. s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username)
  470. }
  471. }
  472. return
  473. }
  474. }
  475. }
  476. func (s *Service) Close() error {
  477. err := common.Close(
  478. common.PtrOrNil(s.httpServer),
  479. common.PtrOrNil(s.listener),
  480. s.tlsConfig,
  481. )
  482. if s.usageTracker != nil {
  483. s.usageTracker.cancelPendingSave()
  484. saveErr := s.usageTracker.Save()
  485. if saveErr != nil {
  486. s.logger.Error("save usage statistics: ", saveErr)
  487. }
  488. }
  489. return err
  490. }