transfer.go 8.4 KB


  1. package sftpd
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "os"
  7. "path"
  8. "sync"
  9. "time"
  10. "github.com/eikenb/pipeat"
  11. "github.com/drakkan/sftpgo/dataprovider"
  12. "github.com/drakkan/sftpgo/logger"
  13. "github.com/drakkan/sftpgo/metrics"
  14. "github.com/drakkan/sftpgo/vfs"
  15. )
  16. const (
  17. transferUpload = iota
  18. transferDownload
  19. )
  20. var (
  21. errTransferClosed = errors.New("transfer already closed")
  22. )
  23. // Transfer contains the transfer details for an upload or a download.
  24. // It implements the io Reader and Writer interface to handle files downloads and uploads
  25. type Transfer struct {
  26. file *os.File
  27. writerAt *vfs.PipeWriter
  28. readerAt *pipeat.PipeReaderAt
  29. cancelFn func()
  30. path string
  31. start time.Time
  32. bytesSent int64
  33. bytesReceived int64
  34. user dataprovider.User
  35. connectionID string
  36. transferType int
  37. lastActivity time.Time
  38. protocol string
  39. transferError error
  40. minWriteOffset int64
  41. initialSize int64
  42. lock *sync.Mutex
  43. isNewFile bool
  44. isFinished bool
  45. requestPath string
  46. maxWriteSize int64
  47. }
  48. // TransferError is called if there is an unexpected error.
  49. // For example network or client issues
  50. func (t *Transfer) TransferError(err error) {
  51. t.lock.Lock()
  52. defer t.lock.Unlock()
  53. if t.transferError != nil {
  54. return
  55. }
  56. t.transferError = err
  57. if t.cancelFn != nil {
  58. t.cancelFn()
  59. }
  60. elapsed := time.Since(t.start).Nanoseconds() / 1000000
  61. logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
  62. "bytes received: %v transfer running since %v ms", t.path, t.transferError, t.bytesSent, t.bytesReceived, elapsed)
  63. }
  64. // ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
  65. // It handles download bandwidth throttling too
  66. func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) {
  67. t.lastActivity = time.Now()
  68. var readed int
  69. var e error
  70. if t.readerAt != nil {
  71. readed, e = t.readerAt.ReadAt(p, off)
  72. } else {
  73. readed, e = t.file.ReadAt(p, off)
  74. }
  75. t.lock.Lock()
  76. t.bytesSent += int64(readed)
  77. t.lock.Unlock()
  78. if e != nil && e != io.EOF {
  79. t.TransferError(e)
  80. return readed, e
  81. }
  82. t.handleThrottle()
  83. return readed, e
  84. }
  85. // WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received.
  86. // It handles upload bandwidth throttling too
  87. func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
  88. t.lastActivity = time.Now()
  89. if off < t.minWriteOffset {
  90. err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.minWriteOffset)
  91. t.TransferError(err)
  92. return 0, err
  93. }
  94. var written int
  95. var e error
  96. if t.writerAt != nil {
  97. written, e = t.writerAt.WriteAt(p, off)
  98. } else {
  99. written, e = t.file.WriteAt(p, off)
  100. }
  101. t.lock.Lock()
  102. t.bytesReceived += int64(written)
  103. if e == nil && t.maxWriteSize > 0 && t.bytesReceived > t.maxWriteSize {
  104. e = errQuotaExceeded
  105. }
  106. t.lock.Unlock()
  107. if e != nil {
  108. t.TransferError(e)
  109. return written, e
  110. }
  111. t.handleThrottle()
  112. return written, e
  113. }
  114. // Close it is called when the transfer is completed.
  115. // It closes the underlying file, logs the transfer info, updates the user quota (for uploads)
  116. // and executes any defined action.
  117. // If there is an error no action will be executed and, in atomic mode, we try to delete
  118. // the temporary file
  119. func (t *Transfer) Close() error {
  120. t.lock.Lock()
  121. defer t.lock.Unlock()
  122. if t.isFinished {
  123. return errTransferClosed
  124. }
  125. err := t.closeIO()
  126. defer removeTransfer(t) //nolint:errcheck
  127. t.isFinished = true
  128. numFiles := 0
  129. if t.isNewFile {
  130. numFiles = 1
  131. }
  132. metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
  133. if t.transferError == errQuotaExceeded && t.file != nil {
  134. // if quota is exceeded we try to remove the partial file for uploads to local filesystem
  135. err = os.Remove(t.file.Name())
  136. if err == nil {
  137. numFiles--
  138. t.bytesReceived = 0
  139. t.minWriteOffset = 0
  140. }
  141. logger.Warn(logSender, t.connectionID, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
  142. t.file.Name(), err)
  143. } else if t.transferType == transferUpload && t.file != nil && t.file.Name() != t.path {
  144. if t.transferError == nil || uploadMode == uploadModeAtomicWithResume {
  145. err = os.Rename(t.file.Name(), t.path)
  146. logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v",
  147. t.file.Name(), t.path, err)
  148. } else {
  149. err = os.Remove(t.file.Name())
  150. logger.Warn(logSender, t.connectionID, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+
  151. "deletion error: %v", t.transferError, t.file.Name(), err)
  152. if err == nil {
  153. numFiles--
  154. t.bytesReceived = 0
  155. t.minWriteOffset = 0
  156. }
  157. }
  158. }
  159. elapsed := time.Since(t.start).Nanoseconds() / 1000000
  160. if t.transferType == transferDownload {
  161. logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
  162. go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) //nolint:errcheck
  163. } else {
  164. logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
  165. go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, //nolint:errcheck
  166. t.transferError))
  167. }
  168. if t.transferError != nil {
  169. logger.Warn(logSender, t.connectionID, "transfer error: %v, path: %#v", t.transferError, t.path)
  170. if err == nil {
  171. err = t.transferError
  172. }
  173. }
  174. t.updateQuota(numFiles)
  175. return err
  176. }
  177. func (t *Transfer) closeIO() error {
  178. var err error
  179. if t.writerAt != nil {
  180. err = t.writerAt.Close()
  181. if err != nil {
  182. t.transferError = err
  183. }
  184. } else if t.readerAt != nil {
  185. err = t.readerAt.Close()
  186. } else {
  187. err = t.file.Close()
  188. }
  189. return err
  190. }
  191. func (t *Transfer) updateQuota(numFiles int) bool {
  192. // S3 uploads are atomic, if there is an error nothing is uploaded
  193. if t.file == nil && t.transferError != nil {
  194. return false
  195. }
  196. if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
  197. vfolder, err := t.user.GetVirtualFolderForPath(path.Dir(t.requestPath))
  198. if err == nil {
  199. dataprovider.UpdateVirtualFolderQuota(dataProvider, vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
  200. t.bytesReceived-t.initialSize, false)
  201. if vfolder.IsIncludedInUserQuota() {
  202. dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
  203. }
  204. } else {
  205. dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
  206. }
  207. return true
  208. }
  209. return false
  210. }
  211. func (t *Transfer) handleThrottle() {
  212. var wantedBandwidth int64
  213. var trasferredBytes int64
  214. if t.transferType == transferDownload {
  215. wantedBandwidth = t.user.DownloadBandwidth
  216. trasferredBytes = t.bytesSent
  217. } else {
  218. wantedBandwidth = t.user.UploadBandwidth
  219. trasferredBytes = t.bytesReceived
  220. }
  221. if wantedBandwidth > 0 {
  222. // real and wanted elapsed as milliseconds, bytes as kilobytes
  223. realElapsed := time.Since(t.start).Nanoseconds() / 1000000
  224. // trasferredBytes / 1000 = KB/s, we multiply for 1000 to get milliseconds
  225. wantedElapsed := 1000 * (trasferredBytes / 1000) / wantedBandwidth
  226. if wantedElapsed > realElapsed {
  227. toSleep := time.Duration(wantedElapsed - realElapsed)
  228. time.Sleep(toSleep * time.Millisecond)
  229. }
  230. }
  231. }
  232. // used for ssh commands.
  233. // It reads from src until EOF so it does not treat an EOF from Read as an error to be reported.
  234. // EOF from Write is reported as error
  235. func (t *Transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, error) {
  236. var written int64
  237. var err error
  238. if t.maxWriteSize < 0 {
  239. return 0, errQuotaExceeded
  240. }
  241. buf := make([]byte, 32768)
  242. for {
  243. t.lastActivity = time.Now()
  244. nr, er := src.Read(buf)
  245. if nr > 0 {
  246. nw, ew := dst.Write(buf[0:nr])
  247. if nw > 0 {
  248. written += int64(nw)
  249. if t.transferType == transferDownload {
  250. t.bytesSent = written
  251. } else {
  252. t.bytesReceived = written
  253. }
  254. if t.maxWriteSize > 0 && written > t.maxWriteSize {
  255. err = errQuotaExceeded
  256. break
  257. }
  258. }
  259. if ew != nil {
  260. err = ew
  261. break
  262. }
  263. if nr != nw {
  264. err = io.ErrShortWrite
  265. break
  266. }
  267. }
  268. if er != nil {
  269. if er != io.EOF {
  270. err = er
  271. }
  272. break
  273. }
  274. t.handleThrottle()
  275. }
  276. t.transferError = err
  277. if t.bytesSent > 0 || t.bytesReceived > 0 || err != nil {
  278. metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
  279. }
  280. return written, err
  281. }