speedtest_server.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package speedtest
  4. import (
  5. "crypto/rand"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "net"
  11. "time"
  12. )
  13. // Serve starts up the server on a given host and port pair. It starts to listen for
  14. // connections and handles each one in a goroutine. Because it runs in an infinite loop,
  15. // this function only returns if any of the speedtests return with errors, or if the
  16. // listener is closed.
  17. func Serve(ln net.Listener) error {
  18. for {
  19. conn, err := ln.Accept()
  20. if errors.Is(err, net.ErrClosed) {
  21. return nil
  22. }
  23. if err != nil {
  24. return err
  25. }
  26. err = handleConnection(conn)
  27. if err != nil {
  28. return err
  29. }
  30. }
  31. }
  32. // handleConnection handles the initial exchange between the server and the client.
  33. // It reads the testconfig message into a config struct. If any errors occur with
  34. // the testconfig (specifically, if there is a version mismatch), it will return those
  35. // errors to the client with a configResponse. After the exchange, it will start
  36. // the speed test.
  37. func handleConnection(conn net.Conn) error {
  38. defer conn.Close()
  39. var conf config
  40. decoder := json.NewDecoder(conn)
  41. err := decoder.Decode(&conf)
  42. encoder := json.NewEncoder(conn)
  43. // Both return and encode errors that occurred before the test started.
  44. if err != nil {
  45. encoder.Encode(configResponse{Error: err.Error()})
  46. return err
  47. }
  48. // The server should always be doing the opposite of what the client is doing.
  49. conf.Direction.Reverse()
  50. if conf.Version != version {
  51. err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version)
  52. encoder.Encode(configResponse{Error: err.Error()})
  53. return err
  54. }
  55. // Start the test
  56. encoder.Encode(configResponse{})
  57. _, err = doTest(conn, conf)
  58. return err
  59. }
  60. // TODO include code to detect whether the code is direct vs DERP
  61. // doTest contains the code to run both the upload and download speedtest.
  62. // the direction value in the config parameter determines which test to run.
  63. func doTest(conn net.Conn, conf config) ([]Result, error) {
  64. bufferData := make([]byte, blockSize)
  65. intervalBytes := 0
  66. totalBytes := 0
  67. var currentTime time.Time
  68. var results []Result
  69. if conf.Direction == Download {
  70. conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second))
  71. } else {
  72. _, err := rand.Read(bufferData)
  73. if err != nil {
  74. return nil, err
  75. }
  76. }
  77. startTime := time.Now()
  78. lastCalculated := startTime
  79. SpeedTestLoop:
  80. for {
  81. var n int
  82. var err error
  83. if conf.Direction == Download {
  84. n, err = io.ReadFull(conn, bufferData)
  85. switch err {
  86. case io.EOF, io.ErrUnexpectedEOF:
  87. break SpeedTestLoop
  88. case nil:
  89. // successful read
  90. default:
  91. return nil, fmt.Errorf("unexpected error has occurred: %w", err)
  92. }
  93. } else {
  94. n, err = conn.Write(bufferData)
  95. if err != nil {
  96. // If the write failed, there is most likely something wrong with the connection.
  97. return nil, fmt.Errorf("upload failed: %w", err)
  98. }
  99. }
  100. intervalBytes += n
  101. currentTime = time.Now()
  102. // checks if the current time is more or equal to the lastCalculated time plus the increment
  103. if currentTime.Sub(lastCalculated) >= increment {
  104. results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
  105. lastCalculated = currentTime
  106. totalBytes += intervalBytes
  107. intervalBytes = 0
  108. }
  109. if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration {
  110. break SpeedTestLoop
  111. }
  112. }
  113. // get last segment
  114. if currentTime.Sub(lastCalculated) > minInterval {
  115. results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
  116. }
  117. // get total
  118. totalBytes += intervalBytes
  119. if currentTime.Sub(startTime) > minInterval {
  120. results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true})
  121. }
  122. return results, nil
  123. }