scp.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package sftpd
  15. import (
  16. "errors"
  17. "fmt"
  18. "io"
  19. "math"
  20. "os"
  21. "path"
  22. "path/filepath"
  23. "runtime/debug"
  24. "strconv"
  25. "strings"
  26. "github.com/drakkan/sftpgo/v2/internal/common"
  27. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  28. "github.com/drakkan/sftpgo/v2/internal/logger"
  29. "github.com/drakkan/sftpgo/v2/internal/vfs"
  30. )
  31. var (
  32. okMsg = []byte{0x00}
  33. warnMsg = []byte{0x01} // must be followed by an optional message and a newline
  34. errMsg = []byte{0x02} // must be followed by an optional message and a newline
  35. newLine = []byte{0x0A}
  36. )
  37. type scpCommand struct {
  38. sshCommand
  39. }
  40. func (c *scpCommand) handle() (err error) {
  41. defer func() {
  42. if r := recover(); r != nil {
  43. logger.Error(logSender, "", "panic in handle scp command: %q stack trace: %v", r, string(debug.Stack()))
  44. err = common.ErrGenericFailure
  45. }
  46. }()
  47. if err := common.Connections.Add(c.connection); err != nil {
  48. logger.Info(logSender, "", "unable to add SCP connection: %v", err)
  49. return err
  50. }
  51. defer common.Connections.Remove(c.connection.GetID())
  52. destPath := c.getDestPath()
  53. c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %s, dest path: %q",
  54. c.args, c.connection.User.Username, destPath)
  55. if c.hasFlag("t") {
  56. // -t means "to", so upload
  57. err = c.sendConfirmationMessage()
  58. if err != nil {
  59. return err
  60. }
  61. err = c.handleRecursiveUpload()
  62. if err != nil {
  63. return err
  64. }
  65. } else if c.hasFlag("f") {
  66. // -f means "from" so download
  67. err = c.readConfirmationMessage()
  68. if err != nil {
  69. return err
  70. }
  71. err = c.handleDownload(destPath)
  72. if err != nil {
  73. return err
  74. }
  75. } else {
  76. err = fmt.Errorf("scp command not supported, args: %v", c.args)
  77. c.connection.Log(logger.LevelDebug, "unsupported scp command, args: %v", c.args)
  78. }
  79. c.sendExitStatus(err)
  80. return err
  81. }
  82. func (c *scpCommand) handleRecursiveUpload() error {
  83. numDirs := 0
  84. destPath := c.getDestPath()
  85. for {
  86. fs, err := c.connection.User.GetFilesystemForPath(destPath, c.connection.ID)
  87. if err != nil {
  88. c.connection.Log(logger.LevelError, "error uploading file %q: %+v", destPath, err)
  89. c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %q", destPath))
  90. return err
  91. }
  92. command, err := c.getNextUploadProtocolMessage()
  93. if err != nil {
  94. if errors.Is(err, io.EOF) {
  95. return nil
  96. }
  97. c.sendErrorMessage(fs, err)
  98. return err
  99. }
  100. if strings.HasPrefix(command, "E") {
  101. numDirs--
  102. c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs)
  103. if numDirs < 0 {
  104. err = errors.New("unacceptable end dir command")
  105. c.sendErrorMessage(nil, err)
  106. return err
  107. }
  108. // the destination dir is now the parent directory
  109. destPath = path.Join(destPath, "..")
  110. } else {
  111. sizeToRead, name, err := c.parseUploadMessage(fs, command)
  112. if err != nil {
  113. return err
  114. }
  115. if strings.HasPrefix(command, "D") {
  116. numDirs++
  117. destPath = path.Join(destPath, name)
  118. fs, err = c.connection.User.GetFilesystemForPath(destPath, c.connection.ID)
  119. if err != nil {
  120. c.connection.Log(logger.LevelError, "error uploading file %q: %+v", destPath, err)
  121. c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %q", destPath))
  122. return err
  123. }
  124. err = c.handleCreateDir(fs, destPath)
  125. if err != nil {
  126. return err
  127. }
  128. c.connection.Log(logger.LevelDebug, "received start dir command, num dirs: %v destPath: %q", numDirs, destPath)
  129. } else if strings.HasPrefix(command, "C") {
  130. err = c.handleUpload(c.getFileUploadDestPath(fs, destPath, name), sizeToRead)
  131. if err != nil {
  132. return err
  133. }
  134. }
  135. }
  136. err = c.sendConfirmationMessage()
  137. if err != nil {
  138. return err
  139. }
  140. }
  141. }
  142. func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error {
  143. c.connection.UpdateLastActivity()
  144. p, err := fs.ResolvePath(dirPath)
  145. if err != nil {
  146. c.connection.Log(logger.LevelError, "error creating dir: %q, invalid file path, err: %v", dirPath, err)
  147. c.sendErrorMessage(fs, err)
  148. return err
  149. }
  150. if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(dirPath)) {
  151. c.connection.Log(logger.LevelError, "error creating dir: %q, permission denied", dirPath)
  152. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  153. return common.ErrPermissionDenied
  154. }
  155. info, err := c.connection.DoStat(dirPath, 1, true)
  156. if err == nil && info.IsDir() {
  157. return nil
  158. }
  159. err = c.createDir(fs, p)
  160. if err != nil {
  161. return err
  162. }
  163. c.connection.Log(logger.LevelDebug, "created dir %q", dirPath)
  164. return nil
  165. }
  166. // we need to close the transfer if we have an error
  167. func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) error {
  168. err := c.sendConfirmationMessage()
  169. if err != nil {
  170. transfer.TransferError(err)
  171. transfer.Close()
  172. return err
  173. }
  174. if sizeToRead > 0 {
  175. // we could replace this method with io.CopyN implementing "Write" method in transfer struct
  176. remaining := sizeToRead
  177. buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
  178. for {
  179. n, err := c.connection.channel.Read(buf)
  180. if err != nil {
  181. transfer.TransferError(err)
  182. transfer.Close()
  183. c.sendErrorMessage(transfer.Fs, err)
  184. return err
  185. }
  186. _, err = transfer.WriteAt(buf[:n], sizeToRead-remaining)
  187. if err != nil {
  188. transfer.Close()
  189. c.sendErrorMessage(transfer.Fs, err)
  190. return err
  191. }
  192. remaining -= int64(n)
  193. if remaining <= 0 {
  194. break
  195. }
  196. if remaining < int64(len(buf)) {
  197. buf = make([]byte, remaining)
  198. }
  199. }
  200. }
  201. err = c.readConfirmationMessage()
  202. if err != nil {
  203. transfer.TransferError(err)
  204. transfer.Close()
  205. return err
  206. }
  207. err = transfer.Close()
  208. if err != nil {
  209. c.sendErrorMessage(transfer.Fs, err)
  210. return err
  211. }
  212. return nil
  213. }
  214. func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
  215. diskQuota, transferQuota := c.connection.HasSpace(isNewFile, false, requestPath)
  216. if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() {
  217. err := fmt.Errorf("denying file write due to quota limits")
  218. c.connection.Log(logger.LevelError, "error uploading file: %q, err: %v", filePath, err)
  219. c.sendErrorMessage(nil, err)
  220. return err
  221. }
  222. _, err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath,
  223. fileSize, os.O_TRUNC)
  224. if err != nil {
  225. c.connection.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err)
  226. err = c.connection.GetPermissionDeniedError()
  227. c.sendErrorMessage(fs, err)
  228. return err
  229. }
  230. maxWriteSize, _ := c.connection.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported())
  231. file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.connection.GetCreateChecks(requestPath, isNewFile, false))
  232. if err != nil {
  233. c.connection.Log(logger.LevelError, "error creating file %q: %v", resolvedPath, err)
  234. c.sendErrorMessage(fs, err)
  235. return err
  236. }
  237. initialSize := int64(0)
  238. truncatedSize := int64(0) // bytes truncated and not included in quota
  239. if !isNewFile {
  240. if vfs.HasTruncateSupport(fs) {
  241. vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
  242. if err == nil {
  243. dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
  244. if vfolder.IsIncludedInUserQuota() {
  245. dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
  246. }
  247. } else {
  248. dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
  249. }
  250. } else {
  251. initialSize = fileSize
  252. truncatedSize = initialSize
  253. }
  254. if maxWriteSize > 0 {
  255. maxWriteSize += fileSize
  256. }
  257. }
  258. vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
  259. baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
  260. common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota)
  261. t := newTransfer(baseTransfer, w, nil, nil)
  262. return c.getUploadFileData(sizeToRead, t)
  263. }
  264. func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error {
  265. c.connection.UpdateLastActivity()
  266. fs, p, err := c.connection.GetFsAndResolvedPath(uploadFilePath)
  267. if err != nil {
  268. c.connection.Log(logger.LevelError, "error uploading file: %q, err: %v", uploadFilePath, err)
  269. c.sendErrorMessage(nil, err)
  270. return err
  271. }
  272. if ok, _ := c.connection.User.IsFileAllowed(uploadFilePath); !ok {
  273. c.connection.Log(logger.LevelWarn, "writing file %q is not allowed", uploadFilePath)
  274. c.sendErrorMessage(fs, c.connection.GetPermissionDeniedError())
  275. return common.ErrPermissionDenied
  276. }
  277. filePath := p
  278. if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
  279. filePath = fs.GetAtomicUploadPath(p)
  280. }
  281. stat, statErr := fs.Lstat(p)
  282. if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) {
  283. if !c.connection.User.HasPerm(dataprovider.PermUpload, path.Dir(uploadFilePath)) {
  284. c.connection.Log(logger.LevelWarn, "cannot upload file: %q, permission denied", uploadFilePath)
  285. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  286. return common.ErrPermissionDenied
  287. }
  288. return c.handleUploadFile(fs, p, filePath, sizeToRead, true, 0, uploadFilePath)
  289. }
  290. if statErr != nil {
  291. c.connection.Log(logger.LevelError, "error performing file stat %q: %v", p, statErr)
  292. c.sendErrorMessage(fs, statErr)
  293. return statErr
  294. }
  295. if stat.IsDir() {
  296. c.connection.Log(logger.LevelError, "attempted to open a directory for writing to: %q", p)
  297. err = fmt.Errorf("attempted to open a directory for writing: %q", p)
  298. c.sendErrorMessage(fs, err)
  299. return err
  300. }
  301. if !c.connection.User.HasPerm(dataprovider.PermOverwrite, uploadFilePath) {
  302. c.connection.Log(logger.LevelWarn, "cannot overwrite file: %q, permission denied", uploadFilePath)
  303. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  304. return common.ErrPermissionDenied
  305. }
  306. if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
  307. _, _, err = fs.Rename(p, filePath)
  308. if err != nil {
  309. c.connection.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %v",
  310. p, filePath, err)
  311. c.sendErrorMessage(fs, err)
  312. return err
  313. }
  314. }
  315. return c.handleUploadFile(fs, p, filePath, sizeToRead, false, stat.Size(), uploadFilePath)
  316. }
  317. func (c *scpCommand) sendDownloadProtocolMessages(virtualDirPath string, stat os.FileInfo) error {
  318. var err error
  319. if c.sendFileTime() {
  320. modTime := stat.ModTime().UnixNano() / 1000000000
  321. tCommand := fmt.Sprintf("T%d 0 %d 0\n", modTime, modTime)
  322. err = c.sendProtocolMessage(tCommand)
  323. if err != nil {
  324. return err
  325. }
  326. err = c.readConfirmationMessage()
  327. if err != nil {
  328. return err
  329. }
  330. }
  331. dirName := path.Base(virtualDirPath)
  332. if dirName == "/" || dirName == "." {
  333. dirName = c.connection.User.Username
  334. }
  335. fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), dirName)
  336. err = c.sendProtocolMessage(fileMode)
  337. if err != nil {
  338. return err
  339. }
  340. err = c.readConfirmationMessage()
  341. return err
  342. }
  343. // We send first all the files in the root directory and then the directories.
  344. // For each directory we recursively call this method again
  345. func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath string, stat os.FileInfo) error {
  346. var err error
  347. if c.isRecursive() {
  348. c.connection.Log(logger.LevelDebug, "recursive download, dir path %q virtual path %q", dirPath, virtualPath)
  349. err = c.sendDownloadProtocolMessages(virtualPath, stat)
  350. if err != nil {
  351. return err
  352. }
  353. // dirPath is a fs path, not a virtual path
  354. lister, err := fs.ReadDir(dirPath)
  355. if err != nil {
  356. c.sendErrorMessage(fs, err)
  357. return err
  358. }
  359. defer lister.Close()
  360. vdirs := c.connection.User.GetVirtualFoldersInfo(virtualPath)
  361. var dirs []string
  362. for {
  363. files, err := lister.Next(vfs.ListerBatchSize)
  364. finished := errors.Is(err, io.EOF)
  365. if err != nil && !finished {
  366. c.sendErrorMessage(fs, err)
  367. return err
  368. }
  369. files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath))
  370. if len(vdirs) > 0 {
  371. files = append(files, vdirs...)
  372. vdirs = nil
  373. }
  374. for _, file := range files {
  375. filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name()))
  376. if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 {
  377. err = c.handleDownload(filePath)
  378. if err != nil {
  379. c.sendErrorMessage(fs, err)
  380. return err
  381. }
  382. } else if file.IsDir() {
  383. dirs = append(dirs, filePath)
  384. }
  385. }
  386. if finished {
  387. break
  388. }
  389. }
  390. lister.Close()
  391. return c.downloadDirs(fs, dirs)
  392. }
  393. err = errors.New("unable to send directory for non recursive copy")
  394. c.sendErrorMessage(nil, err)
  395. return err
  396. }
  397. func (c *scpCommand) downloadDirs(fs vfs.Fs, dirs []string) error {
  398. for _, dir := range dirs {
  399. if err := c.handleDownload(dir); err != nil {
  400. c.sendErrorMessage(fs, err)
  401. return err
  402. }
  403. }
  404. if err := c.sendProtocolMessage("E\n"); err != nil {
  405. return err
  406. }
  407. return c.readConfirmationMessage()
  408. }
  409. func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error {
  410. var err error
  411. if c.sendFileTime() {
  412. modTime := stat.ModTime().UnixNano() / 1000000000
  413. tCommand := fmt.Sprintf("T%d 0 %d 0\n", modTime, modTime)
  414. err = c.sendProtocolMessage(tCommand)
  415. if err != nil {
  416. return err
  417. }
  418. err = c.readConfirmationMessage()
  419. if err != nil {
  420. return err
  421. }
  422. }
  423. if vfs.IsCryptOsFs(fs) {
  424. stat = fs.(*vfs.CryptFs).ConvertFileInfo(stat)
  425. }
  426. fileSize := stat.Size()
  427. readed := int64(0)
  428. fileMode := fmt.Sprintf("C%v %v %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), fileSize, filepath.Base(filePath))
  429. err = c.sendProtocolMessage(fileMode)
  430. if err != nil {
  431. return err
  432. }
  433. err = c.readConfirmationMessage()
  434. if err != nil {
  435. return err
  436. }
  437. // we could replace this method with io.CopyN implementing "Read" method in transfer struct
  438. buf := make([]byte, 32768)
  439. var n int
  440. for {
  441. n, err = transfer.ReadAt(buf, readed)
  442. if err == nil || err == io.EOF {
  443. if n > 0 {
  444. _, err = c.connection.channel.Write(buf[:n])
  445. }
  446. }
  447. readed += int64(n)
  448. if err != nil {
  449. break
  450. }
  451. }
  452. if err != io.EOF {
  453. c.sendErrorMessage(fs, err)
  454. return err
  455. }
  456. err = c.sendConfirmationMessage()
  457. if err != nil {
  458. return err
  459. }
  460. err = c.readConfirmationMessage()
  461. return err
  462. }
  463. func (c *scpCommand) handleDownload(filePath string) error {
  464. c.connection.UpdateLastActivity()
  465. transferQuota := c.connection.GetTransferQuota()
  466. if !transferQuota.HasDownloadSpace() {
  467. c.connection.Log(logger.LevelInfo, "denying file read due to quota limits")
  468. c.sendErrorMessage(nil, c.connection.GetReadQuotaExceededError())
  469. return c.connection.GetReadQuotaExceededError()
  470. }
  471. var err error
  472. fs, p, err := c.connection.GetFsAndResolvedPath(filePath)
  473. if err != nil {
  474. c.connection.Log(logger.LevelError, "error downloading file %q: %+v", filePath, err)
  475. c.sendErrorMessage(nil, fmt.Errorf("unable to download file %q: %w", filePath, err))
  476. return err
  477. }
  478. var stat os.FileInfo
  479. if stat, err = fs.Stat(p); err != nil {
  480. c.connection.Log(logger.LevelError, "error downloading file: %q->%q, err: %v", filePath, p, err)
  481. c.sendErrorMessage(fs, err)
  482. return err
  483. }
  484. if stat.IsDir() {
  485. if !c.connection.User.HasPerm(dataprovider.PermDownload, filePath) {
  486. c.connection.Log(logger.LevelWarn, "error downloading dir: %q, permission denied", filePath)
  487. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  488. return common.ErrPermissionDenied
  489. }
  490. err = c.handleRecursiveDownload(fs, p, filePath, stat)
  491. return err
  492. }
  493. if !c.connection.User.HasPerm(dataprovider.PermDownload, path.Dir(filePath)) {
  494. c.connection.Log(logger.LevelWarn, "error downloading dir: %q, permission denied", filePath)
  495. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  496. return common.ErrPermissionDenied
  497. }
  498. if ok, policy := c.connection.User.IsFileAllowed(filePath); !ok {
  499. c.connection.Log(logger.LevelWarn, "reading file %q is not allowed", filePath)
  500. c.sendErrorMessage(fs, c.connection.GetErrorForDeniedFile(policy))
  501. return common.ErrPermissionDenied
  502. }
  503. if _, err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreDownload, p, filePath, 0, 0); err != nil {
  504. c.connection.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", filePath, err)
  505. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  506. return common.ErrPermissionDenied
  507. }
  508. file, r, cancelFn, err := fs.Open(p, 0)
  509. if err != nil {
  510. c.connection.Log(logger.LevelError, "could not open file %q for reading: %v", p, err)
  511. c.sendErrorMessage(fs, err)
  512. return err
  513. }
  514. baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath,
  515. common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota)
  516. t := newTransfer(baseTransfer, nil, r, nil)
  517. err = c.sendDownloadFileData(fs, p, stat, t)
  518. // we need to call Close anyway and return close error if any and
  519. // if we have no previous error
  520. if err == nil {
  521. err = t.Close()
  522. } else {
  523. t.TransferError(err)
  524. t.Close()
  525. }
  526. return err
  527. }
  528. func (c *scpCommand) sendFileTime() bool {
  529. return c.hasFlag("p")
  530. }
  531. func (c *scpCommand) isRecursive() bool {
  532. return c.hasFlag("r")
  533. }
  534. func (c *scpCommand) hasFlag(flag string) bool {
  535. for idx := 0; idx < len(c.args)-1; idx++ {
  536. arg := c.args[idx]
  537. if !strings.HasPrefix(arg, "--") && strings.HasPrefix(arg, "-") && strings.Contains(arg, flag) {
  538. return true
  539. }
  540. }
  541. return false
  542. }
  543. // read the SCP confirmation message and the optional text message
  544. // the channel will be closed on errors
  545. func (c *scpCommand) readConfirmationMessage() error {
  546. var msg strings.Builder
  547. buf := make([]byte, 1)
  548. n, err := c.connection.channel.Read(buf)
  549. if err != nil {
  550. c.connection.channel.Close()
  551. return err
  552. }
  553. if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
  554. isError := buf[0] == errMsg[0]
  555. for {
  556. n, err = c.connection.channel.Read(buf)
  557. readed := buf[:n]
  558. if err != nil || (n == 1 && readed[0] == newLine[0]) {
  559. break
  560. }
  561. if n > 0 {
  562. msg.Write(readed)
  563. }
  564. }
  565. c.connection.Log(logger.LevelInfo, "scp error message received: %v is error: %v", msg.String(), isError)
  566. err = fmt.Errorf("%v", msg.String())
  567. c.connection.channel.Close()
  568. }
  569. return err
  570. }
  571. // protool messages are newline terminated
  572. func (c *scpCommand) readProtocolMessage() (string, error) {
  573. var command strings.Builder
  574. var err error
  575. buf := make([]byte, 1)
  576. for {
  577. var n int
  578. n, err = c.connection.channel.Read(buf)
  579. if err != nil {
  580. break
  581. }
  582. if n > 0 {
  583. readed := buf[:n]
  584. if n == 1 && readed[0] == newLine[0] {
  585. break
  586. }
  587. command.Write(readed)
  588. }
  589. }
  590. if err != nil && !errors.Is(err, io.EOF) {
  591. c.connection.channel.Close()
  592. }
  593. return command.String(), err
  594. }
  595. // sendErrorMessage sends an error message and close the channel
  596. // we don't check write errors here, we have to close the channel anyway
  597. //
  598. //nolint:errcheck
  599. func (c *scpCommand) sendErrorMessage(fs vfs.Fs, err error) {
  600. c.connection.channel.Write(errMsg)
  601. if fs != nil {
  602. c.connection.channel.Write([]byte(c.connection.GetFsError(fs, err).Error()))
  603. } else {
  604. c.connection.channel.Write([]byte(err.Error()))
  605. }
  606. c.connection.channel.Write(newLine)
  607. c.connection.channel.Close()
  608. }
  609. // send scp confirmation message and close the channel if an error happen
  610. func (c *scpCommand) sendConfirmationMessage() error {
  611. _, err := c.connection.channel.Write(okMsg)
  612. if err != nil {
  613. c.connection.channel.Close()
  614. }
  615. return err
  616. }
  617. // sends a protocol message and close the channel on error
  618. func (c *scpCommand) sendProtocolMessage(message string) error {
  619. _, err := c.connection.channel.Write([]byte(message))
  620. if err != nil {
  621. c.connection.Log(logger.LevelError, "error sending protocol message: %v, err: %v", message, err)
  622. c.connection.channel.Close()
  623. }
  624. return err
  625. }
  626. // get the next upload protocol message ignoring T command if any
  627. func (c *scpCommand) getNextUploadProtocolMessage() (string, error) {
  628. var command string
  629. var err error
  630. for {
  631. command, err = c.readProtocolMessage()
  632. if err != nil {
  633. return command, err
  634. }
  635. if strings.HasPrefix(command, "T") {
  636. err = c.sendConfirmationMessage()
  637. if err != nil {
  638. return command, err
  639. }
  640. } else {
  641. break
  642. }
  643. }
  644. return command, err
  645. }
  646. func (c *scpCommand) createDir(fs vfs.Fs, dirPath string) error {
  647. err := fs.Mkdir(dirPath)
  648. if err != nil {
  649. c.connection.Log(logger.LevelError, "error creating dir %q: %v", dirPath, err)
  650. c.sendErrorMessage(fs, err)
  651. return err
  652. }
  653. vfs.SetPathPermissions(fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID())
  654. return err
  655. }
  656. // parse protocol messages such as:
  657. // D0755 0 testdir
  658. // or:
  659. // C0644 6 testfile
  660. // and returns file size and file/directory name
  661. func (c *scpCommand) parseUploadMessage(fs vfs.Fs, command string) (int64, string, error) {
  662. var size int64
  663. var name string
  664. var err error
  665. if !strings.HasPrefix(command, "C") && !strings.HasPrefix(command, "D") {
  666. err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v",
  667. command, c.args, c.connection.User.Username)
  668. c.connection.Log(logger.LevelError, "error: %v", err)
  669. c.sendErrorMessage(fs, err)
  670. return size, name, err
  671. }
  672. parts := strings.SplitN(command, " ", 3)
  673. if len(parts) == 3 {
  674. size, err = strconv.ParseInt(parts[1], 10, 64)
  675. if err != nil {
  676. c.connection.Log(logger.LevelError, "error getting size from upload message: %v", err)
  677. c.sendErrorMessage(fs, err)
  678. return size, name, err
  679. }
  680. name = parts[2]
  681. if name == "" {
  682. err = fmt.Errorf("error getting name from upload message, cannot be empty")
  683. c.connection.Log(logger.LevelError, "error: %v", err)
  684. c.sendErrorMessage(fs, err)
  685. return size, name, err
  686. }
  687. } else {
  688. err = fmt.Errorf("unable to split upload message: %q", command)
  689. c.connection.Log(logger.LevelError, "error: %v", err)
  690. c.sendErrorMessage(fs, err)
  691. return size, name, err
  692. }
  693. return size, name, err
  694. }
  695. func (c *scpCommand) getFileUploadDestPath(fs vfs.Fs, scpDestPath, fileName string) string {
  696. if !c.isRecursive() {
  697. // if the upload is not recursive and the destination path does not end with "/"
  698. // then scpDestPath is the wanted filename, for example:
  699. // scp fileName.txt [email protected]:/newFileName.txt
  700. // or
  701. // scp fileName.txt [email protected]:/fileName.txt
  702. if !strings.HasSuffix(scpDestPath, "/") {
  703. // but if scpDestPath is an existing directory then we put the uploaded file
  704. // inside that directory this is as scp command works, for example:
  705. // scp fileName.txt [email protected]:/existing_dir
  706. if p, err := fs.ResolvePath(scpDestPath); err == nil {
  707. if stat, err := fs.Stat(p); err == nil {
  708. if stat.IsDir() {
  709. return path.Join(scpDestPath, fileName)
  710. }
  711. }
  712. }
  713. return scpDestPath
  714. }
  715. }
  716. // if the upload is recursive or scpDestPath has the "/" suffix then the destination
  717. // file is relative to scpDestPath
  718. return path.Join(scpDestPath, fileName)
  719. }
  720. func getFileModeAsString(fileMode os.FileMode, isDir bool) string {
  721. var defaultMode string
  722. if isDir {
  723. defaultMode = "0755"
  724. } else {
  725. defaultMode = "0644"
  726. }
  727. if fileMode == 0 {
  728. return defaultMode
  729. }
  730. modeString := []byte(fileMode.String())
  731. nullPerm := []byte("-")
  732. u := 0
  733. g := 0
  734. o := 0
  735. s := 0
  736. lastChar := len(modeString) - 1
  737. if fileMode&os.ModeSticky != 0 {
  738. s++
  739. }
  740. if fileMode&os.ModeSetuid != 0 {
  741. s += 2
  742. }
  743. if fileMode&os.ModeSetgid != 0 {
  744. s += 4
  745. }
  746. if modeString[lastChar-8] != nullPerm[0] {
  747. u += 4
  748. }
  749. if modeString[lastChar-7] != nullPerm[0] {
  750. u += 2
  751. }
  752. if modeString[lastChar-6] != nullPerm[0] {
  753. u++
  754. }
  755. if modeString[lastChar-5] != nullPerm[0] {
  756. g += 4
  757. }
  758. if modeString[lastChar-4] != nullPerm[0] {
  759. g += 2
  760. }
  761. if modeString[lastChar-3] != nullPerm[0] {
  762. g++
  763. }
  764. if modeString[lastChar-2] != nullPerm[0] {
  765. o += 4
  766. }
  767. if modeString[lastChar-1] != nullPerm[0] {
  768. o += 2
  769. }
  770. if modeString[lastChar] != nullPerm[0] {
  771. o++
  772. }
  773. return fmt.Sprintf("%v%v%v%v", s, u, g, o)
  774. }