transfer.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. package sftpd
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "os"
  7. "sync"
  8. "time"
  9. "github.com/drakkan/sftpgo/dataprovider"
  10. "github.com/drakkan/sftpgo/logger"
  11. "github.com/drakkan/sftpgo/metrics"
  12. "github.com/eikenb/pipeat"
  13. )
  14. const (
  15. transferUpload = iota
  16. transferDownload
  17. )
  18. var (
  19. errTransferClosed = errors.New("transfer already closed")
  20. )
  21. // Transfer contains the transfer details for an upload or a download.
  22. // It implements the io Reader and Writer interface to handle files downloads and uploads
  23. type Transfer struct {
  24. file *os.File
  25. writerAt *pipeat.PipeWriterAt
  26. readerAt *pipeat.PipeReaderAt
  27. cancelFn func()
  28. path string
  29. start time.Time
  30. bytesSent int64
  31. bytesReceived int64
  32. user dataprovider.User
  33. connectionID string
  34. transferType int
  35. lastActivity time.Time
  36. protocol string
  37. transferError error
  38. minWriteOffset int64
  39. expectedSize int64
  40. initialSize int64
  41. lock *sync.Mutex
  42. isNewFile bool
  43. isFinished bool
  44. isExcludedFromQuota bool
  45. }
  46. // TransferError is called if there is an unexpected error.
  47. // For example network or client issues
  48. func (t *Transfer) TransferError(err error) {
  49. t.lock.Lock()
  50. defer t.lock.Unlock()
  51. if t.transferError != nil {
  52. return
  53. }
  54. t.transferError = err
  55. if t.cancelFn != nil {
  56. t.cancelFn()
  57. }
  58. elapsed := time.Since(t.start).Nanoseconds() / 1000000
  59. logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
  60. "bytes received: %v transfer running since %v ms", t.path, t.transferError, t.bytesSent, t.bytesReceived, elapsed)
  61. }
  62. // ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
  63. // It handles download bandwidth throttling too
  64. func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) {
  65. t.lastActivity = time.Now()
  66. var readed int
  67. var e error
  68. if t.readerAt != nil {
  69. readed, e = t.readerAt.ReadAt(p, off)
  70. } else {
  71. readed, e = t.file.ReadAt(p, off)
  72. }
  73. t.lock.Lock()
  74. t.bytesSent += int64(readed)
  75. t.lock.Unlock()
  76. if e != nil && e != io.EOF {
  77. t.TransferError(e)
  78. return readed, e
  79. }
  80. t.handleThrottle()
  81. return readed, e
  82. }
  83. // WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received.
  84. // It handles upload bandwidth throttling too
  85. func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
  86. t.lastActivity = time.Now()
  87. if off < t.minWriteOffset {
  88. err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.minWriteOffset)
  89. t.TransferError(err)
  90. return 0, err
  91. }
  92. var written int
  93. var e error
  94. if t.writerAt != nil {
  95. written, e = t.writerAt.WriteAt(p, off)
  96. } else {
  97. written, e = t.file.WriteAt(p, off)
  98. }
  99. t.lock.Lock()
  100. t.bytesReceived += int64(written)
  101. t.lock.Unlock()
  102. if e != nil {
  103. t.TransferError(e)
  104. return written, e
  105. }
  106. t.handleThrottle()
  107. return written, e
  108. }
  109. // Close it is called when the transfer is completed.
  110. // It closes the underlying file, logs the transfer info, updates the user quota (for uploads)
  111. // and executes any defined action.
  112. // If there is an error no action will be executed and, in atomic mode, we try to delete
  113. // the temporary file
  114. func (t *Transfer) Close() error {
  115. t.lock.Lock()
  116. defer t.lock.Unlock()
  117. if t.isFinished {
  118. return errTransferClosed
  119. }
  120. err := t.closeIO()
  121. defer removeTransfer(t) //nolint:errcheck
  122. t.isFinished = true
  123. numFiles := 0
  124. if t.isNewFile {
  125. numFiles = 1
  126. }
  127. metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
  128. if t.transferType == transferUpload && t.file != nil && t.file.Name() != t.path {
  129. if t.transferError == nil || uploadMode == uploadModeAtomicWithResume {
  130. err = os.Rename(t.file.Name(), t.path)
  131. logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v",
  132. t.file.Name(), t.path, err)
  133. } else {
  134. err = os.Remove(t.file.Name())
  135. logger.Warn(logSender, t.connectionID, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+
  136. "deletion error: %v", t.transferError, t.file.Name(), err)
  137. if err == nil {
  138. numFiles--
  139. t.bytesReceived = 0
  140. }
  141. }
  142. }
  143. elapsed := time.Since(t.start).Nanoseconds() / 1000000
  144. if t.transferType == transferDownload {
  145. logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
  146. go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) //nolint:errcheck
  147. } else {
  148. logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
  149. go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, //nolint:errcheck
  150. t.transferError))
  151. }
  152. if t.transferError != nil {
  153. logger.Warn(logSender, t.connectionID, "transfer error: %v, path: %#v", t.transferError, t.path)
  154. if err == nil {
  155. err = t.transferError
  156. }
  157. }
  158. t.updateQuota(numFiles)
  159. return err
  160. }
  161. func (t *Transfer) closeIO() error {
  162. var err error
  163. if t.writerAt != nil {
  164. err = t.writerAt.Close()
  165. } else if t.readerAt != nil {
  166. err = t.readerAt.Close()
  167. } else {
  168. err = t.file.Close()
  169. }
  170. return err
  171. }
  172. func (t *Transfer) updateQuota(numFiles int) bool {
  173. // S3 uploads are atomic, if there is an error nothing is uploaded
  174. if t.file == nil && t.transferError != nil {
  175. return false
  176. }
  177. if t.isExcludedFromQuota {
  178. return false
  179. }
  180. if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
  181. dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
  182. return true
  183. }
  184. return false
  185. }
  186. func (t *Transfer) handleThrottle() {
  187. var wantedBandwidth int64
  188. var trasferredBytes int64
  189. if t.transferType == transferDownload {
  190. wantedBandwidth = t.user.DownloadBandwidth
  191. trasferredBytes = t.bytesSent
  192. } else {
  193. wantedBandwidth = t.user.UploadBandwidth
  194. trasferredBytes = t.bytesReceived
  195. }
  196. if wantedBandwidth > 0 {
  197. // real and wanted elapsed as milliseconds, bytes as kilobytes
  198. realElapsed := time.Since(t.start).Nanoseconds() / 1000000
  199. // trasferredBytes / 1000 = KB/s, we multiply for 1000 to get milliseconds
  200. wantedElapsed := 1000 * (trasferredBytes / 1000) / wantedBandwidth
  201. if wantedElapsed > realElapsed {
  202. toSleep := time.Duration(wantedElapsed - realElapsed)
  203. time.Sleep(toSleep * time.Millisecond)
  204. }
  205. }
  206. }
  207. // used for ssh commands.
  208. // It reads from src until EOF so it does not treat an EOF from Read as an error to be reported.
  209. // EOF from Write is reported as error
  210. func (t *Transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader, maxWriteSize int64) (int64, error) {
  211. var written int64
  212. var err error
  213. if maxWriteSize < 0 {
  214. return 0, errQuotaExceeded
  215. }
  216. buf := make([]byte, 32768)
  217. for {
  218. t.lastActivity = time.Now()
  219. nr, er := src.Read(buf)
  220. if nr > 0 {
  221. nw, ew := dst.Write(buf[0:nr])
  222. if nw > 0 {
  223. written += int64(nw)
  224. if t.transferType == transferDownload {
  225. t.bytesSent = written
  226. } else {
  227. t.bytesReceived = written
  228. }
  229. if maxWriteSize > 0 && written > maxWriteSize {
  230. err = errQuotaExceeded
  231. break
  232. }
  233. }
  234. if ew != nil {
  235. err = ew
  236. break
  237. }
  238. if nr != nw {
  239. err = io.ErrShortWrite
  240. break
  241. }
  242. }
  243. if er != nil {
  244. if er != io.EOF {
  245. err = er
  246. }
  247. break
  248. }
  249. t.handleThrottle()
  250. }
  251. t.transferError = err
  252. if t.bytesSent > 0 || t.bytesReceived > 0 || err != nil {
  253. metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
  254. }
  255. return written, err
  256. }