service.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. package ccm
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "io"
  8. "mime"
  9. "net"
  10. "net/http"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/sagernet/sing-box/adapter"
  15. boxService "github.com/sagernet/sing-box/adapter/service"
  16. "github.com/sagernet/sing-box/common/dialer"
  17. "github.com/sagernet/sing-box/common/listener"
  18. "github.com/sagernet/sing-box/common/tls"
  19. C "github.com/sagernet/sing-box/constant"
  20. "github.com/sagernet/sing-box/log"
  21. "github.com/sagernet/sing-box/option"
  22. "github.com/sagernet/sing/common"
  23. "github.com/sagernet/sing/common/buf"
  24. E "github.com/sagernet/sing/common/exceptions"
  25. M "github.com/sagernet/sing/common/metadata"
  26. N "github.com/sagernet/sing/common/network"
  27. aTLS "github.com/sagernet/sing/common/tls"
  28. "github.com/anthropics/anthropic-sdk-go"
  29. "github.com/go-chi/chi/v5"
  30. "golang.org/x/net/http2"
  31. )
  32. const (
  33. contextWindowStandard = 200000
  34. contextWindowPremium = 1000000
  35. premiumContextThreshold = 200000
  36. )
  37. func RegisterService(registry *boxService.Registry) {
  38. boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
  39. }
  40. type errorResponse struct {
  41. Type string `json:"type"`
  42. Error errorDetails `json:"error"`
  43. RequestID string `json:"request_id,omitempty"`
  44. }
  45. type errorDetails struct {
  46. Type string `json:"type"`
  47. Message string `json:"message"`
  48. }
  49. func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
  50. w.Header().Set("Content-Type", "application/json")
  51. w.WriteHeader(statusCode)
  52. json.NewEncoder(w).Encode(errorResponse{
  53. Type: "error",
  54. Error: errorDetails{
  55. Type: errorType,
  56. Message: message,
  57. },
  58. RequestID: r.Header.Get("Request-Id"),
  59. })
  60. }
  61. func isHopByHopHeader(header string) bool {
  62. switch strings.ToLower(header) {
  63. case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
  64. return true
  65. default:
  66. return false
  67. }
  68. }
  69. type Service struct {
  70. boxService.Adapter
  71. ctx context.Context
  72. logger log.ContextLogger
  73. credentialPath string
  74. credentials *oauthCredentials
  75. users []option.CCMUser
  76. httpClient *http.Client
  77. httpHeaders http.Header
  78. listener *listener.Listener
  79. tlsConfig tls.ServerConfig
  80. httpServer *http.Server
  81. userManager *UserManager
  82. accessMutex sync.RWMutex
  83. usageTracker *AggregatedUsage
  84. trackingGroup sync.WaitGroup
  85. shuttingDown bool
  86. }
  87. func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
  88. serviceDialer, err := dialer.NewWithOptions(dialer.Options{
  89. Context: ctx,
  90. Options: option.DialerOptions{
  91. Detour: options.Detour,
  92. },
  93. RemoteIsDomain: true,
  94. })
  95. if err != nil {
  96. return nil, E.Cause(err, "create dialer")
  97. }
  98. httpClient := &http.Client{
  99. Transport: &http.Transport{
  100. ForceAttemptHTTP2: true,
  101. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  102. return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  103. },
  104. },
  105. }
  106. userManager := &UserManager{
  107. tokenMap: make(map[string]string),
  108. }
  109. var usageTracker *AggregatedUsage
  110. if options.UsagesPath != "" {
  111. usageTracker = &AggregatedUsage{
  112. LastUpdated: time.Now(),
  113. Combinations: make([]CostCombination, 0),
  114. filePath: options.UsagesPath,
  115. logger: logger,
  116. }
  117. }
  118. service := &Service{
  119. Adapter: boxService.NewAdapter(C.TypeCCM, tag),
  120. ctx: ctx,
  121. logger: logger,
  122. credentialPath: options.CredentialPath,
  123. users: options.Users,
  124. httpClient: httpClient,
  125. httpHeaders: options.Headers.Build(),
  126. listener: listener.New(listener.Options{
  127. Context: ctx,
  128. Logger: logger,
  129. Network: []string{N.NetworkTCP},
  130. Listen: options.ListenOptions,
  131. }),
  132. userManager: userManager,
  133. usageTracker: usageTracker,
  134. }
  135. if options.TLS != nil {
  136. tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
  137. if err != nil {
  138. return nil, err
  139. }
  140. service.tlsConfig = tlsConfig
  141. }
  142. return service, nil
  143. }
  144. func (s *Service) Start(stage adapter.StartStage) error {
  145. if stage != adapter.StartStateStart {
  146. return nil
  147. }
  148. s.userManager.UpdateUsers(s.users)
  149. credentials, err := platformReadCredentials(s.credentialPath)
  150. if err != nil {
  151. return E.Cause(err, "read credentials")
  152. }
  153. s.credentials = credentials
  154. if s.usageTracker != nil {
  155. err = s.usageTracker.Load()
  156. if err != nil {
  157. s.logger.Warn("load usage statistics: ", err)
  158. }
  159. }
  160. router := chi.NewRouter()
  161. router.Mount("/", s)
  162. s.httpServer = &http.Server{Handler: router}
  163. if s.tlsConfig != nil {
  164. err = s.tlsConfig.Start()
  165. if err != nil {
  166. return E.Cause(err, "create TLS config")
  167. }
  168. }
  169. tcpListener, err := s.listener.ListenTCP()
  170. if err != nil {
  171. return err
  172. }
  173. if s.tlsConfig != nil {
  174. if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
  175. s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
  176. }
  177. tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
  178. }
  179. go func() {
  180. serveErr := s.httpServer.Serve(tcpListener)
  181. if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
  182. s.logger.Error("serve error: ", serveErr)
  183. }
  184. }()
  185. return nil
  186. }
  187. func (s *Service) getAccessToken() (string, error) {
  188. s.accessMutex.RLock()
  189. if !s.credentials.needsRefresh() {
  190. token := s.credentials.AccessToken
  191. s.accessMutex.RUnlock()
  192. return token, nil
  193. }
  194. s.accessMutex.RUnlock()
  195. s.accessMutex.Lock()
  196. defer s.accessMutex.Unlock()
  197. if !s.credentials.needsRefresh() {
  198. return s.credentials.AccessToken, nil
  199. }
  200. newCredentials, err := refreshToken(s.httpClient, s.credentials)
  201. if err != nil {
  202. return "", err
  203. }
  204. s.credentials = newCredentials
  205. err = platformWriteCredentials(newCredentials, s.credentialPath)
  206. if err != nil {
  207. s.logger.Warn("persist refreshed token: ", err)
  208. }
  209. return newCredentials.AccessToken, nil
  210. }
  211. func detectContextWindow(betaHeader string, inputTokens int64) int {
  212. if inputTokens > premiumContextThreshold {
  213. features := strings.Split(betaHeader, ",")
  214. for _, feature := range features {
  215. if strings.TrimSpace(feature) == "context-1m" {
  216. return contextWindowPremium
  217. }
  218. }
  219. }
  220. return contextWindowStandard
  221. }
  222. func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  223. if !strings.HasPrefix(r.URL.Path, "/v1/") {
  224. writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
  225. return
  226. }
  227. var username string
  228. if len(s.users) > 0 {
  229. authHeader := r.Header.Get("Authorization")
  230. if authHeader == "" {
  231. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
  232. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
  233. return
  234. }
  235. clientToken := strings.TrimPrefix(authHeader, "Bearer ")
  236. if clientToken == authHeader {
  237. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
  238. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
  239. return
  240. }
  241. var ok bool
  242. username, ok = s.userManager.Authenticate(clientToken)
  243. if !ok {
  244. s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
  245. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
  246. return
  247. }
  248. }
  249. var requestModel string
  250. var messagesCount int
  251. if s.usageTracker != nil && r.Body != nil {
  252. bodyBytes, err := io.ReadAll(r.Body)
  253. if err == nil {
  254. var request struct {
  255. Model string `json:"model"`
  256. Messages []anthropic.MessageParam `json:"messages"`
  257. }
  258. err := json.Unmarshal(bodyBytes, &request)
  259. if err == nil {
  260. requestModel = request.Model
  261. messagesCount = len(request.Messages)
  262. }
  263. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
  264. }
  265. }
  266. accessToken, err := s.getAccessToken()
  267. if err != nil {
  268. s.logger.Error("get access token: ", err)
  269. writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
  270. return
  271. }
  272. proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
  273. proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
  274. if err != nil {
  275. s.logger.Error("create proxy request: ", err)
  276. writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
  277. return
  278. }
  279. for key, values := range r.Header {
  280. if !isHopByHopHeader(key) && key != "Authorization" {
  281. proxyRequest.Header[key] = values
  282. }
  283. }
  284. anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
  285. if anthropicBetaHeader != "" {
  286. proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
  287. } else {
  288. proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
  289. }
  290. for key, values := range s.httpHeaders {
  291. proxyRequest.Header.Del(key)
  292. proxyRequest.Header[key] = values
  293. }
  294. proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
  295. response, err := s.httpClient.Do(proxyRequest)
  296. if err != nil {
  297. writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
  298. return
  299. }
  300. defer response.Body.Close()
  301. for key, values := range response.Header {
  302. if !isHopByHopHeader(key) {
  303. w.Header()[key] = values
  304. }
  305. }
  306. w.WriteHeader(response.StatusCode)
  307. if s.usageTracker != nil && response.StatusCode == http.StatusOK {
  308. s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username)
  309. } else {
  310. mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
  311. if err == nil && mediaType != "text/event-stream" {
  312. _, _ = io.Copy(w, response.Body)
  313. return
  314. }
  315. flusher, ok := w.(http.Flusher)
  316. if !ok {
  317. s.logger.Error("streaming not supported")
  318. return
  319. }
  320. buffer := make([]byte, buf.BufferSize)
  321. for {
  322. n, err := response.Body.Read(buffer)
  323. if n > 0 {
  324. _, writeError := w.Write(buffer[:n])
  325. if writeError != nil {
  326. s.logger.Error("write streaming response: ", writeError)
  327. return
  328. }
  329. flusher.Flush()
  330. }
  331. if err != nil {
  332. return
  333. }
  334. }
  335. }
  336. }
  337. func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
  338. mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
  339. isStreaming := err == nil && mediaType == "text/event-stream"
  340. if !isStreaming {
  341. bodyBytes, err := io.ReadAll(response.Body)
  342. if err != nil {
  343. s.logger.Error("read response body: ", err)
  344. return
  345. }
  346. var message anthropic.Message
  347. var usage anthropic.Usage
  348. var responseModel string
  349. err = json.Unmarshal(bodyBytes, &message)
  350. if err == nil {
  351. responseModel = string(message.Model)
  352. usage = message.Usage
  353. }
  354. if responseModel == "" {
  355. responseModel = requestModel
  356. }
  357. if usage.InputTokens > 0 || usage.OutputTokens > 0 {
  358. if responseModel != "" {
  359. contextWindow := detectContextWindow(anthropicBetaHeader, usage.InputTokens)
  360. s.usageTracker.AddUsage(
  361. responseModel,
  362. contextWindow,
  363. messagesCount,
  364. usage.InputTokens,
  365. usage.OutputTokens,
  366. usage.CacheReadInputTokens,
  367. usage.CacheCreationInputTokens,
  368. username,
  369. )
  370. }
  371. }
  372. _, _ = writer.Write(bodyBytes)
  373. return
  374. }
  375. flusher, ok := writer.(http.Flusher)
  376. if !ok {
  377. s.logger.Error("streaming not supported")
  378. return
  379. }
  380. var accumulatedUsage anthropic.Usage
  381. var responseModel string
  382. buffer := make([]byte, buf.BufferSize)
  383. var leftover []byte
  384. for {
  385. n, err := response.Body.Read(buffer)
  386. if n > 0 {
  387. data := append(leftover, buffer[:n]...)
  388. lines := bytes.Split(data, []byte("\n"))
  389. if err == nil {
  390. leftover = lines[len(lines)-1]
  391. lines = lines[:len(lines)-1]
  392. } else {
  393. leftover = nil
  394. }
  395. for _, line := range lines {
  396. line = bytes.TrimSpace(line)
  397. if len(line) == 0 {
  398. continue
  399. }
  400. if bytes.HasPrefix(line, []byte("data: ")) {
  401. eventData := bytes.TrimPrefix(line, []byte("data: "))
  402. if bytes.Equal(eventData, []byte("[DONE]")) {
  403. continue
  404. }
  405. var event anthropic.MessageStreamEventUnion
  406. err := json.Unmarshal(eventData, &event)
  407. if err != nil {
  408. continue
  409. }
  410. switch event.Type {
  411. case "message_start":
  412. messageStart := event.AsMessageStart()
  413. if messageStart.Message.Model != "" {
  414. responseModel = string(messageStart.Message.Model)
  415. }
  416. if messageStart.Message.Usage.InputTokens > 0 {
  417. accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
  418. accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
  419. accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
  420. }
  421. case "message_delta":
  422. messageDelta := event.AsMessageDelta()
  423. if messageDelta.Usage.OutputTokens > 0 {
  424. accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
  425. }
  426. }
  427. }
  428. }
  429. _, writeError := writer.Write(buffer[:n])
  430. if writeError != nil {
  431. s.logger.Error("write streaming response: ", writeError)
  432. return
  433. }
  434. flusher.Flush()
  435. }
  436. if err != nil {
  437. if responseModel == "" {
  438. responseModel = requestModel
  439. }
  440. if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
  441. if responseModel != "" {
  442. contextWindow := detectContextWindow(anthropicBetaHeader, accumulatedUsage.InputTokens)
  443. s.usageTracker.AddUsage(
  444. responseModel,
  445. contextWindow,
  446. messagesCount,
  447. accumulatedUsage.InputTokens,
  448. accumulatedUsage.OutputTokens,
  449. accumulatedUsage.CacheReadInputTokens,
  450. accumulatedUsage.CacheCreationInputTokens,
  451. username,
  452. )
  453. }
  454. }
  455. return
  456. }
  457. }
  458. }
  459. func (s *Service) Close() error {
  460. err := common.Close(
  461. common.PtrOrNil(s.httpServer),
  462. common.PtrOrNil(s.listener),
  463. s.tlsConfig,
  464. )
  465. if s.usageTracker != nil {
  466. s.usageTracker.cancelPendingSave()
  467. saveErr := s.usageTracker.Save()
  468. if saveErr != nil {
  469. s.logger.Error("save usage statistics: ", saveErr)
  470. }
  471. }
  472. return err
  473. }