transfer.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package sftpd
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "os"
  7. "sync"
  8. "time"
  9. "github.com/eikenb/pipeat"
  10. "github.com/drakkan/sftpgo/dataprovider"
  11. "github.com/drakkan/sftpgo/logger"
  12. "github.com/drakkan/sftpgo/metrics"
  13. "github.com/drakkan/sftpgo/vfs"
  14. )
  15. const (
  16. transferUpload = iota
  17. transferDownload
  18. )
  19. var (
  20. errTransferClosed = errors.New("transfer already closed")
  21. )
  22. // Transfer contains the transfer details for an upload or a download.
  23. // It implements the io Reader and Writer interface to handle files downloads and uploads
  24. type Transfer struct {
  25. file *os.File
  26. writerAt *vfs.PipeWriter
  27. readerAt *pipeat.PipeReaderAt
  28. cancelFn func()
  29. path string
  30. start time.Time
  31. bytesSent int64
  32. bytesReceived int64
  33. user dataprovider.User
  34. connectionID string
  35. transferType int
  36. lastActivity time.Time
  37. protocol string
  38. transferError error
  39. minWriteOffset int64
  40. initialSize int64
  41. lock *sync.Mutex
  42. isNewFile bool
  43. isFinished bool
  44. requestPath string
  45. }
  46. // TransferError is called if there is an unexpected error.
  47. // For example network or client issues
  48. func (t *Transfer) TransferError(err error) {
  49. t.lock.Lock()
  50. defer t.lock.Unlock()
  51. if t.transferError != nil {
  52. return
  53. }
  54. t.transferError = err
  55. if t.cancelFn != nil {
  56. t.cancelFn()
  57. }
  58. elapsed := time.Since(t.start).Nanoseconds() / 1000000
  59. logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
  60. "bytes received: %v transfer running since %v ms", t.path, t.transferError, t.bytesSent, t.bytesReceived, elapsed)
  61. }
  62. // ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
  63. // It handles download bandwidth throttling too
  64. func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) {
  65. t.lastActivity = time.Now()
  66. var readed int
  67. var e error
  68. if t.readerAt != nil {
  69. readed, e = t.readerAt.ReadAt(p, off)
  70. } else {
  71. readed, e = t.file.ReadAt(p, off)
  72. }
  73. t.lock.Lock()
  74. t.bytesSent += int64(readed)
  75. t.lock.Unlock()
  76. if e != nil && e != io.EOF {
  77. t.TransferError(e)
  78. return readed, e
  79. }
  80. t.handleThrottle()
  81. return readed, e
  82. }
  83. // WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received.
  84. // It handles upload bandwidth throttling too
  85. func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
  86. t.lastActivity = time.Now()
  87. if off < t.minWriteOffset {
  88. err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.minWriteOffset)
  89. t.TransferError(err)
  90. return 0, err
  91. }
  92. var written int
  93. var e error
  94. if t.writerAt != nil {
  95. written, e = t.writerAt.WriteAt(p, off)
  96. } else {
  97. written, e = t.file.WriteAt(p, off)
  98. }
  99. t.lock.Lock()
  100. t.bytesReceived += int64(written)
  101. t.lock.Unlock()
  102. if e != nil {
  103. t.TransferError(e)
  104. return written, e
  105. }
  106. t.handleThrottle()
  107. return written, e
  108. }
  109. // Close it is called when the transfer is completed.
  110. // It closes the underlying file, logs the transfer info, updates the user quota (for uploads)
  111. // and executes any defined action.
  112. // If there is an error no action will be executed and, in atomic mode, we try to delete
  113. // the temporary file
  114. func (t *Transfer) Close() error {
  115. t.lock.Lock()
  116. defer t.lock.Unlock()
  117. if t.isFinished {
  118. return errTransferClosed
  119. }
  120. err := t.closeIO()
  121. defer removeTransfer(t) //nolint:errcheck
  122. t.isFinished = true
  123. numFiles := 0
  124. if t.isNewFile {
  125. numFiles = 1
  126. }
  127. metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
  128. if t.transferType == transferUpload && t.file != nil && t.file.Name() != t.path {
  129. if t.transferError == nil || uploadMode == uploadModeAtomicWithResume {
  130. err = os.Rename(t.file.Name(), t.path)
  131. logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v",
  132. t.file.Name(), t.path, err)
  133. } else {
  134. err = os.Remove(t.file.Name())
  135. logger.Warn(logSender, t.connectionID, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+
  136. "deletion error: %v", t.transferError, t.file.Name(), err)
  137. if err == nil {
  138. numFiles--
  139. t.bytesReceived = 0
  140. }
  141. }
  142. }
  143. elapsed := time.Since(t.start).Nanoseconds() / 1000000
  144. if t.transferType == transferDownload {
  145. logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
  146. go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) //nolint:errcheck
  147. } else {
  148. logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
  149. go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, //nolint:errcheck
  150. t.transferError))
  151. }
  152. if t.transferError != nil {
  153. logger.Warn(logSender, t.connectionID, "transfer error: %v, path: %#v", t.transferError, t.path)
  154. if err == nil {
  155. err = t.transferError
  156. }
  157. }
  158. t.updateQuota(numFiles)
  159. return err
  160. }
  161. func (t *Transfer) closeIO() error {
  162. var err error
  163. if t.writerAt != nil {
  164. err = t.writerAt.Close()
  165. if err != nil {
  166. t.transferError = err
  167. }
  168. } else if t.readerAt != nil {
  169. err = t.readerAt.Close()
  170. } else {
  171. err = t.file.Close()
  172. }
  173. return err
  174. }
  175. func (t *Transfer) updateQuota(numFiles int) bool {
  176. // S3 uploads are atomic, if there is an error nothing is uploaded
  177. if t.file == nil && t.transferError != nil {
  178. return false
  179. }
  180. if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
  181. vfolder, err := t.user.GetVirtualFolderForPath(t.requestPath)
  182. if err == nil {
  183. dataprovider.UpdateVirtualFolderQuota(dataProvider, vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
  184. t.bytesReceived-t.initialSize, false)
  185. if vfolder.IsIncludedInUserQuota() {
  186. dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
  187. }
  188. } else {
  189. dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
  190. }
  191. return true
  192. }
  193. return false
  194. }
  195. func (t *Transfer) handleThrottle() {
  196. var wantedBandwidth int64
  197. var trasferredBytes int64
  198. if t.transferType == transferDownload {
  199. wantedBandwidth = t.user.DownloadBandwidth
  200. trasferredBytes = t.bytesSent
  201. } else {
  202. wantedBandwidth = t.user.UploadBandwidth
  203. trasferredBytes = t.bytesReceived
  204. }
  205. if wantedBandwidth > 0 {
  206. // real and wanted elapsed as milliseconds, bytes as kilobytes
  207. realElapsed := time.Since(t.start).Nanoseconds() / 1000000
  208. // trasferredBytes / 1000 = KB/s, we multiply for 1000 to get milliseconds
  209. wantedElapsed := 1000 * (trasferredBytes / 1000) / wantedBandwidth
  210. if wantedElapsed > realElapsed {
  211. toSleep := time.Duration(wantedElapsed - realElapsed)
  212. time.Sleep(toSleep * time.Millisecond)
  213. }
  214. }
  215. }
  216. // used for ssh commands.
  217. // It reads from src until EOF so it does not treat an EOF from Read as an error to be reported.
  218. // EOF from Write is reported as error
  219. func (t *Transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader, maxWriteSize int64) (int64, error) {
  220. var written int64
  221. var err error
  222. if maxWriteSize < 0 {
  223. return 0, errQuotaExceeded
  224. }
  225. buf := make([]byte, 32768)
  226. for {
  227. t.lastActivity = time.Now()
  228. nr, er := src.Read(buf)
  229. if nr > 0 {
  230. nw, ew := dst.Write(buf[0:nr])
  231. if nw > 0 {
  232. written += int64(nw)
  233. if t.transferType == transferDownload {
  234. t.bytesSent = written
  235. } else {
  236. t.bytesReceived = written
  237. }
  238. if maxWriteSize > 0 && written > maxWriteSize {
  239. err = errQuotaExceeded
  240. break
  241. }
  242. }
  243. if ew != nil {
  244. err = ew
  245. break
  246. }
  247. if nr != nw {
  248. err = io.ErrShortWrite
  249. break
  250. }
  251. }
  252. if er != nil {
  253. if er != io.EOF {
  254. err = er
  255. }
  256. break
  257. }
  258. t.handleThrottle()
  259. }
  260. t.transferError = err
  261. if t.bytesSent > 0 || t.bytesReceived > 0 || err != nil {
  262. metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
  263. }
  264. return written, err
  265. }