| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- package sftpd
- import (
- "errors"
- "fmt"
- "io"
- "os"
- "sync"
- "time"
- "github.com/drakkan/sftpgo/dataprovider"
- "github.com/drakkan/sftpgo/logger"
- "github.com/drakkan/sftpgo/metrics"
- "github.com/eikenb/pipeat"
- )
- const (
- transferUpload = iota
- transferDownload
- )
- var (
- errTransferClosed = errors.New("transfer already closed")
- )
- // Transfer contains the transfer details for an upload or a download.
- // It implements the io Reader and Writer interface to handle files downloads and uploads
- type Transfer struct {
- file *os.File
- writerAt *pipeat.PipeWriterAt
- readerAt *pipeat.PipeReaderAt
- cancelFn func()
- path string
- start time.Time
- bytesSent int64
- bytesReceived int64
- user dataprovider.User
- connectionID string
- transferType int
- lastActivity time.Time
- protocol string
- transferError error
- minWriteOffset int64
- expectedSize int64
- initialSize int64
- lock *sync.Mutex
- isNewFile bool
- isFinished bool
- isExcludedFromQuota bool
- }
- // TransferError is called if there is an unexpected error.
- // For example network or client issues
- func (t *Transfer) TransferError(err error) {
- t.lock.Lock()
- defer t.lock.Unlock()
- if t.transferError != nil {
- return
- }
- t.transferError = err
- if t.cancelFn != nil {
- t.cancelFn()
- }
- elapsed := time.Since(t.start).Nanoseconds() / 1000000
- logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
- "bytes received: %v transfer running since %v ms", t.path, t.transferError, t.bytesSent, t.bytesReceived, elapsed)
- }
- // ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
- // It handles download bandwidth throttling too
- func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) {
- t.lastActivity = time.Now()
- var readed int
- var e error
- if t.readerAt != nil {
- readed, e = t.readerAt.ReadAt(p, off)
- } else {
- readed, e = t.file.ReadAt(p, off)
- }
- t.lock.Lock()
- t.bytesSent += int64(readed)
- t.lock.Unlock()
- if e != nil && e != io.EOF {
- t.TransferError(e)
- return readed, e
- }
- t.handleThrottle()
- return readed, e
- }
- // WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received.
- // It handles upload bandwidth throttling too
- func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
- t.lastActivity = time.Now()
- if off < t.minWriteOffset {
- err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.minWriteOffset)
- t.TransferError(err)
- return 0, err
- }
- var written int
- var e error
- if t.writerAt != nil {
- written, e = t.writerAt.WriteAt(p, off)
- } else {
- written, e = t.file.WriteAt(p, off)
- }
- t.lock.Lock()
- t.bytesReceived += int64(written)
- t.lock.Unlock()
- if e != nil {
- t.TransferError(e)
- return written, e
- }
- t.handleThrottle()
- return written, e
- }
- // Close it is called when the transfer is completed.
- // It closes the underlying file, logs the transfer info, updates the user quota (for uploads)
- // and executes any defined action.
- // If there is an error no action will be executed and, in atomic mode, we try to delete
- // the temporary file
- func (t *Transfer) Close() error {
- t.lock.Lock()
- defer t.lock.Unlock()
- if t.isFinished {
- return errTransferClosed
- }
- err := t.closeIO()
- defer removeTransfer(t) //nolint:errcheck
- t.isFinished = true
- numFiles := 0
- if t.isNewFile {
- numFiles = 1
- }
- metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
- if t.transferType == transferUpload && t.file != nil && t.file.Name() != t.path {
- if t.transferError == nil || uploadMode == uploadModeAtomicWithResume {
- err = os.Rename(t.file.Name(), t.path)
- logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v",
- t.file.Name(), t.path, err)
- } else {
- err = os.Remove(t.file.Name())
- logger.Warn(logSender, t.connectionID, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+
- "deletion error: %v", t.transferError, t.file.Name(), err)
- if err == nil {
- numFiles--
- t.bytesReceived = 0
- }
- }
- }
- elapsed := time.Since(t.start).Nanoseconds() / 1000000
- if t.transferType == transferDownload {
- logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
- go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) //nolint:errcheck
- } else {
- logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
- go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, //nolint:errcheck
- t.transferError))
- }
- if t.transferError != nil {
- logger.Warn(logSender, t.connectionID, "transfer error: %v, path: %#v", t.transferError, t.path)
- if err == nil {
- err = t.transferError
- }
- }
- t.updateQuota(numFiles)
- return err
- }
- func (t *Transfer) closeIO() error {
- var err error
- if t.writerAt != nil {
- err = t.writerAt.Close()
- } else if t.readerAt != nil {
- err = t.readerAt.Close()
- } else {
- err = t.file.Close()
- }
- return err
- }
- func (t *Transfer) updateQuota(numFiles int) bool {
- // S3 uploads are atomic, if there is an error nothing is uploaded
- if t.file == nil && t.transferError != nil {
- return false
- }
- if t.isExcludedFromQuota {
- return false
- }
- if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
- dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
- return true
- }
- return false
- }
- func (t *Transfer) handleThrottle() {
- var wantedBandwidth int64
- var trasferredBytes int64
- if t.transferType == transferDownload {
- wantedBandwidth = t.user.DownloadBandwidth
- trasferredBytes = t.bytesSent
- } else {
- wantedBandwidth = t.user.UploadBandwidth
- trasferredBytes = t.bytesReceived
- }
- if wantedBandwidth > 0 {
- // real and wanted elapsed as milliseconds, bytes as kilobytes
- realElapsed := time.Since(t.start).Nanoseconds() / 1000000
- // trasferredBytes / 1000 = KB/s, we multiply for 1000 to get milliseconds
- wantedElapsed := 1000 * (trasferredBytes / 1000) / wantedBandwidth
- if wantedElapsed > realElapsed {
- toSleep := time.Duration(wantedElapsed - realElapsed)
- time.Sleep(toSleep * time.Millisecond)
- }
- }
- }
- // used for ssh commands.
- // It reads from src until EOF so it does not treat an EOF from Read as an error to be reported.
- // EOF from Write is reported as error
- func (t *Transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader, maxWriteSize int64) (int64, error) {
- var written int64
- var err error
- if maxWriteSize < 0 {
- return 0, errQuotaExceeded
- }
- buf := make([]byte, 32768)
- for {
- t.lastActivity = time.Now()
- nr, er := src.Read(buf)
- if nr > 0 {
- nw, ew := dst.Write(buf[0:nr])
- if nw > 0 {
- written += int64(nw)
- if t.transferType == transferDownload {
- t.bytesSent = written
- } else {
- t.bytesReceived = written
- }
- if maxWriteSize > 0 && written > maxWriteSize {
- err = errQuotaExceeded
- break
- }
- }
- if ew != nil {
- err = ew
- break
- }
- if nr != nw {
- err = io.ErrShortWrite
- break
- }
- }
- if er != nil {
- if er != io.EOF {
- err = er
- }
- break
- }
- t.handleThrottle()
- }
- t.transferError = err
- if t.bytesSent > 0 || t.bytesReceived > 0 || err != nil {
- metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
- }
- return written, err
- }
|