internal_test.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. package sftpd
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net"
  8. "os"
  9. "runtime"
  10. "testing"
  11. "time"
  12. "github.com/drakkan/sftpgo/dataprovider"
  13. "github.com/drakkan/sftpgo/utils"
  14. "github.com/pkg/sftp"
  15. )
  16. type MockChannel struct {
  17. Buffer *bytes.Buffer
  18. StdErrBuffer *bytes.Buffer
  19. ReadError error
  20. WriteError error
  21. }
  22. func (c *MockChannel) Read(data []byte) (int, error) {
  23. if c.ReadError != nil {
  24. return 0, c.ReadError
  25. }
  26. return c.Buffer.Read(data)
  27. }
  28. func (c *MockChannel) Write(data []byte) (int, error) {
  29. if c.WriteError != nil {
  30. return 0, c.WriteError
  31. }
  32. return c.Buffer.Write(data)
  33. }
  34. func (c *MockChannel) Close() error {
  35. return nil
  36. }
  37. func (c *MockChannel) CloseWrite() error {
  38. return nil
  39. }
  40. func (c *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
  41. return true, nil
  42. }
  43. func (c *MockChannel) Stderr() io.ReadWriter {
  44. return c.StdErrBuffer
  45. }
  46. func TestWrongActions(t *testing.T) {
  47. actionsCopy := actions
  48. badCommand := "/bad/command"
  49. if runtime.GOOS == "windows" {
  50. badCommand = "C:\\bad\\command"
  51. }
  52. actions = Actions{
  53. ExecuteOn: []string{operationDownload},
  54. Command: badCommand,
  55. HTTPNotificationURL: "",
  56. }
  57. err := executeAction(operationDownload, "username", "path", "")
  58. if err == nil {
  59. t.Errorf("action with bad command must fail")
  60. }
  61. err = executeAction(operationDelete, "username", "path", "")
  62. if err != nil {
  63. t.Errorf("action not configured must silently fail")
  64. }
  65. actions.Command = ""
  66. actions.HTTPNotificationURL = "http://foo\x7f.com/"
  67. err = executeAction(operationDownload, "username", "path", "")
  68. if err == nil {
  69. t.Errorf("action with bad url must fail")
  70. }
  71. actions = actionsCopy
  72. }
  73. func TestRemoveNonexistentTransfer(t *testing.T) {
  74. transfer := Transfer{}
  75. err := removeTransfer(&transfer)
  76. if err == nil {
  77. t.Errorf("remove nonexistent transfer must fail")
  78. }
  79. }
  80. func TestRemoveNonexistentQuotaScan(t *testing.T) {
  81. err := RemoveQuotaScan("username")
  82. if err == nil {
  83. t.Errorf("remove nonexistent transfer must fail")
  84. }
  85. }
  86. func TestGetOSOpenFlags(t *testing.T) {
  87. var flags sftp.FileOpenFlags
  88. flags.Write = true
  89. flags.Excl = true
  90. osFlags := getOSOpenFlags(flags)
  91. if osFlags&os.O_WRONLY == 0 || osFlags&os.O_EXCL == 0 {
  92. t.Errorf("error getting os flags from sftp file open flags")
  93. }
  94. flags.Append = true
  95. // append flag should be ignored to allow resume
  96. if osFlags&os.O_WRONLY == 0 || osFlags&os.O_EXCL == 0 {
  97. t.Errorf("error getting os flags from sftp file open flags")
  98. }
  99. }
  100. func TestUploadResumeInvalidOffset(t *testing.T) {
  101. testfile := "testfile"
  102. file, _ := os.Create(testfile)
  103. transfer := Transfer{
  104. file: file,
  105. path: file.Name(),
  106. start: time.Now(),
  107. bytesSent: 0,
  108. bytesReceived: 0,
  109. user: dataprovider.User{
  110. Username: "testuser",
  111. },
  112. connectionID: "",
  113. transferType: transferUpload,
  114. lastActivity: time.Now(),
  115. isNewFile: false,
  116. protocol: protocolSFTP,
  117. transferError: nil,
  118. isFinished: false,
  119. minWriteOffset: 10,
  120. }
  121. _, err := transfer.WriteAt([]byte("test"), 0)
  122. if err == nil {
  123. t.Errorf("upload with invalid offset must fail")
  124. }
  125. os.Remove(testfile)
  126. }
  127. func TestUploadFiles(t *testing.T) {
  128. oldUploadMode := uploadMode
  129. uploadMode = uploadModeAtomic
  130. c := Connection{}
  131. var flags sftp.FileOpenFlags
  132. flags.Write = true
  133. flags.Trunc = true
  134. _, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0)
  135. if err == nil {
  136. t.Errorf("upload to existing file must fail if one or both paths are invalid")
  137. }
  138. uploadMode = uploadModeStandard
  139. _, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0)
  140. if err == nil {
  141. t.Errorf("upload to existing file must fail if one or both paths are invalid")
  142. }
  143. missingFile := "missing/relative/file.txt"
  144. if runtime.GOOS == "windows" {
  145. missingFile = "missing\\relative\\file.txt"
  146. }
  147. _, err = c.handleSFTPUploadToNewFile(".", missingFile)
  148. if err == nil {
  149. t.Errorf("upload new file in missing path must fail")
  150. }
  151. uploadMode = oldUploadMode
  152. }
  153. func TestWithInvalidHome(t *testing.T) {
  154. u := dataprovider.User{}
  155. u.HomeDir = "home_rel_path"
  156. _, err := loginUser(u, "password")
  157. if err == nil {
  158. t.Errorf("login a user with an invalid home_dir must fail")
  159. }
  160. c := Connection{
  161. User: u,
  162. }
  163. err = c.isSubDir("dir_rel_path")
  164. if err == nil {
  165. t.Errorf("tested path is not a home subdir")
  166. }
  167. }
  168. func TestSFTPCmdTargetPath(t *testing.T) {
  169. u := dataprovider.User{}
  170. u.HomeDir = "home_rel_path"
  171. u.Username = "test"
  172. u.Permissions = []string{"*"}
  173. connection := Connection{
  174. User: u,
  175. }
  176. _, err := connection.getSFTPCmdTargetPath("invalid_path")
  177. if err != sftp.ErrSshFxOpUnsupported {
  178. t.Errorf("getSFTPCmdTargetPath must fal with the expected error: %v", err)
  179. }
  180. }
  181. func TestSFTPGetUsedQuota(t *testing.T) {
  182. u := dataprovider.User{}
  183. u.HomeDir = "home_rel_path"
  184. u.Username = "test_invalid_user"
  185. u.QuotaSize = 4096
  186. u.QuotaFiles = 1
  187. u.Permissions = []string{"*"}
  188. connection := Connection{
  189. User: u,
  190. }
  191. res := connection.hasSpace(false)
  192. if res != false {
  193. t.Errorf("has space must return false if the user is invalid")
  194. }
  195. }
  196. func TestSCPFileMode(t *testing.T) {
  197. mode := getFileModeAsString(0, true)
  198. if mode != "0755" {
  199. t.Errorf("invalid file mode: %v expected: 0755", mode)
  200. }
  201. mode = getFileModeAsString(0700, true)
  202. if mode != "0700" {
  203. t.Errorf("invalid file mode: %v expected: 0700", mode)
  204. }
  205. mode = getFileModeAsString(0750, true)
  206. if mode != "0750" {
  207. t.Errorf("invalid file mode: %v expected: 0750", mode)
  208. }
  209. mode = getFileModeAsString(0777, true)
  210. if mode != "0777" {
  211. t.Errorf("invalid file mode: %v expected: 0777", mode)
  212. }
  213. mode = getFileModeAsString(0640, false)
  214. if mode != "0640" {
  215. t.Errorf("invalid file mode: %v expected: 0640", mode)
  216. }
  217. mode = getFileModeAsString(0600, false)
  218. if mode != "0600" {
  219. t.Errorf("invalid file mode: %v expected: 0600", mode)
  220. }
  221. mode = getFileModeAsString(0, false)
  222. if mode != "0644" {
  223. t.Errorf("invalid file mode: %v expected: 0644", mode)
  224. }
  225. fileMode := uint32(0777)
  226. fileMode = fileMode | uint32(os.ModeSetgid)
  227. fileMode = fileMode | uint32(os.ModeSetuid)
  228. fileMode = fileMode | uint32(os.ModeSticky)
  229. mode = getFileModeAsString(os.FileMode(fileMode), false)
  230. if mode != "7777" {
  231. t.Errorf("invalid file mode: %v expected: 7777", mode)
  232. }
  233. fileMode = uint32(0644)
  234. fileMode = fileMode | uint32(os.ModeSetgid)
  235. mode = getFileModeAsString(os.FileMode(fileMode), false)
  236. if mode != "4644" {
  237. t.Errorf("invalid file mode: %v expected: 4644", mode)
  238. }
  239. fileMode = uint32(0600)
  240. fileMode = fileMode | uint32(os.ModeSetuid)
  241. mode = getFileModeAsString(os.FileMode(fileMode), false)
  242. if mode != "2600" {
  243. t.Errorf("invalid file mode: %v expected: 2600", mode)
  244. }
  245. fileMode = uint32(0044)
  246. fileMode = fileMode | uint32(os.ModeSticky)
  247. mode = getFileModeAsString(os.FileMode(fileMode), false)
  248. if mode != "1044" {
  249. t.Errorf("invalid file mode: %v expected: 1044", mode)
  250. }
  251. }
  252. func TestSCPGetNonExistingDirContent(t *testing.T) {
  253. _, err := getDirContents("non_existing")
  254. if err == nil {
  255. t.Errorf("get non existing dir contents must fail")
  256. }
  257. }
  258. func TestSCPParseUploadMessage(t *testing.T) {
  259. buf := make([]byte, 65535)
  260. stdErrBuf := make([]byte, 65535)
  261. mockSSHChannel := MockChannel{
  262. Buffer: bytes.NewBuffer(buf),
  263. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  264. ReadError: nil,
  265. }
  266. connection := Connection{
  267. channel: &mockSSHChannel,
  268. }
  269. scpCommand := scpCommand{
  270. connection: connection,
  271. args: []string{"-t", "/tmp"},
  272. }
  273. _, _, err := scpCommand.parseUploadMessage("invalid")
  274. if err == nil {
  275. t.Errorf("parsing invalid upload message must fail")
  276. }
  277. _, _, err = scpCommand.parseUploadMessage("D0755 0")
  278. if err == nil {
  279. t.Errorf("parsing incomplete upload message must fail")
  280. }
  281. _, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir")
  282. if err == nil {
  283. t.Errorf("parsing upload message with invalid size must fail")
  284. }
  285. _, _, err = scpCommand.parseUploadMessage("D0755 0 ")
  286. if err == nil {
  287. t.Errorf("parsing upload message with invalid name must fail")
  288. }
  289. }
  290. func TestSCPProtocolMessages(t *testing.T) {
  291. buf := make([]byte, 65535)
  292. stdErrBuf := make([]byte, 65535)
  293. readErr := fmt.Errorf("test read error")
  294. writeErr := fmt.Errorf("test write error")
  295. mockSSHChannel := MockChannel{
  296. Buffer: bytes.NewBuffer(buf),
  297. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  298. ReadError: readErr,
  299. WriteError: writeErr,
  300. }
  301. connection := Connection{
  302. channel: &mockSSHChannel,
  303. }
  304. scpCommand := scpCommand{
  305. connection: connection,
  306. args: []string{"-t", "/tmp"},
  307. }
  308. _, err := scpCommand.readProtocolMessage()
  309. if err == nil || err != readErr {
  310. t.Errorf("read protocol message must fail, we are sending a fake error")
  311. }
  312. err = scpCommand.sendConfirmationMessage()
  313. if err != writeErr {
  314. t.Errorf("write confirmation message must fail, we are sending a fake error")
  315. }
  316. err = scpCommand.sendProtocolMessage("E\n")
  317. if err != writeErr {
  318. t.Errorf("write confirmation message must fail, we are sending a fake error")
  319. }
  320. _, err = scpCommand.getNextUploadProtocolMessage()
  321. if err == nil || err != readErr {
  322. t.Errorf("read next upload protocol message must fail, we are sending a fake read error")
  323. }
  324. mockSSHChannel = MockChannel{
  325. Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")),
  326. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  327. ReadError: nil,
  328. WriteError: writeErr,
  329. }
  330. scpCommand.connection.channel = &mockSSHChannel
  331. _, err = scpCommand.getNextUploadProtocolMessage()
  332. if err == nil || err != writeErr {
  333. t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
  334. }
  335. respBuffer := []byte{0x02}
  336. protocolErrorMsg := "protocol error msg"
  337. respBuffer = append(respBuffer, protocolErrorMsg...)
  338. respBuffer = append(respBuffer, 0x0A)
  339. mockSSHChannel = MockChannel{
  340. Buffer: bytes.NewBuffer(respBuffer),
  341. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  342. ReadError: nil,
  343. WriteError: nil,
  344. }
  345. scpCommand.connection.channel = &mockSSHChannel
  346. err = scpCommand.readConfirmationMessage()
  347. if err == nil || err.Error() != protocolErrorMsg {
  348. t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
  349. }
  350. }
  351. func TestSCPTestDownloadProtocolMessages(t *testing.T) {
  352. buf := make([]byte, 65535)
  353. stdErrBuf := make([]byte, 65535)
  354. readErr := fmt.Errorf("test read error")
  355. writeErr := fmt.Errorf("test write error")
  356. mockSSHChannel := MockChannel{
  357. Buffer: bytes.NewBuffer(buf),
  358. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  359. ReadError: readErr,
  360. WriteError: writeErr,
  361. }
  362. connection := Connection{
  363. channel: &mockSSHChannel,
  364. }
  365. scpCommand := scpCommand{
  366. connection: connection,
  367. args: []string{"-f", "-p", "/tmp"},
  368. }
  369. path := "testDir"
  370. os.Mkdir(path, 0777)
  371. stat, _ := os.Stat(path)
  372. err := scpCommand.sendDownloadProtocolMessages(path, stat)
  373. if err != writeErr {
  374. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  375. }
  376. mockSSHChannel = MockChannel{
  377. Buffer: bytes.NewBuffer(buf),
  378. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  379. ReadError: readErr,
  380. WriteError: nil,
  381. }
  382. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  383. if err != readErr {
  384. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  385. }
  386. mockSSHChannel = MockChannel{
  387. Buffer: bytes.NewBuffer(buf),
  388. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  389. ReadError: readErr,
  390. WriteError: writeErr,
  391. }
  392. scpCommand.args = []string{"-f", "/tmp"}
  393. scpCommand.connection.channel = &mockSSHChannel
  394. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  395. if err != writeErr {
  396. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  397. }
  398. mockSSHChannel = MockChannel{
  399. Buffer: bytes.NewBuffer(buf),
  400. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  401. ReadError: readErr,
  402. WriteError: nil,
  403. }
  404. scpCommand.connection.channel = &mockSSHChannel
  405. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  406. if err != readErr {
  407. t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
  408. }
  409. os.Remove(path)
  410. }
  411. func TestSCPCommandHandleErrors(t *testing.T) {
  412. buf := make([]byte, 65535)
  413. stdErrBuf := make([]byte, 65535)
  414. readErr := fmt.Errorf("test read error")
  415. writeErr := fmt.Errorf("test write error")
  416. mockSSHChannel := MockChannel{
  417. Buffer: bytes.NewBuffer(buf),
  418. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  419. ReadError: readErr,
  420. WriteError: writeErr,
  421. }
  422. server, client := net.Pipe()
  423. defer server.Close()
  424. defer client.Close()
  425. connection := Connection{
  426. channel: &mockSSHChannel,
  427. netConn: client,
  428. }
  429. scpCommand := scpCommand{
  430. connection: connection,
  431. args: []string{"-f", "/tmp"},
  432. }
  433. err := scpCommand.handle()
  434. if err == nil || err != readErr {
  435. t.Errorf("scp download must fail, we are sending a fake error")
  436. }
  437. scpCommand.args = []string{"-i", "/tmp"}
  438. err = scpCommand.handle()
  439. if err == nil {
  440. t.Errorf("invalid scp command must fail")
  441. }
  442. }
  443. func TestSCPRecursiveDownloadErrors(t *testing.T) {
  444. buf := make([]byte, 65535)
  445. stdErrBuf := make([]byte, 65535)
  446. readErr := fmt.Errorf("test read error")
  447. writeErr := fmt.Errorf("test write error")
  448. mockSSHChannel := MockChannel{
  449. Buffer: bytes.NewBuffer(buf),
  450. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  451. ReadError: readErr,
  452. WriteError: writeErr,
  453. }
  454. server, client := net.Pipe()
  455. defer server.Close()
  456. defer client.Close()
  457. connection := Connection{
  458. channel: &mockSSHChannel,
  459. netConn: client,
  460. }
  461. scpCommand := scpCommand{
  462. connection: connection,
  463. args: []string{"-r", "-f", "/tmp"},
  464. }
  465. path := "testDir"
  466. os.Mkdir(path, 0777)
  467. stat, _ := os.Stat(path)
  468. err := scpCommand.handleRecursiveDownload("invalid_dir", stat)
  469. if err != writeErr {
  470. t.Errorf("recursive upload download must fail with the expected error: %v", err)
  471. }
  472. mockSSHChannel = MockChannel{
  473. Buffer: bytes.NewBuffer(buf),
  474. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  475. ReadError: nil,
  476. WriteError: nil,
  477. }
  478. scpCommand.connection.channel = &mockSSHChannel
  479. err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
  480. if err == nil {
  481. t.Errorf("recursive upload download must fail for a non existing dir")
  482. }
  483. os.Remove(path)
  484. }
  485. func TestSCPRecursiveUploadErrors(t *testing.T) {
  486. buf := make([]byte, 65535)
  487. stdErrBuf := make([]byte, 65535)
  488. readErr := fmt.Errorf("test read error")
  489. writeErr := fmt.Errorf("test write error")
  490. mockSSHChannel := MockChannel{
  491. Buffer: bytes.NewBuffer(buf),
  492. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  493. ReadError: readErr,
  494. WriteError: writeErr,
  495. }
  496. connection := Connection{
  497. channel: &mockSSHChannel,
  498. }
  499. scpCommand := scpCommand{
  500. connection: connection,
  501. args: []string{"-r", "-t", "/tmp"},
  502. }
  503. err := scpCommand.handleRecursiveUpload()
  504. if err == nil {
  505. t.Errorf("recursive upload must fail, we send a fake error message")
  506. }
  507. mockSSHChannel = MockChannel{
  508. Buffer: bytes.NewBuffer(buf),
  509. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  510. ReadError: readErr,
  511. WriteError: nil,
  512. }
  513. scpCommand.connection.channel = &mockSSHChannel
  514. err = scpCommand.handleRecursiveUpload()
  515. if err == nil {
  516. t.Errorf("recursive upload must fail, we send a fake error message")
  517. }
  518. }
  519. func TestSCPCreateDirs(t *testing.T) {
  520. buf := make([]byte, 65535)
  521. stdErrBuf := make([]byte, 65535)
  522. u := dataprovider.User{}
  523. u.HomeDir = "home_rel_path"
  524. u.Username = "test"
  525. u.Permissions = []string{"*"}
  526. mockSSHChannel := MockChannel{
  527. Buffer: bytes.NewBuffer(buf),
  528. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  529. ReadError: nil,
  530. WriteError: nil,
  531. }
  532. connection := Connection{
  533. User: u,
  534. channel: &mockSSHChannel,
  535. }
  536. scpCommand := scpCommand{
  537. connection: connection,
  538. args: []string{"-r", "-t", "/tmp"},
  539. }
  540. err := scpCommand.handleCreateDir("invalid_dir")
  541. if err == nil {
  542. t.Errorf("create invalid dir must fail")
  543. }
  544. }
  545. func TestSCPDownloadFileData(t *testing.T) {
  546. testfile := "testfile"
  547. buf := make([]byte, 65535)
  548. readErr := fmt.Errorf("test read error")
  549. writeErr := fmt.Errorf("test write error")
  550. stdErrBuf := make([]byte, 65535)
  551. mockSSHChannelReadErr := MockChannel{
  552. Buffer: bytes.NewBuffer(buf),
  553. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  554. ReadError: readErr,
  555. WriteError: nil,
  556. }
  557. mockSSHChannelWriteErr := MockChannel{
  558. Buffer: bytes.NewBuffer(buf),
  559. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  560. ReadError: nil,
  561. WriteError: writeErr,
  562. }
  563. connection := Connection{
  564. channel: &mockSSHChannelReadErr,
  565. }
  566. scpCommand := scpCommand{
  567. connection: connection,
  568. args: []string{"-r", "-f", "/tmp"},
  569. }
  570. ioutil.WriteFile(testfile, []byte("test"), 0666)
  571. stat, _ := os.Stat(testfile)
  572. err := scpCommand.sendDownloadFileData(testfile, stat, nil)
  573. if err != readErr {
  574. t.Errorf("send download file data must fail with the expected error: %v", err)
  575. }
  576. scpCommand.connection.channel = &mockSSHChannelWriteErr
  577. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  578. if err != writeErr {
  579. t.Errorf("send download file data must fail with the expected error: %v", err)
  580. }
  581. scpCommand.args = []string{"-r", "-p", "-f", "/tmp"}
  582. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  583. if err != writeErr {
  584. t.Errorf("send download file data must fail with the expected error: %v", err)
  585. }
  586. scpCommand.connection.channel = &mockSSHChannelReadErr
  587. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  588. if err != readErr {
  589. t.Errorf("send download file data must fail with the expected error: %v", err)
  590. }
  591. os.Remove(testfile)
  592. }
  593. func TestSCPUploadFiledata(t *testing.T) {
  594. testfile := "testfile"
  595. buf := make([]byte, 65535)
  596. stdErrBuf := make([]byte, 65535)
  597. readErr := fmt.Errorf("test read error")
  598. writeErr := fmt.Errorf("test write error")
  599. mockSSHChannel := MockChannel{
  600. Buffer: bytes.NewBuffer(buf),
  601. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  602. ReadError: readErr,
  603. WriteError: writeErr,
  604. }
  605. connection := Connection{
  606. User: dataprovider.User{
  607. Username: "testuser",
  608. },
  609. protocol: protocolSCP,
  610. channel: &mockSSHChannel,
  611. }
  612. scpCommand := scpCommand{
  613. connection: connection,
  614. args: []string{"-r", "-t", "/tmp"},
  615. }
  616. file, _ := os.Create(testfile)
  617. transfer := Transfer{
  618. file: file,
  619. path: file.Name(),
  620. start: time.Now(),
  621. bytesSent: 0,
  622. bytesReceived: 0,
  623. user: scpCommand.connection.User,
  624. connectionID: "",
  625. transferType: transferDownload,
  626. lastActivity: time.Now(),
  627. isNewFile: true,
  628. protocol: connection.protocol,
  629. transferError: nil,
  630. isFinished: false,
  631. minWriteOffset: 0,
  632. }
  633. addTransfer(&transfer)
  634. err := scpCommand.getUploadFileData(2, &transfer)
  635. if err == nil {
  636. t.Errorf("upload must fail, we send a fake write error message")
  637. }
  638. mockSSHChannel = MockChannel{
  639. Buffer: bytes.NewBuffer(buf),
  640. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  641. ReadError: readErr,
  642. WriteError: nil,
  643. }
  644. scpCommand.connection.channel = &mockSSHChannel
  645. file, _ = os.Create(testfile)
  646. transfer.file = file
  647. addTransfer(&transfer)
  648. err = scpCommand.getUploadFileData(2, &transfer)
  649. if err == nil {
  650. t.Errorf("upload must fail, we send a fake read error message")
  651. }
  652. respBuffer := []byte("12")
  653. respBuffer = append(respBuffer, 0x02)
  654. mockSSHChannel = MockChannel{
  655. Buffer: bytes.NewBuffer(respBuffer),
  656. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  657. ReadError: nil,
  658. WriteError: nil,
  659. }
  660. scpCommand.connection.channel = &mockSSHChannel
  661. file, _ = os.Create(testfile)
  662. transfer.file = file
  663. addTransfer(&transfer)
  664. err = scpCommand.getUploadFileData(2, &transfer)
  665. if err == nil {
  666. t.Errorf("upload must fail, we have not enough data to read")
  667. }
  668. // the file is already closed so we have an error on trasfer closing
  669. mockSSHChannel = MockChannel{
  670. Buffer: bytes.NewBuffer(buf),
  671. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  672. ReadError: nil,
  673. WriteError: nil,
  674. }
  675. addTransfer(&transfer)
  676. err = scpCommand.getUploadFileData(0, &transfer)
  677. if err == nil {
  678. t.Errorf("upload must fail, the file is closed")
  679. }
  680. os.Remove(testfile)
  681. }
  682. func TestUploadError(t *testing.T) {
  683. oldUploadMode := uploadMode
  684. uploadMode = uploadModeAtomic
  685. connection := Connection{
  686. User: dataprovider.User{
  687. Username: "testuser",
  688. },
  689. protocol: protocolSCP,
  690. }
  691. testfile := "testfile"
  692. fileTempName := "temptestfile"
  693. file, _ := os.Create(fileTempName)
  694. transfer := Transfer{
  695. file: file,
  696. path: testfile,
  697. start: time.Now(),
  698. bytesSent: 0,
  699. bytesReceived: 100,
  700. user: connection.User,
  701. connectionID: "",
  702. transferType: transferUpload,
  703. lastActivity: time.Now(),
  704. isNewFile: true,
  705. protocol: connection.protocol,
  706. transferError: nil,
  707. isFinished: false,
  708. minWriteOffset: 0,
  709. }
  710. addTransfer(&transfer)
  711. transfer.TransferError(fmt.Errorf("fake error"))
  712. transfer.Close()
  713. if transfer.bytesReceived > 0 {
  714. t.Errorf("byte sent should be 0 for a failed transfer: %v", transfer.bytesSent)
  715. }
  716. _, err := os.Stat(testfile)
  717. if !os.IsNotExist(err) {
  718. t.Errorf("file uploaded must be deleted after an error: %v", err)
  719. }
  720. _, err = os.Stat(fileTempName)
  721. if !os.IsNotExist(err) {
  722. t.Errorf("file uploaded must be deleted after an error: %v", err)
  723. }
  724. uploadMode = oldUploadMode
  725. }
  726. func TestConnectionStatusStruct(t *testing.T) {
  727. var transfers []connectionTransfer
  728. transferUL := connectionTransfer{
  729. OperationType: operationUpload,
  730. StartTime: utils.GetTimeAsMsSinceEpoch(time.Now()),
  731. Size: 123,
  732. LastActivity: utils.GetTimeAsMsSinceEpoch(time.Now()),
  733. Path: "/test.upload",
  734. }
  735. transferDL := connectionTransfer{
  736. OperationType: operationDownload,
  737. StartTime: utils.GetTimeAsMsSinceEpoch(time.Now()),
  738. Size: 123,
  739. LastActivity: utils.GetTimeAsMsSinceEpoch(time.Now()),
  740. Path: "/test.download",
  741. }
  742. transfers = append(transfers, transferUL)
  743. transfers = append(transfers, transferDL)
  744. c := ConnectionStatus{
  745. Username: "test",
  746. ConnectionID: "123",
  747. ClientVersion: "fakeClient-1.0.0",
  748. RemoteAddress: "127.0.0.1:1234",
  749. ConnectionTime: utils.GetTimeAsMsSinceEpoch(time.Now()),
  750. LastActivity: utils.GetTimeAsMsSinceEpoch(time.Now()),
  751. Protocol: "SFTP",
  752. Transfers: transfers,
  753. }
  754. durationString := c.GetConnectionDuration()
  755. if len(durationString) == 0 {
  756. t.Errorf("error getting connection duration")
  757. }
  758. transfersString := c.GetTransfersAsString()
  759. if len(transfersString) == 0 {
  760. t.Errorf("error getting transfers as string")
  761. }
  762. connInfo := c.GetConnectionInfo()
  763. if len(connInfo) == 0 {
  764. t.Errorf("error getting connection info")
  765. }
  766. }