ctxreader.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. // Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tailssh
  5. import (
  6. "context"
  7. "io"
  8. "sync"
  9. "tailscale.com/tempfork/gliderlabs/ssh"
  10. )
  11. // readResult is a result from a io.Reader.Read call,
  12. // as used by contextReader.
  13. type readResult struct {
  14. buf []byte // ownership passed on chan send
  15. err error
  16. }
  17. // contextReader wraps an io.Reader, providing a ReadContext method
  18. // that can be aborted before yielding bytes. If it's aborted, subsequent
  19. // reads can get those byte(s) later.
  20. type contextReader struct {
  21. r io.Reader
  22. // buffered is leftover data from a previous read call that wasn't entirely
  23. // consumed.
  24. buffered []byte
  25. // readErr is a previous read error that was seen while filling buffered. It
  26. // should be returned to the caller after bufffered is consumed.
  27. readErr error
  28. mu sync.Mutex // guards ch only
  29. // ch is non-nil if a goroutine had been started and has a result to be
  30. // read. The goroutine may be either still running or done and has
  31. // send to the channel.
  32. ch chan readResult
  33. }
  34. // HasOutstandingRead reports whether there's an oustanding Read call that's
  35. // either currently blocked in a Read or whose result hasn't been consumed.
  36. func (w *contextReader) HasOutstandingRead() bool {
  37. w.mu.Lock()
  38. defer w.mu.Unlock()
  39. return w.ch != nil
  40. }
  41. func (w *contextReader) setChan(c chan readResult) {
  42. w.mu.Lock()
  43. defer w.mu.Unlock()
  44. w.ch = c
  45. }
  46. // ReadContext is like Read, but takes a context permitting the read to be canceled.
  47. //
  48. // If the context becomes done, the underlying Read call continues and its result
  49. // will be given to the next caller to ReadContext.
  50. func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) {
  51. if len(p) == 0 {
  52. return 0, nil
  53. }
  54. n = copy(p, w.buffered)
  55. if n > 0 {
  56. w.buffered = w.buffered[n:]
  57. if len(w.buffered) == 0 {
  58. err = w.readErr
  59. }
  60. return n, err
  61. }
  62. if w.ch == nil {
  63. ch := make(chan readResult, 1)
  64. w.setChan(ch)
  65. go func() {
  66. rbuf := make([]byte, len(p))
  67. n, err := w.r.Read(rbuf)
  68. ch <- readResult{rbuf[:n], err}
  69. }()
  70. }
  71. select {
  72. case <-ctx.Done():
  73. return 0, ctx.Err()
  74. case rr := <-w.ch:
  75. w.setChan(nil)
  76. n = copy(p, rr.buf)
  77. w.buffered = rr.buf[n:]
  78. w.readErr = rr.err
  79. if len(w.buffered) == 0 {
  80. err = rr.err
  81. }
  82. return n, err
  83. }
  84. }
  85. // contextReaderSesssion implements ssh.Session, wrapping another
  86. // ssh.Session but changing its Read method to use contextReader.
  87. type contextReaderSesssion struct {
  88. ssh.Session
  89. cr *contextReader
  90. }
  91. func (a contextReaderSesssion) Read(p []byte) (n int, err error) {
  92. if a.cr.HasOutstandingRead() {
  93. return a.cr.ReadContext(context.Background(), p)
  94. }
  95. return a.Session.Read(p)
  96. }