connections.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package clashapi
  2. import (
  3. "bytes"
  4. "net/http"
  5. "strconv"
  6. "time"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
  9. "github.com/sagernet/sing/common/json"
  10. "github.com/sagernet/ws"
  11. "github.com/sagernet/ws/wsutil"
  12. "github.com/go-chi/chi/v5"
  13. "github.com/go-chi/render"
  14. "github.com/gofrs/uuid/v5"
  15. )
  16. func connectionRouter(router adapter.Router, trafficManager *trafficontrol.Manager) http.Handler {
  17. r := chi.NewRouter()
  18. r.Get("/", getConnections(trafficManager))
  19. r.Delete("/", closeAllConnections(router, trafficManager))
  20. r.Delete("/{id}", closeConnection(trafficManager))
  21. return r
  22. }
  23. func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
  24. return func(w http.ResponseWriter, r *http.Request) {
  25. if r.Header.Get("Upgrade") != "websocket" {
  26. snapshot := trafficManager.Snapshot()
  27. render.JSON(w, r, snapshot)
  28. return
  29. }
  30. conn, _, _, err := ws.UpgradeHTTP(r, w)
  31. if err != nil {
  32. return
  33. }
  34. intervalStr := r.URL.Query().Get("interval")
  35. interval := 1000
  36. if intervalStr != "" {
  37. t, err := strconv.Atoi(intervalStr)
  38. if err != nil {
  39. render.Status(r, http.StatusBadRequest)
  40. render.JSON(w, r, ErrBadRequest)
  41. return
  42. }
  43. interval = t
  44. }
  45. buf := &bytes.Buffer{}
  46. sendSnapshot := func() error {
  47. buf.Reset()
  48. snapshot := trafficManager.Snapshot()
  49. if err := json.NewEncoder(buf).Encode(snapshot); err != nil {
  50. return err
  51. }
  52. return wsutil.WriteServerText(conn, buf.Bytes())
  53. }
  54. if err = sendSnapshot(); err != nil {
  55. return
  56. }
  57. tick := time.NewTicker(time.Millisecond * time.Duration(interval))
  58. defer tick.Stop()
  59. for range tick.C {
  60. if err = sendSnapshot(); err != nil {
  61. break
  62. }
  63. }
  64. }
  65. }
  66. func closeConnection(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
  67. return func(w http.ResponseWriter, r *http.Request) {
  68. id := uuid.FromStringOrNil(chi.URLParam(r, "id"))
  69. snapshot := trafficManager.Snapshot()
  70. for _, c := range snapshot.Connections {
  71. if id == c.Metadata().ID {
  72. c.Close()
  73. break
  74. }
  75. }
  76. render.NoContent(w, r)
  77. }
  78. }
  79. func closeAllConnections(router adapter.Router, trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
  80. return func(w http.ResponseWriter, r *http.Request) {
  81. snapshot := trafficManager.Snapshot()
  82. for _, c := range snapshot.Connections {
  83. c.Close()
  84. }
  85. router.ResetNetwork()
  86. render.NoContent(w, r)
  87. }
  88. }