scp.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. package sftpd
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "math"
  7. "os"
  8. "path"
  9. "path/filepath"
  10. "runtime/debug"
  11. "strconv"
  12. "strings"
  13. "github.com/drakkan/sftpgo/v2/common"
  14. "github.com/drakkan/sftpgo/v2/dataprovider"
  15. "github.com/drakkan/sftpgo/v2/logger"
  16. "github.com/drakkan/sftpgo/v2/util"
  17. "github.com/drakkan/sftpgo/v2/vfs"
  18. )
  19. var (
  20. okMsg = []byte{0x00}
  21. warnMsg = []byte{0x01} // must be followed by an optional message and a newline
  22. errMsg = []byte{0x02} // must be followed by an optional message and a newline
  23. newLine = []byte{0x0A}
  24. )
  25. type scpCommand struct {
  26. sshCommand
  27. }
  28. func (c *scpCommand) handle() (err error) {
  29. defer func() {
  30. if r := recover(); r != nil {
  31. logger.Error(logSender, "", "panic in handle scp command: %#v stack trace: %v", r, string(debug.Stack()))
  32. err = common.ErrGenericFailure
  33. }
  34. }()
  35. if err := common.Connections.Add(c.connection); err != nil {
  36. logger.Info(logSender, "", "unable to add SCP connection: %v", err)
  37. return err
  38. }
  39. defer common.Connections.Remove(c.connection.GetID())
  40. destPath := c.getDestPath()
  41. commandType := c.getCommandType()
  42. c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %v command type: %v, dest path: %#v",
  43. c.args, c.connection.User.Username, commandType, destPath)
  44. if commandType == "-t" {
  45. // -t means "to", so upload
  46. err = c.sendConfirmationMessage()
  47. if err != nil {
  48. return err
  49. }
  50. err = c.handleRecursiveUpload()
  51. if err != nil {
  52. return err
  53. }
  54. } else if commandType == "-f" {
  55. // -f means "from" so download
  56. err = c.readConfirmationMessage()
  57. if err != nil {
  58. return err
  59. }
  60. err = c.handleDownload(destPath)
  61. if err != nil {
  62. return err
  63. }
  64. } else {
  65. err = fmt.Errorf("scp command not supported, args: %v", c.args)
  66. c.connection.Log(logger.LevelDebug, "unsupported scp command, args: %v", c.args)
  67. }
  68. c.sendExitStatus(err)
  69. return err
  70. }
  71. func (c *scpCommand) handleRecursiveUpload() error {
  72. numDirs := 0
  73. destPath := c.getDestPath()
  74. for {
  75. fs, err := c.connection.User.GetFilesystemForPath(destPath, c.connection.ID)
  76. if err != nil {
  77. c.connection.Log(logger.LevelError, "error uploading file %#v: %+v", destPath, err)
  78. c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", destPath))
  79. return err
  80. }
  81. command, err := c.getNextUploadProtocolMessage()
  82. if err != nil {
  83. if errors.Is(err, io.EOF) {
  84. return nil
  85. }
  86. c.sendErrorMessage(fs, err)
  87. return err
  88. }
  89. if strings.HasPrefix(command, "E") {
  90. numDirs--
  91. c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs)
  92. if numDirs < 0 {
  93. err = errors.New("unacceptable end dir command")
  94. c.sendErrorMessage(nil, err)
  95. return err
  96. }
  97. // the destination dir is now the parent directory
  98. destPath = path.Join(destPath, "..")
  99. } else {
  100. sizeToRead, name, err := c.parseUploadMessage(fs, command)
  101. if err != nil {
  102. return err
  103. }
  104. if strings.HasPrefix(command, "D") {
  105. numDirs++
  106. destPath = path.Join(destPath, name)
  107. fs, err = c.connection.User.GetFilesystemForPath(destPath, c.connection.ID)
  108. if err != nil {
  109. c.connection.Log(logger.LevelError, "error uploading file %#v: %+v", destPath, err)
  110. c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", destPath))
  111. return err
  112. }
  113. err = c.handleCreateDir(fs, destPath)
  114. if err != nil {
  115. return err
  116. }
  117. c.connection.Log(logger.LevelDebug, "received start dir command, num dirs: %v destPath: %#v", numDirs, destPath)
  118. } else if strings.HasPrefix(command, "C") {
  119. err = c.handleUpload(c.getFileUploadDestPath(fs, destPath, name), sizeToRead)
  120. if err != nil {
  121. return err
  122. }
  123. }
  124. }
  125. err = c.sendConfirmationMessage()
  126. if err != nil {
  127. return err
  128. }
  129. }
  130. }
  131. func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error {
  132. c.connection.UpdateLastActivity()
  133. p, err := fs.ResolvePath(dirPath)
  134. if err != nil {
  135. c.connection.Log(logger.LevelError, "error creating dir: %#v, invalid file path, err: %v", dirPath, err)
  136. c.sendErrorMessage(fs, err)
  137. return err
  138. }
  139. if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(dirPath)) {
  140. c.connection.Log(logger.LevelError, "error creating dir: %#v, permission denied", dirPath)
  141. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  142. return common.ErrPermissionDenied
  143. }
  144. info, err := c.connection.DoStat(dirPath, 1, true)
  145. if err == nil && info.IsDir() {
  146. return nil
  147. }
  148. err = c.createDir(fs, p)
  149. if err != nil {
  150. return err
  151. }
  152. c.connection.Log(logger.LevelDebug, "created dir %#v", dirPath)
  153. return nil
  154. }
  155. // we need to close the transfer if we have an error
  156. func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) error {
  157. err := c.sendConfirmationMessage()
  158. if err != nil {
  159. transfer.TransferError(err)
  160. transfer.Close()
  161. return err
  162. }
  163. if sizeToRead > 0 {
  164. // we could replace this method with io.CopyN implementing "Write" method in transfer struct
  165. remaining := sizeToRead
  166. buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
  167. for {
  168. n, err := c.connection.channel.Read(buf)
  169. if err != nil {
  170. c.sendErrorMessage(transfer.Fs, err)
  171. transfer.TransferError(err)
  172. transfer.Close()
  173. return err
  174. }
  175. _, err = transfer.WriteAt(buf[:n], sizeToRead-remaining)
  176. if err != nil {
  177. c.sendErrorMessage(transfer.Fs, err)
  178. transfer.Close()
  179. return err
  180. }
  181. remaining -= int64(n)
  182. if remaining <= 0 {
  183. break
  184. }
  185. if remaining < int64(len(buf)) {
  186. buf = make([]byte, remaining)
  187. }
  188. }
  189. }
  190. err = c.readConfirmationMessage()
  191. if err != nil {
  192. transfer.TransferError(err)
  193. transfer.Close()
  194. return err
  195. }
  196. err = transfer.Close()
  197. if err != nil {
  198. c.sendErrorMessage(transfer.Fs, err)
  199. return err
  200. }
  201. return nil
  202. }
  203. func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
  204. diskQuota, transferQuota := c.connection.HasSpace(isNewFile, false, requestPath)
  205. if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() {
  206. err := fmt.Errorf("denying file write due to quota limits")
  207. c.connection.Log(logger.LevelError, "error uploading file: %#v, err: %v", filePath, err)
  208. c.sendErrorMessage(nil, err)
  209. return err
  210. }
  211. err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath,
  212. fileSize, os.O_TRUNC)
  213. if err != nil {
  214. c.connection.Log(logger.LevelDebug, "upload for file %#v denied by pre action: %v", requestPath, err)
  215. err = c.connection.GetPermissionDeniedError()
  216. c.sendErrorMessage(fs, err)
  217. return err
  218. }
  219. maxWriteSize, _ := c.connection.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported())
  220. file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
  221. if err != nil {
  222. c.connection.Log(logger.LevelError, "error creating file %#v: %v", resolvedPath, err)
  223. c.sendErrorMessage(fs, err)
  224. return err
  225. }
  226. initialSize := int64(0)
  227. truncatedSize := int64(0) // bytes truncated and not included in quota
  228. if !isNewFile {
  229. if vfs.IsLocalOrSFTPFs(fs) {
  230. vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
  231. if err == nil {
  232. dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
  233. if vfolder.IsIncludedInUserQuota() {
  234. dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
  235. }
  236. } else {
  237. dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
  238. }
  239. } else {
  240. initialSize = fileSize
  241. truncatedSize = initialSize
  242. }
  243. if maxWriteSize > 0 {
  244. maxWriteSize += fileSize
  245. }
  246. }
  247. vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
  248. baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
  249. common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota)
  250. t := newTransfer(baseTransfer, w, nil, nil)
  251. return c.getUploadFileData(sizeToRead, t)
  252. }
  253. func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error {
  254. c.connection.UpdateLastActivity()
  255. fs, p, err := c.connection.GetFsAndResolvedPath(uploadFilePath)
  256. if err != nil {
  257. c.connection.Log(logger.LevelError, "error uploading file: %#v, err: %v", uploadFilePath, err)
  258. c.sendErrorMessage(nil, err)
  259. return err
  260. }
  261. if ok, _ := c.connection.User.IsFileAllowed(uploadFilePath); !ok {
  262. c.connection.Log(logger.LevelWarn, "writing file %#v is not allowed", uploadFilePath)
  263. c.sendErrorMessage(fs, c.connection.GetPermissionDeniedError())
  264. return common.ErrPermissionDenied
  265. }
  266. filePath := p
  267. if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
  268. filePath = fs.GetAtomicUploadPath(p)
  269. }
  270. stat, statErr := fs.Lstat(p)
  271. if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) {
  272. if !c.connection.User.HasPerm(dataprovider.PermUpload, path.Dir(uploadFilePath)) {
  273. c.connection.Log(logger.LevelWarn, "cannot upload file: %#v, permission denied", uploadFilePath)
  274. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  275. return common.ErrPermissionDenied
  276. }
  277. return c.handleUploadFile(fs, p, filePath, sizeToRead, true, 0, uploadFilePath)
  278. }
  279. if statErr != nil {
  280. c.connection.Log(logger.LevelError, "error performing file stat %#v: %v", p, statErr)
  281. c.sendErrorMessage(fs, statErr)
  282. return statErr
  283. }
  284. if stat.IsDir() {
  285. c.connection.Log(logger.LevelError, "attempted to open a directory for writing to: %#v", p)
  286. err = fmt.Errorf("attempted to open a directory for writing: %#v", p)
  287. c.sendErrorMessage(fs, err)
  288. return err
  289. }
  290. if !c.connection.User.HasPerm(dataprovider.PermOverwrite, uploadFilePath) {
  291. c.connection.Log(logger.LevelWarn, "cannot overwrite file: %#v, permission denied", uploadFilePath)
  292. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  293. return common.ErrPermissionDenied
  294. }
  295. if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
  296. err = fs.Rename(p, filePath)
  297. if err != nil {
  298. c.connection.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %v",
  299. p, filePath, err)
  300. c.sendErrorMessage(fs, err)
  301. return err
  302. }
  303. }
  304. return c.handleUploadFile(fs, p, filePath, sizeToRead, false, stat.Size(), uploadFilePath)
  305. }
  306. func (c *scpCommand) sendDownloadProtocolMessages(virtualDirPath string, stat os.FileInfo) error {
  307. var err error
  308. if c.sendFileTime() {
  309. modTime := stat.ModTime().UnixNano() / 1000000000
  310. tCommand := fmt.Sprintf("T%v 0 %v 0\n", modTime, modTime)
  311. err = c.sendProtocolMessage(tCommand)
  312. if err != nil {
  313. return err
  314. }
  315. err = c.readConfirmationMessage()
  316. if err != nil {
  317. return err
  318. }
  319. }
  320. dirName := path.Base(virtualDirPath)
  321. if dirName == "/" || dirName == "." {
  322. dirName = c.connection.User.Username
  323. }
  324. fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), dirName)
  325. err = c.sendProtocolMessage(fileMode)
  326. if err != nil {
  327. return err
  328. }
  329. err = c.readConfirmationMessage()
  330. return err
  331. }
  332. // We send first all the files in the root directory and then the directories.
  333. // For each directory we recursively call this method again
  334. func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath string, stat os.FileInfo) error {
  335. var err error
  336. if c.isRecursive() {
  337. c.connection.Log(logger.LevelDebug, "recursive download, dir path %#v virtual path %#v", dirPath, virtualPath)
  338. err = c.sendDownloadProtocolMessages(virtualPath, stat)
  339. if err != nil {
  340. return err
  341. }
  342. files, err := fs.ReadDir(dirPath)
  343. if err != nil {
  344. c.sendErrorMessage(fs, err)
  345. return err
  346. }
  347. files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath))
  348. var dirs []string
  349. for _, file := range files {
  350. filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name()))
  351. if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 {
  352. err = c.handleDownload(filePath)
  353. if err != nil {
  354. break
  355. }
  356. } else if file.IsDir() {
  357. dirs = append(dirs, filePath)
  358. }
  359. }
  360. if err != nil {
  361. c.sendErrorMessage(fs, err)
  362. return err
  363. }
  364. for _, dir := range dirs {
  365. err = c.handleDownload(dir)
  366. if err != nil {
  367. break
  368. }
  369. }
  370. if err != nil {
  371. c.sendErrorMessage(fs, err)
  372. return err
  373. }
  374. err = c.sendProtocolMessage("E\n")
  375. if err != nil {
  376. return err
  377. }
  378. return c.readConfirmationMessage()
  379. }
  380. err = fmt.Errorf("unable to send directory for non recursive copy")
  381. c.sendErrorMessage(nil, err)
  382. return err
  383. }
  384. func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error {
  385. var err error
  386. if c.sendFileTime() {
  387. modTime := stat.ModTime().UnixNano() / 1000000000
  388. tCommand := fmt.Sprintf("T%v 0 %v 0\n", modTime, modTime)
  389. err = c.sendProtocolMessage(tCommand)
  390. if err != nil {
  391. return err
  392. }
  393. err = c.readConfirmationMessage()
  394. if err != nil {
  395. return err
  396. }
  397. }
  398. if vfs.IsCryptOsFs(fs) {
  399. stat = fs.(*vfs.CryptFs).ConvertFileInfo(stat)
  400. }
  401. fileSize := stat.Size()
  402. readed := int64(0)
  403. fileMode := fmt.Sprintf("C%v %v %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), fileSize, filepath.Base(filePath))
  404. err = c.sendProtocolMessage(fileMode)
  405. if err != nil {
  406. return err
  407. }
  408. err = c.readConfirmationMessage()
  409. if err != nil {
  410. return err
  411. }
  412. // we could replace this method with io.CopyN implementing "Read" method in transfer struct
  413. buf := make([]byte, 32768)
  414. var n int
  415. for {
  416. n, err = transfer.ReadAt(buf, readed)
  417. if err == nil || err == io.EOF {
  418. if n > 0 {
  419. _, err = c.connection.channel.Write(buf[:n])
  420. }
  421. }
  422. readed += int64(n)
  423. if err != nil {
  424. break
  425. }
  426. }
  427. if err != io.EOF {
  428. c.sendErrorMessage(fs, err)
  429. return err
  430. }
  431. err = c.sendConfirmationMessage()
  432. if err != nil {
  433. return err
  434. }
  435. err = c.readConfirmationMessage()
  436. return err
  437. }
  438. func (c *scpCommand) handleDownload(filePath string) error {
  439. c.connection.UpdateLastActivity()
  440. transferQuota := c.connection.GetTransferQuota()
  441. if !transferQuota.HasDownloadSpace() {
  442. c.connection.Log(logger.LevelInfo, "denying file read due to quota limits")
  443. c.sendErrorMessage(nil, c.connection.GetReadQuotaExceededError())
  444. return c.connection.GetReadQuotaExceededError()
  445. }
  446. var err error
  447. fs, err := c.connection.User.GetFilesystemForPath(filePath, c.connection.ID)
  448. if err != nil {
  449. c.connection.Log(logger.LevelError, "error downloading file %#v: %+v", filePath, err)
  450. c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", filePath))
  451. return err
  452. }
  453. p, err := fs.ResolvePath(filePath)
  454. if err != nil {
  455. err := fmt.Errorf("invalid file path %#v", filePath)
  456. c.connection.Log(logger.LevelError, "error downloading file: %#v, invalid file path", filePath)
  457. c.sendErrorMessage(fs, err)
  458. return err
  459. }
  460. var stat os.FileInfo
  461. if stat, err = fs.Stat(p); err != nil {
  462. c.connection.Log(logger.LevelError, "error downloading file: %#v->%#v, err: %v", filePath, p, err)
  463. c.sendErrorMessage(fs, err)
  464. return err
  465. }
  466. if stat.IsDir() {
  467. if !c.connection.User.HasPerm(dataprovider.PermDownload, filePath) {
  468. c.connection.Log(logger.LevelWarn, "error downloading dir: %#v, permission denied", filePath)
  469. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  470. return common.ErrPermissionDenied
  471. }
  472. err = c.handleRecursiveDownload(fs, p, filePath, stat)
  473. return err
  474. }
  475. if !c.connection.User.HasPerm(dataprovider.PermDownload, path.Dir(filePath)) {
  476. c.connection.Log(logger.LevelWarn, "error downloading dir: %#v, permission denied", filePath)
  477. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  478. return common.ErrPermissionDenied
  479. }
  480. if ok, policy := c.connection.User.IsFileAllowed(filePath); !ok {
  481. c.connection.Log(logger.LevelWarn, "reading file %#v is not allowed", filePath)
  482. c.sendErrorMessage(fs, c.connection.GetErrorForDeniedFile(policy))
  483. return common.ErrPermissionDenied
  484. }
  485. if err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreDownload, p, filePath, 0, 0); err != nil {
  486. c.connection.Log(logger.LevelDebug, "download for file %#v denied by pre action: %v", filePath, err)
  487. c.sendErrorMessage(fs, common.ErrPermissionDenied)
  488. return common.ErrPermissionDenied
  489. }
  490. file, r, cancelFn, err := fs.Open(p, 0)
  491. if err != nil {
  492. c.connection.Log(logger.LevelError, "could not open file %#v for reading: %v", p, err)
  493. c.sendErrorMessage(fs, err)
  494. return err
  495. }
  496. baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath,
  497. common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota)
  498. t := newTransfer(baseTransfer, nil, r, nil)
  499. err = c.sendDownloadFileData(fs, p, stat, t)
  500. // we need to call Close anyway and return close error if any and
  501. // if we have no previous error
  502. if err == nil {
  503. err = t.Close()
  504. } else {
  505. t.TransferError(err)
  506. t.Close()
  507. }
  508. return err
  509. }
  510. func (c *scpCommand) getCommandType() string {
  511. return c.args[len(c.args)-2]
  512. }
  513. func (c *scpCommand) sendFileTime() bool {
  514. return util.IsStringInSlice("-p", c.args)
  515. }
  516. func (c *scpCommand) isRecursive() bool {
  517. return util.IsStringInSlice("-r", c.args)
  518. }
  519. // read the SCP confirmation message and the optional text message
  520. // the channel will be closed on errors
  521. func (c *scpCommand) readConfirmationMessage() error {
  522. var msg strings.Builder
  523. buf := make([]byte, 1)
  524. n, err := c.connection.channel.Read(buf)
  525. if err != nil {
  526. c.connection.channel.Close()
  527. return err
  528. }
  529. if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
  530. isError := buf[0] == errMsg[0]
  531. for {
  532. n, err = c.connection.channel.Read(buf)
  533. readed := buf[:n]
  534. if err != nil || (n == 1 && readed[0] == newLine[0]) {
  535. break
  536. }
  537. if n > 0 {
  538. msg.Write(readed)
  539. }
  540. }
  541. c.connection.Log(logger.LevelInfo, "scp error message received: %v is error: %v", msg.String(), isError)
  542. err = fmt.Errorf("%v", msg.String())
  543. c.connection.channel.Close()
  544. }
  545. return err
  546. }
  547. // protool messages are newline terminated
  548. func (c *scpCommand) readProtocolMessage() (string, error) {
  549. var command strings.Builder
  550. var err error
  551. buf := make([]byte, 1)
  552. for {
  553. var n int
  554. n, err = c.connection.channel.Read(buf)
  555. if err != nil {
  556. break
  557. }
  558. if n > 0 {
  559. readed := buf[:n]
  560. if n == 1 && readed[0] == newLine[0] {
  561. break
  562. }
  563. command.Write(readed)
  564. }
  565. }
  566. if err != nil && !errors.Is(err, io.EOF) {
  567. c.connection.channel.Close()
  568. }
  569. return command.String(), err
  570. }
  571. // send an error message and close the channel
  572. //nolint:errcheck // we don't check write errors here, we have to close the channel anyway
  573. func (c *scpCommand) sendErrorMessage(fs vfs.Fs, err error) {
  574. c.connection.channel.Write(errMsg)
  575. if fs != nil {
  576. c.connection.channel.Write([]byte(c.connection.GetFsError(fs, err).Error()))
  577. } else {
  578. c.connection.channel.Write([]byte(err.Error()))
  579. }
  580. c.connection.channel.Write(newLine)
  581. c.connection.channel.Close()
  582. }
  583. // send scp confirmation message and close the channel if an error happen
  584. func (c *scpCommand) sendConfirmationMessage() error {
  585. _, err := c.connection.channel.Write(okMsg)
  586. if err != nil {
  587. c.connection.channel.Close()
  588. }
  589. return err
  590. }
  591. // sends a protocol message and close the channel on error
  592. func (c *scpCommand) sendProtocolMessage(message string) error {
  593. _, err := c.connection.channel.Write([]byte(message))
  594. if err != nil {
  595. c.connection.Log(logger.LevelError, "error sending protocol message: %v, err: %v", message, err)
  596. c.connection.channel.Close()
  597. }
  598. return err
  599. }
  600. // get the next upload protocol message ignoring T command if any
  601. func (c *scpCommand) getNextUploadProtocolMessage() (string, error) {
  602. var command string
  603. var err error
  604. for {
  605. command, err = c.readProtocolMessage()
  606. if err != nil {
  607. return command, err
  608. }
  609. if strings.HasPrefix(command, "T") {
  610. err = c.sendConfirmationMessage()
  611. if err != nil {
  612. return command, err
  613. }
  614. } else {
  615. break
  616. }
  617. }
  618. return command, err
  619. }
  620. func (c *scpCommand) createDir(fs vfs.Fs, dirPath string) error {
  621. err := fs.Mkdir(dirPath)
  622. if err != nil {
  623. c.connection.Log(logger.LevelError, "error creating dir %#v: %v", dirPath, err)
  624. c.sendErrorMessage(fs, err)
  625. return err
  626. }
  627. vfs.SetPathPermissions(fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID())
  628. return err
  629. }
  630. // parse protocol messages such as:
  631. // D0755 0 testdir
  632. // or:
  633. // C0644 6 testfile
  634. // and returns file size and file/directory name
  635. func (c *scpCommand) parseUploadMessage(fs vfs.Fs, command string) (int64, string, error) {
  636. var size int64
  637. var name string
  638. var err error
  639. if !strings.HasPrefix(command, "C") && !strings.HasPrefix(command, "D") {
  640. err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v",
  641. command, c.args, c.connection.User.Username)
  642. c.connection.Log(logger.LevelError, "error: %v", err)
  643. c.sendErrorMessage(fs, err)
  644. return size, name, err
  645. }
  646. parts := strings.SplitN(command, " ", 3)
  647. if len(parts) == 3 {
  648. size, err = strconv.ParseInt(parts[1], 10, 64)
  649. if err != nil {
  650. c.connection.Log(logger.LevelError, "error getting size from upload message: %v", err)
  651. c.sendErrorMessage(fs, err)
  652. return size, name, err
  653. }
  654. name = parts[2]
  655. if name == "" {
  656. err = fmt.Errorf("error getting name from upload message, cannot be empty")
  657. c.connection.Log(logger.LevelError, "error: %v", err)
  658. c.sendErrorMessage(fs, err)
  659. return size, name, err
  660. }
  661. } else {
  662. err = fmt.Errorf("unable to split upload message: %#v", command)
  663. c.connection.Log(logger.LevelError, "error: %v", err)
  664. c.sendErrorMessage(fs, err)
  665. return size, name, err
  666. }
  667. return size, name, err
  668. }
  669. func (c *scpCommand) getFileUploadDestPath(fs vfs.Fs, scpDestPath, fileName string) string {
  670. if !c.isRecursive() {
  671. // if the upload is not recursive and the destination path does not end with "/"
  672. // then scpDestPath is the wanted filename, for example:
  673. // scp fileName.txt [email protected]:/newFileName.txt
  674. // or
  675. // scp fileName.txt [email protected]:/fileName.txt
  676. if !strings.HasSuffix(scpDestPath, "/") {
  677. // but if scpDestPath is an existing directory then we put the uploaded file
  678. // inside that directory this is as scp command works, for example:
  679. // scp fileName.txt [email protected]:/existing_dir
  680. if p, err := fs.ResolvePath(scpDestPath); err == nil {
  681. if stat, err := fs.Stat(p); err == nil {
  682. if stat.IsDir() {
  683. return path.Join(scpDestPath, fileName)
  684. }
  685. }
  686. }
  687. return scpDestPath
  688. }
  689. }
  690. // if the upload is recursive or scpDestPath has the "/" suffix then the destination
  691. // file is relative to scpDestPath
  692. return path.Join(scpDestPath, fileName)
  693. }
  694. func getFileModeAsString(fileMode os.FileMode, isDir bool) string {
  695. var defaultMode string
  696. if isDir {
  697. defaultMode = "0755"
  698. } else {
  699. defaultMode = "0644"
  700. }
  701. if fileMode == 0 {
  702. return defaultMode
  703. }
  704. modeString := []byte(fileMode.String())
  705. nullPerm := []byte("-")
  706. u := 0
  707. g := 0
  708. o := 0
  709. s := 0
  710. lastChar := len(modeString) - 1
  711. if fileMode&os.ModeSticky != 0 {
  712. s++
  713. }
  714. if fileMode&os.ModeSetuid != 0 {
  715. s += 2
  716. }
  717. if fileMode&os.ModeSetgid != 0 {
  718. s += 4
  719. }
  720. if modeString[lastChar-8] != nullPerm[0] {
  721. u += 4
  722. }
  723. if modeString[lastChar-7] != nullPerm[0] {
  724. u += 2
  725. }
  726. if modeString[lastChar-6] != nullPerm[0] {
  727. u++
  728. }
  729. if modeString[lastChar-5] != nullPerm[0] {
  730. g += 4
  731. }
  732. if modeString[lastChar-4] != nullPerm[0] {
  733. g += 2
  734. }
  735. if modeString[lastChar-3] != nullPerm[0] {
  736. g++
  737. }
  738. if modeString[lastChar-2] != nullPerm[0] {
  739. o += 4
  740. }
  741. if modeString[lastChar-1] != nullPerm[0] {
  742. o += 2
  743. }
  744. if modeString[lastChar] != nullPerm[0] {
  745. o++
  746. }
  747. return fmt.Sprintf("%v%v%v%v", s, u, g, o)
  748. }