service.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package ccm
  2. import (
  3. "bytes"
  4. "context"
  5. stdTLS "crypto/tls"
  6. "encoding/json"
  7. "errors"
  8. "io"
  9. "mime"
  10. "net"
  11. "net/http"
  12. "strings"
  13. "sync"
  14. "time"
  15. "github.com/sagernet/sing-box/adapter"
  16. boxService "github.com/sagernet/sing-box/adapter/service"
  17. "github.com/sagernet/sing-box/common/dialer"
  18. "github.com/sagernet/sing-box/common/listener"
  19. "github.com/sagernet/sing-box/common/tls"
  20. C "github.com/sagernet/sing-box/constant"
  21. "github.com/sagernet/sing-box/log"
  22. "github.com/sagernet/sing-box/option"
  23. "github.com/sagernet/sing/common"
  24. "github.com/sagernet/sing/common/buf"
  25. E "github.com/sagernet/sing/common/exceptions"
  26. M "github.com/sagernet/sing/common/metadata"
  27. N "github.com/sagernet/sing/common/network"
  28. "github.com/sagernet/sing/common/ntp"
  29. aTLS "github.com/sagernet/sing/common/tls"
  30. "github.com/anthropics/anthropic-sdk-go"
  31. "github.com/go-chi/chi/v5"
  32. "golang.org/x/net/http2"
  33. )
  34. const (
  35. contextWindowStandard = 200000
  36. contextWindowPremium = 1000000
  37. premiumContextThreshold = 200000
  38. )
  39. func RegisterService(registry *boxService.Registry) {
  40. boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
  41. }
  42. type errorResponse struct {
  43. Type string `json:"type"`
  44. Error errorDetails `json:"error"`
  45. RequestID string `json:"request_id,omitempty"`
  46. }
  47. type errorDetails struct {
  48. Type string `json:"type"`
  49. Message string `json:"message"`
  50. }
  51. func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
  52. w.Header().Set("Content-Type", "application/json")
  53. w.WriteHeader(statusCode)
  54. json.NewEncoder(w).Encode(errorResponse{
  55. Type: "error",
  56. Error: errorDetails{
  57. Type: errorType,
  58. Message: message,
  59. },
  60. RequestID: r.Header.Get("Request-Id"),
  61. })
  62. }
  63. func isHopByHopHeader(header string) bool {
  64. switch strings.ToLower(header) {
  65. case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
  66. return true
  67. default:
  68. return false
  69. }
  70. }
  71. type Service struct {
  72. boxService.Adapter
  73. ctx context.Context
  74. logger log.ContextLogger
  75. credentialPath string
  76. credentials *oauthCredentials
  77. users []option.CCMUser
  78. httpClient *http.Client
  79. httpHeaders http.Header
  80. listener *listener.Listener
  81. tlsConfig tls.ServerConfig
  82. httpServer *http.Server
  83. userManager *UserManager
  84. accessMutex sync.RWMutex
  85. usageTracker *AggregatedUsage
  86. trackingGroup sync.WaitGroup
  87. shuttingDown bool
  88. }
  89. func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
  90. serviceDialer, err := dialer.NewWithOptions(dialer.Options{
  91. Context: ctx,
  92. Options: option.DialerOptions{
  93. Detour: options.Detour,
  94. },
  95. RemoteIsDomain: true,
  96. })
  97. if err != nil {
  98. return nil, E.Cause(err, "create dialer")
  99. }
  100. httpClient := &http.Client{
  101. Transport: &http.Transport{
  102. ForceAttemptHTTP2: true,
  103. TLSClientConfig: &stdTLS.Config{
  104. RootCAs: adapter.RootPoolFromContext(ctx),
  105. Time: ntp.TimeFuncFromContext(ctx),
  106. },
  107. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  108. return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  109. },
  110. },
  111. }
  112. userManager := &UserManager{
  113. tokenMap: make(map[string]string),
  114. }
  115. var usageTracker *AggregatedUsage
  116. if options.UsagesPath != "" {
  117. usageTracker = &AggregatedUsage{
  118. LastUpdated: time.Now(),
  119. Combinations: make([]CostCombination, 0),
  120. filePath: options.UsagesPath,
  121. logger: logger,
  122. }
  123. }
  124. service := &Service{
  125. Adapter: boxService.NewAdapter(C.TypeCCM, tag),
  126. ctx: ctx,
  127. logger: logger,
  128. credentialPath: options.CredentialPath,
  129. users: options.Users,
  130. httpClient: httpClient,
  131. httpHeaders: options.Headers.Build(),
  132. listener: listener.New(listener.Options{
  133. Context: ctx,
  134. Logger: logger,
  135. Network: []string{N.NetworkTCP},
  136. Listen: options.ListenOptions,
  137. }),
  138. userManager: userManager,
  139. usageTracker: usageTracker,
  140. }
  141. if options.TLS != nil {
  142. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  143. if err != nil {
  144. return nil, err
  145. }
  146. service.tlsConfig = tlsConfig
  147. }
  148. return service, nil
  149. }
  150. func (s *Service) Start(stage adapter.StartStage) error {
  151. if stage != adapter.StartStateStart {
  152. return nil
  153. }
  154. s.userManager.UpdateUsers(s.users)
  155. credentials, err := platformReadCredentials(s.credentialPath)
  156. if err != nil {
  157. return E.Cause(err, "read credentials")
  158. }
  159. s.credentials = credentials
  160. if s.usageTracker != nil {
  161. err = s.usageTracker.Load()
  162. if err != nil {
  163. s.logger.Warn("load usage statistics: ", err)
  164. }
  165. }
  166. router := chi.NewRouter()
  167. router.Mount("/", s)
  168. s.httpServer = &http.Server{Handler: router}
  169. if s.tlsConfig != nil {
  170. err = s.tlsConfig.Start()
  171. if err != nil {
  172. return E.Cause(err, "create TLS config")
  173. }
  174. }
  175. tcpListener, err := s.listener.ListenTCP()
  176. if err != nil {
  177. return err
  178. }
  179. if s.tlsConfig != nil {
  180. if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
  181. s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
  182. }
  183. tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
  184. }
  185. go func() {
  186. serveErr := s.httpServer.Serve(tcpListener)
  187. if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
  188. s.logger.Error("serve error: ", serveErr)
  189. }
  190. }()
  191. return nil
  192. }
  193. func (s *Service) getAccessToken() (string, error) {
  194. s.accessMutex.RLock()
  195. if !s.credentials.needsRefresh() {
  196. token := s.credentials.AccessToken
  197. s.accessMutex.RUnlock()
  198. return token, nil
  199. }
  200. s.accessMutex.RUnlock()
  201. s.accessMutex.Lock()
  202. defer s.accessMutex.Unlock()
  203. if !s.credentials.needsRefresh() {
  204. return s.credentials.AccessToken, nil
  205. }
  206. newCredentials, err := refreshToken(s.httpClient, s.credentials)
  207. if err != nil {
  208. return "", err
  209. }
  210. s.credentials = newCredentials
  211. err = platformWriteCredentials(newCredentials, s.credentialPath)
  212. if err != nil {
  213. s.logger.Warn("persist refreshed token: ", err)
  214. }
  215. return newCredentials.AccessToken, nil
  216. }
  217. func detectContextWindow(betaHeader string, inputTokens int64) int {
  218. if inputTokens > premiumContextThreshold {
  219. features := strings.Split(betaHeader, ",")
  220. for _, feature := range features {
  221. if strings.TrimSpace(feature) == "context-1m" {
  222. return contextWindowPremium
  223. }
  224. }
  225. }
  226. return contextWindowStandard
  227. }
  228. func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  229. if !strings.HasPrefix(r.URL.Path, "/v1/") {
  230. writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
  231. return
  232. }
  233. var username string
  234. if len(s.users) > 0 {
  235. authHeader := r.Header.Get("Authorization")
  236. if authHeader == "" {
  237. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
  238. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
  239. return
  240. }
  241. clientToken := strings.TrimPrefix(authHeader, "Bearer ")
  242. if clientToken == authHeader {
  243. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
  244. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
  245. return
  246. }
  247. var ok bool
  248. username, ok = s.userManager.Authenticate(clientToken)
  249. if !ok {
  250. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
  251. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
  252. return
  253. }
  254. }
  255. var requestModel string
  256. var messagesCount int
  257. if s.usageTracker != nil && r.Body != nil {
  258. bodyBytes, err := io.ReadAll(r.Body)
  259. if err == nil {
  260. var request struct {
  261. Model string `json:"model"`
  262. Messages []anthropic.MessageParam `json:"messages"`
  263. }
  264. err := json.Unmarshal(bodyBytes, &request)
  265. if err == nil {
  266. requestModel = request.Model
  267. messagesCount = len(request.Messages)
  268. }
  269. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
  270. }
  271. }
  272. accessToken, err := s.getAccessToken()
  273. if err != nil {
  274. s.logger.Error("get access token: ", err)
  275. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
  276. return
  277. }
  278. proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
  279. proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
  280. if err != nil {
  281. s.logger.Error("create proxy request: ", err)
  282. writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
  283. return
  284. }
  285. for key, values := range r.Header {
  286. if !isHopByHopHeader(key) && key != "Authorization" {
  287. proxyRequest.Header[key] = values
  288. }
  289. }
  290. anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
  291. if anthropicBetaHeader != "" {
  292. proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
  293. } else {
  294. proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
  295. }
  296. for key, values := range s.httpHeaders {
  297. proxyRequest.Header.Del(key)
  298. proxyRequest.Header[key] = values
  299. }
  300. proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
  301. response, err := s.httpClient.Do(proxyRequest)
  302. if err != nil {
  303. writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
  304. return
  305. }
  306. defer response.Body.Close()
  307. for key, values := range response.Header {
  308. if !isHopByHopHeader(key) {
  309. w.Header()[key] = values
  310. }
  311. }
  312. w.WriteHeader(response.StatusCode)
  313. if s.usageTracker != nil && response.StatusCode == http.StatusOK {
  314. s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username)
  315. } else {
  316. mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
  317. if err == nil && mediaType != "text/event-stream" {
  318. _, _ = io.Copy(w, response.Body)
  319. return
  320. }
  321. flusher, ok := w.(http.Flusher)
  322. if !ok {
  323. s.logger.Error("streaming not supported")
  324. return
  325. }
  326. buffer := make([]byte, buf.BufferSize)
  327. for {
  328. n, err := response.Body.Read(buffer)
  329. if n > 0 {
  330. _, writeError := w.Write(buffer[:n])
  331. if writeError != nil {
  332. s.logger.Error("write streaming response: ", writeError)
  333. return
  334. }
  335. flusher.Flush()
  336. }
  337. if err != nil {
  338. return
  339. }
  340. }
  341. }
  342. }
  343. func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
  344. mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
  345. isStreaming := err == nil && mediaType == "text/event-stream"
  346. if !isStreaming {
  347. bodyBytes, err := io.ReadAll(response.Body)
  348. if err != nil {
  349. s.logger.Error("read response body: ", err)
  350. return
  351. }
  352. var message anthropic.Message
  353. var usage anthropic.Usage
  354. var responseModel string
  355. err = json.Unmarshal(bodyBytes, &message)
  356. if err == nil {
  357. responseModel = string(message.Model)
  358. usage = message.Usage
  359. }
  360. if responseModel == "" {
  361. responseModel = requestModel
  362. }
  363. if usage.InputTokens > 0 || usage.OutputTokens > 0 {
  364. if responseModel != "" {
  365. contextWindow := detectContextWindow(anthropicBetaHeader, usage.InputTokens)
  366. s.usageTracker.AddUsage(
  367. responseModel,
  368. contextWindow,
  369. messagesCount,
  370. usage.InputTokens,
  371. usage.OutputTokens,
  372. usage.CacheReadInputTokens,
  373. usage.CacheCreationInputTokens,
  374. username,
  375. )
  376. }
  377. }
  378. _, _ = writer.Write(bodyBytes)
  379. return
  380. }
  381. flusher, ok := writer.(http.Flusher)
  382. if !ok {
  383. s.logger.Error("streaming not supported")
  384. return
  385. }
  386. var accumulatedUsage anthropic.Usage
  387. var responseModel string
  388. buffer := make([]byte, buf.BufferSize)
  389. var leftover []byte
  390. for {
  391. n, err := response.Body.Read(buffer)
  392. if n > 0 {
  393. data := append(leftover, buffer[:n]...)
  394. lines := bytes.Split(data, []byte("\n"))
  395. if err == nil {
  396. leftover = lines[len(lines)-1]
  397. lines = lines[:len(lines)-1]
  398. } else {
  399. leftover = nil
  400. }
  401. for _, line := range lines {
  402. line = bytes.TrimSpace(line)
  403. if len(line) == 0 {
  404. continue
  405. }
  406. if bytes.HasPrefix(line, []byte("data: ")) {
  407. eventData := bytes.TrimPrefix(line, []byte("data: "))
  408. if bytes.Equal(eventData, []byte("[DONE]")) {
  409. continue
  410. }
  411. var event anthropic.MessageStreamEventUnion
  412. err := json.Unmarshal(eventData, &event)
  413. if err != nil {
  414. continue
  415. }
  416. switch event.Type {
  417. case "message_start":
  418. messageStart := event.AsMessageStart()
  419. if messageStart.Message.Model != "" {
  420. responseModel = string(messageStart.Message.Model)
  421. }
  422. if messageStart.Message.Usage.InputTokens > 0 {
  423. accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
  424. accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
  425. accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
  426. }
  427. case "message_delta":
  428. messageDelta := event.AsMessageDelta()
  429. if messageDelta.Usage.OutputTokens > 0 {
  430. accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
  431. }
  432. }
  433. }
  434. }
  435. _, writeError := writer.Write(buffer[:n])
  436. if writeError != nil {
  437. s.logger.Error("write streaming response: ", writeError)
  438. return
  439. }
  440. flusher.Flush()
  441. }
  442. if err != nil {
  443. if responseModel == "" {
  444. responseModel = requestModel
  445. }
  446. if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
  447. if responseModel != "" {
  448. contextWindow := detectContextWindow(anthropicBetaHeader, accumulatedUsage.InputTokens)
  449. s.usageTracker.AddUsage(
  450. responseModel,
  451. contextWindow,
  452. messagesCount,
  453. accumulatedUsage.InputTokens,
  454. accumulatedUsage.OutputTokens,
  455. accumulatedUsage.CacheReadInputTokens,
  456. accumulatedUsage.CacheCreationInputTokens,
  457. username,
  458. )
  459. }
  460. }
  461. return
  462. }
  463. }
  464. }
  465. func (s *Service) Close() error {
  466. err := common.Close(
  467. common.PtrOrNil(s.httpServer),
  468. common.PtrOrNil(s.listener),
  469. s.tlsConfig,
  470. )
  471. if s.usageTracker != nil {
  472. s.usageTracker.cancelPendingSave()
  473. saveErr := s.usageTracker.Save()
  474. if saveErr != nil {
  475. s.logger.Error("save usage statistics: ", saveErr)
  476. }
  477. }
  478. return err
  479. }