internal_test.go 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772
  1. package sftpd
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "os"
  10. "path/filepath"
  11. "runtime"
  12. "sync"
  13. "testing"
  14. "time"
  15. "github.com/eikenb/pipeat"
  16. "github.com/pkg/sftp"
  17. "github.com/stretchr/testify/assert"
  18. "golang.org/x/crypto/ssh"
  19. "github.com/drakkan/sftpgo/dataprovider"
  20. "github.com/drakkan/sftpgo/utils"
  21. "github.com/drakkan/sftpgo/vfs"
  22. )
  23. const osWindows = "windows"
  24. type MockChannel struct {
  25. Buffer *bytes.Buffer
  26. StdErrBuffer *bytes.Buffer
  27. ReadError error
  28. WriteError error
  29. ShortWriteErr bool
  30. }
  31. func (c *MockChannel) Read(data []byte) (int, error) {
  32. if c.ReadError != nil {
  33. return 0, c.ReadError
  34. }
  35. return c.Buffer.Read(data)
  36. }
  37. func (c *MockChannel) Write(data []byte) (int, error) {
  38. if c.WriteError != nil {
  39. return 0, c.WriteError
  40. }
  41. if c.ShortWriteErr {
  42. return 0, nil
  43. }
  44. return c.Buffer.Write(data)
  45. }
  46. func (c *MockChannel) Close() error {
  47. return nil
  48. }
  49. func (c *MockChannel) CloseWrite() error {
  50. return nil
  51. }
  52. func (c *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
  53. return true, nil
  54. }
  55. func (c *MockChannel) Stderr() io.ReadWriter {
  56. return c.StdErrBuffer
  57. }
  58. // MockOsFs mockable OsFs
  59. type MockOsFs struct {
  60. vfs.Fs
  61. err error
  62. statErr error
  63. isAtomicUploadSupported bool
  64. }
  65. // Name returns the name for the Fs implementation
  66. func (fs MockOsFs) Name() string {
  67. return "mockOsFs"
  68. }
  69. // IsUploadResumeSupported returns true if upload resume is supported
  70. func (MockOsFs) IsUploadResumeSupported() bool {
  71. return false
  72. }
  73. // IsAtomicUploadSupported returns true if atomic upload is supported
  74. func (fs MockOsFs) IsAtomicUploadSupported() bool {
  75. return fs.isAtomicUploadSupported
  76. }
  77. // Stat returns a FileInfo describing the named file
  78. func (fs MockOsFs) Stat(name string) (os.FileInfo, error) {
  79. if fs.statErr != nil {
  80. return nil, fs.statErr
  81. }
  82. return os.Stat(name)
  83. }
  84. // Remove removes the named file or (empty) directory.
  85. func (fs MockOsFs) Remove(name string, isDir bool) error {
  86. if fs.err != nil {
  87. return fs.err
  88. }
  89. return os.Remove(name)
  90. }
  91. // Rename renames (moves) source to target
  92. func (fs MockOsFs) Rename(source, target string) error {
  93. if fs.err != nil {
  94. return fs.err
  95. }
  96. return os.Rename(source, target)
  97. }
  98. func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
  99. return &MockOsFs{
  100. Fs: vfs.NewOsFs(connectionID, rootDir, nil),
  101. err: err,
  102. statErr: statErr,
  103. isAtomicUploadSupported: atomicUpload,
  104. }
  105. }
  106. func TestNewActionNotification(t *testing.T) {
  107. user := dataprovider.User{
  108. Username: "username",
  109. }
  110. user.FsConfig.Provider = 0
  111. user.FsConfig.S3Config = vfs.S3FsConfig{
  112. Bucket: "s3bucket",
  113. Endpoint: "endpoint",
  114. }
  115. user.FsConfig.GCSConfig = vfs.GCSFsConfig{
  116. Bucket: "gcsbucket",
  117. }
  118. a := newActionNotification(user, operationDownload, "path", "target", "", 123, nil)
  119. assert.Equal(t, user.Username, a.Username)
  120. assert.Equal(t, 0, len(a.Bucket))
  121. assert.Equal(t, 0, len(a.Endpoint))
  122. user.FsConfig.Provider = 1
  123. a = newActionNotification(user, operationDownload, "path", "target", "", 123, nil)
  124. assert.Equal(t, "s3bucket", a.Bucket)
  125. assert.Equal(t, "endpoint", a.Endpoint)
  126. user.FsConfig.Provider = 2
  127. a = newActionNotification(user, operationDownload, "path", "target", "", 123, nil)
  128. assert.Equal(t, "gcsbucket", a.Bucket)
  129. assert.Equal(t, 0, len(a.Endpoint))
  130. }
  131. func TestWrongActions(t *testing.T) {
  132. actionsCopy := actions
  133. badCommand := "/bad/command"
  134. if runtime.GOOS == osWindows {
  135. badCommand = "C:\\bad\\command"
  136. }
  137. actions = Actions{
  138. ExecuteOn: []string{operationDownload},
  139. Command: badCommand,
  140. HTTPNotificationURL: "",
  141. }
  142. user := dataprovider.User{
  143. Username: "username",
  144. }
  145. err := executeAction(newActionNotification(user, operationDownload, "path", "", "", 0, nil))
  146. assert.Error(t, err, "action with bad command must fail")
  147. err = executeAction(newActionNotification(user, operationDelete, "path", "", "", 0, nil))
  148. assert.NoError(t, err)
  149. actions.Command = ""
  150. actions.HTTPNotificationURL = "http://foo\x7f.com/"
  151. err = executeAction(newActionNotification(user, operationDownload, "path", "", "", 0, nil))
  152. assert.Error(t, err, "action with bad url must fail")
  153. actions = actionsCopy
  154. }
  155. func TestActionHTTP(t *testing.T) {
  156. actionsCopy := actions
  157. actions = Actions{
  158. ExecuteOn: []string{operationDownload},
  159. Command: "",
  160. HTTPNotificationURL: "http://127.0.0.1:8080/",
  161. }
  162. user := dataprovider.User{
  163. Username: "username",
  164. }
  165. err := executeAction(newActionNotification(user, operationDownload, "path", "", "", 0, nil))
  166. assert.NoError(t, err)
  167. actions = actionsCopy
  168. }
  169. func TestRemoveNonexistentTransfer(t *testing.T) {
  170. transfer := Transfer{}
  171. err := removeTransfer(&transfer)
  172. assert.Error(t, err, "remove nonexistent transfer must fail")
  173. }
  174. func TestRemoveNonexistentQuotaScan(t *testing.T) {
  175. err := RemoveQuotaScan("username")
  176. assert.Error(t, err, "remove nonexistent quota scan must fail")
  177. }
  178. func TestGetOSOpenFlags(t *testing.T) {
  179. var flags sftp.FileOpenFlags
  180. flags.Write = true
  181. flags.Excl = true
  182. osFlags := getOSOpenFlags(flags)
  183. assert.NotEqual(t, 0, osFlags&os.O_WRONLY)
  184. assert.NotEqual(t, 0, osFlags&os.O_EXCL)
  185. flags.Append = true
  186. // append flag should be ignored to allow resume
  187. assert.NotEqual(t, 0, osFlags&os.O_WRONLY)
  188. assert.NotEqual(t, 0, osFlags&os.O_EXCL)
  189. }
  190. func TestUploadResumeInvalidOffset(t *testing.T) {
  191. testfile := "testfile" //nolint:goconst
  192. file, err := os.Create(testfile)
  193. assert.NoError(t, err)
  194. transfer := Transfer{
  195. file: file,
  196. path: file.Name(),
  197. start: time.Now(),
  198. bytesSent: 0,
  199. bytesReceived: 0,
  200. user: dataprovider.User{
  201. Username: "testuser",
  202. },
  203. connectionID: "",
  204. transferType: transferUpload,
  205. lastActivity: time.Now(),
  206. isNewFile: false,
  207. protocol: protocolSFTP,
  208. transferError: nil,
  209. isFinished: false,
  210. minWriteOffset: 10,
  211. lock: new(sync.Mutex),
  212. }
  213. _, err = transfer.WriteAt([]byte("test"), 0)
  214. assert.Error(t, err, "upload with invalid offset must fail")
  215. err = transfer.Close()
  216. if assert.Error(t, err) {
  217. assert.Contains(t, err.Error(), "Invalid write offset")
  218. }
  219. err = os.Remove(testfile)
  220. assert.NoError(t, err)
  221. }
  222. func TestReadWriteErrors(t *testing.T) {
  223. testfile := "testfile"
  224. file, err := os.Create(testfile)
  225. assert.NoError(t, err)
  226. transfer := Transfer{
  227. file: file,
  228. path: file.Name(),
  229. start: time.Now(),
  230. bytesSent: 0,
  231. bytesReceived: 0,
  232. user: dataprovider.User{
  233. Username: "testuser",
  234. },
  235. connectionID: "",
  236. transferType: transferDownload,
  237. lastActivity: time.Now(),
  238. isNewFile: false,
  239. protocol: protocolSFTP,
  240. transferError: nil,
  241. isFinished: false,
  242. minWriteOffset: 0,
  243. expectedSize: 10,
  244. lock: new(sync.Mutex),
  245. }
  246. err = file.Close()
  247. assert.NoError(t, err)
  248. _, err = transfer.WriteAt([]byte("test"), 0)
  249. assert.Error(t, err, "writing to closed file must fail")
  250. buf := make([]byte, 32768)
  251. _, err = transfer.ReadAt(buf, 0)
  252. assert.Error(t, err, "reading from a closed file must fail")
  253. err = transfer.Close()
  254. assert.Error(t, err, "upoload must fail: the expected size does not match")
  255. r, _, err := pipeat.Pipe()
  256. assert.NoError(t, err)
  257. transfer = Transfer{
  258. readerAt: r,
  259. writerAt: nil,
  260. start: time.Now(),
  261. bytesSent: 0,
  262. bytesReceived: 0,
  263. user: dataprovider.User{
  264. Username: "testuser",
  265. },
  266. connectionID: "",
  267. transferType: transferDownload,
  268. lastActivity: time.Now(),
  269. isNewFile: false,
  270. protocol: protocolSFTP,
  271. transferError: nil,
  272. isFinished: false,
  273. lock: new(sync.Mutex),
  274. }
  275. err = transfer.closeIO()
  276. assert.NoError(t, err)
  277. _, err = transfer.ReadAt(buf, 0)
  278. assert.Error(t, err, "reading from a closed pipe must fail")
  279. r, w, err := pipeat.Pipe()
  280. assert.NoError(t, err)
  281. transfer = Transfer{
  282. readerAt: nil,
  283. writerAt: vfs.NewPipeWriter(w),
  284. start: time.Now(),
  285. bytesSent: 0,
  286. bytesReceived: 0,
  287. user: dataprovider.User{
  288. Username: "testuser",
  289. },
  290. connectionID: "",
  291. transferType: transferDownload,
  292. lastActivity: time.Now(),
  293. isNewFile: false,
  294. protocol: protocolSFTP,
  295. transferError: nil,
  296. isFinished: false,
  297. lock: new(sync.Mutex),
  298. }
  299. err = r.Close()
  300. assert.NoError(t, err)
  301. errFake := fmt.Errorf("fake upload error")
  302. go func() {
  303. time.Sleep(100 * time.Millisecond)
  304. transfer.writerAt.Done(errFake)
  305. }()
  306. err = transfer.closeIO()
  307. assert.EqualError(t, err, errFake.Error())
  308. _, err = transfer.WriteAt([]byte("test"), 0)
  309. assert.Error(t, err, "writing to closed pipe must fail")
  310. err = os.Remove(testfile)
  311. assert.NoError(t, err)
  312. }
  313. func TestTransferCancelFn(t *testing.T) {
  314. testfile := "testfile"
  315. file, err := os.Create(testfile)
  316. assert.NoError(t, err)
  317. isCancelled := false
  318. cancelFn := func() {
  319. isCancelled = true
  320. }
  321. transfer := Transfer{
  322. file: file,
  323. cancelFn: cancelFn,
  324. path: file.Name(),
  325. start: time.Now(),
  326. bytesSent: 0,
  327. bytesReceived: 0,
  328. user: dataprovider.User{
  329. Username: "testuser",
  330. },
  331. connectionID: "",
  332. transferType: transferDownload,
  333. lastActivity: time.Now(),
  334. isNewFile: false,
  335. protocol: protocolSFTP,
  336. transferError: nil,
  337. isFinished: false,
  338. minWriteOffset: 0,
  339. expectedSize: 10,
  340. lock: new(sync.Mutex),
  341. }
  342. errFake := errors.New("fake error, this will trigger cancelFn")
  343. transfer.TransferError(errFake)
  344. err = transfer.Close()
  345. assert.EqualError(t, err, errFake.Error())
  346. assert.True(t, isCancelled, "cancelFn not called!")
  347. err = os.Remove(testfile)
  348. assert.NoError(t, err)
  349. }
  350. func TestMockFsErrors(t *testing.T) {
  351. errFake := errors.New("fake error")
  352. fs := newMockOsFs(errFake, errFake, false, "123", os.TempDir())
  353. u := dataprovider.User{}
  354. u.Username = "test_username"
  355. u.Permissions = make(map[string][]string)
  356. u.Permissions["/"] = []string{dataprovider.PermAny}
  357. u.HomeDir = os.TempDir()
  358. c := Connection{
  359. fs: fs,
  360. User: u,
  361. }
  362. testfile := filepath.Join(u.HomeDir, "testfile")
  363. request := sftp.NewRequest("Remove", testfile)
  364. err := ioutil.WriteFile(testfile, []byte("test"), 0666)
  365. assert.NoError(t, err)
  366. err = c.handleSFTPRemove(testfile, request)
  367. assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
  368. _, err = c.Filewrite(request)
  369. assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
  370. var flags sftp.FileOpenFlags
  371. flags.Write = true
  372. flags.Trunc = false
  373. flags.Append = true
  374. _, err = c.handleSFTPUploadToExistingFile(flags, testfile, testfile, 0, false)
  375. assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error())
  376. err = os.Remove(testfile)
  377. assert.NoError(t, err)
  378. }
  379. func TestUploadFiles(t *testing.T) {
  380. oldUploadMode := uploadMode
  381. uploadMode = uploadModeAtomic
  382. c := Connection{
  383. fs: vfs.NewOsFs("123", os.TempDir(), nil),
  384. }
  385. var flags sftp.FileOpenFlags
  386. flags.Write = true
  387. flags.Trunc = true
  388. _, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, false)
  389. assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
  390. uploadMode = uploadModeStandard
  391. _, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, false)
  392. assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
  393. missingFile := "missing/relative/file.txt"
  394. if runtime.GOOS == osWindows {
  395. missingFile = "missing\\relative\\file.txt"
  396. }
  397. _, err = c.handleSFTPUploadToNewFile(".", missingFile, false)
  398. assert.Error(t, err, "upload new file in missing path must fail")
  399. c.fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
  400. f, err := ioutil.TempFile("", "temp")
  401. assert.NoError(t, err)
  402. err = f.Close()
  403. assert.NoError(t, err)
  404. _, err = c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, false)
  405. assert.NoError(t, err)
  406. if assert.Equal(t, 1, len(activeTransfers)) {
  407. transfer := activeTransfers[0]
  408. assert.Equal(t, int64(123), transfer.initialSize)
  409. err = transfer.Close()
  410. assert.NoError(t, err)
  411. assert.Equal(t, 0, len(activeTransfers))
  412. }
  413. err = os.Remove(f.Name())
  414. assert.NoError(t, err)
  415. uploadMode = oldUploadMode
  416. }
  417. func TestWithInvalidHome(t *testing.T) {
  418. u := dataprovider.User{}
  419. u.HomeDir = "home_rel_path" //nolint:goconst
  420. _, err := loginUser(u, dataprovider.SSHLoginMethodPassword, "", nil)
  421. assert.Error(t, err, "login a user with an invalid home_dir must fail")
  422. u.HomeDir = os.TempDir()
  423. fs, err := u.GetFilesystem("123")
  424. assert.NoError(t, err)
  425. c := Connection{
  426. User: u,
  427. fs: fs,
  428. }
  429. _, err = c.fs.ResolvePath("../upper_path")
  430. assert.Error(t, err, "tested path is not a home subdir")
  431. }
  432. func TestSFTPCmdTargetPath(t *testing.T) {
  433. u := dataprovider.User{}
  434. if runtime.GOOS == osWindows {
  435. u.HomeDir = "C:\\invalid_home"
  436. } else {
  437. u.HomeDir = "/invalid_home"
  438. }
  439. u.Username = "testuser"
  440. u.Permissions = make(map[string][]string)
  441. u.Permissions["/"] = []string{dataprovider.PermAny}
  442. fs, err := u.GetFilesystem("123")
  443. assert.NoError(t, err)
  444. connection := Connection{
  445. User: u,
  446. fs: fs,
  447. }
  448. _, err = connection.getSFTPCmdTargetPath("invalid_path")
  449. assert.EqualError(t, err, sftp.ErrSSHFxNoSuchFile.Error())
  450. }
  451. func TestGetSFTPErrorFromOSError(t *testing.T) {
  452. err := os.ErrNotExist
  453. fs := vfs.NewOsFs("", os.TempDir(), nil)
  454. err = vfs.GetSFTPError(fs, err)
  455. assert.EqualError(t, err, sftp.ErrSSHFxNoSuchFile.Error())
  456. err = os.ErrPermission
  457. err = vfs.GetSFTPError(fs, err)
  458. assert.EqualError(t, err, sftp.ErrSSHFxPermissionDenied.Error())
  459. err = vfs.GetSFTPError(fs, nil)
  460. assert.NoError(t, err)
  461. }
  462. func TestSetstatModeIgnore(t *testing.T) {
  463. originalMode := setstatMode
  464. setstatMode = 1
  465. connection := Connection{}
  466. err := connection.handleSFTPSetstat("invalid", nil)
  467. assert.NoError(t, err)
  468. setstatMode = originalMode
  469. }
  470. func TestSFTPGetUsedQuota(t *testing.T) {
  471. u := dataprovider.User{}
  472. u.HomeDir = "home_rel_path"
  473. u.Username = "test_invalid_user"
  474. u.QuotaSize = 4096
  475. u.QuotaFiles = 1
  476. u.Permissions = make(map[string][]string)
  477. u.Permissions["/"] = []string{dataprovider.PermAny}
  478. connection := Connection{
  479. User: u,
  480. }
  481. assert.False(t, connection.hasSpace(false))
  482. }
  483. func TestSupportedSSHCommands(t *testing.T) {
  484. cmds := GetSupportedSSHCommands()
  485. assert.Equal(t, len(supportedSSHCommands), len(cmds))
  486. for _, c := range cmds {
  487. assert.True(t, utils.IsStringInSlice(c, supportedSSHCommands))
  488. }
  489. }
  490. func TestSSHCommandPath(t *testing.T) {
  491. buf := make([]byte, 65535)
  492. stdErrBuf := make([]byte, 65535)
  493. mockSSHChannel := MockChannel{
  494. Buffer: bytes.NewBuffer(buf),
  495. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  496. ReadError: nil,
  497. }
  498. connection := Connection{
  499. channel: &mockSSHChannel,
  500. }
  501. sshCommand := sshCommand{
  502. command: "test",
  503. connection: connection,
  504. args: []string{},
  505. }
  506. assert.Equal(t, "", sshCommand.getDestPath())
  507. sshCommand.args = []string{"-t", "/tmp/../path"}
  508. assert.Equal(t, "/path", sshCommand.getDestPath())
  509. sshCommand.args = []string{"-t", "/tmp/"}
  510. assert.Equal(t, "/tmp/", sshCommand.getDestPath())
  511. sshCommand.args = []string{"-t", "tmp/"}
  512. assert.Equal(t, "/tmp/", sshCommand.getDestPath())
  513. sshCommand.args = []string{"-t", "/tmp/../../../path"}
  514. assert.Equal(t, "/path", sshCommand.getDestPath())
  515. sshCommand.args = []string{"-t", ".."}
  516. assert.Equal(t, "/", sshCommand.getDestPath())
  517. sshCommand.args = []string{"-t", "."}
  518. assert.Equal(t, "/", sshCommand.getDestPath())
  519. sshCommand.args = []string{"-t", "//"}
  520. assert.Equal(t, "/", sshCommand.getDestPath())
  521. sshCommand.args = []string{"-t", "../.."}
  522. assert.Equal(t, "/", sshCommand.getDestPath())
  523. sshCommand.args = []string{"-t", "/.."}
  524. assert.Equal(t, "/", sshCommand.getDestPath())
  525. sshCommand.args = []string{"-f", "/a space.txt"}
  526. assert.Equal(t, "/a space.txt", sshCommand.getDestPath())
  527. }
  528. func TestSSHParseCommandPayload(t *testing.T) {
  529. cmd := "command -a -f /ab\\ à/some\\ spaces\\ \\ \\(\\).txt"
  530. name, args, _ := parseCommandPayload(cmd)
  531. assert.Equal(t, "command", name)
  532. assert.Equal(t, 3, len(args))
  533. assert.Equal(t, "/ab à/some spaces ().txt", args[2])
  534. _, _, err := parseCommandPayload("")
  535. assert.Error(t, err, "parsing invalid command must fail")
  536. }
  537. func TestSSHCommandErrors(t *testing.T) {
  538. buf := make([]byte, 65535)
  539. stdErrBuf := make([]byte, 65535)
  540. readErr := fmt.Errorf("test read error")
  541. mockSSHChannel := MockChannel{
  542. Buffer: bytes.NewBuffer(buf),
  543. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  544. ReadError: readErr,
  545. }
  546. server, client := net.Pipe()
  547. defer func() {
  548. err := server.Close()
  549. assert.NoError(t, err)
  550. }()
  551. defer func() {
  552. err := client.Close()
  553. assert.NoError(t, err)
  554. }()
  555. user := dataprovider.User{}
  556. user.Permissions = make(map[string][]string)
  557. user.Permissions["/"] = []string{dataprovider.PermAny}
  558. fs, err := user.GetFilesystem("123")
  559. assert.NoError(t, err)
  560. connection := Connection{
  561. channel: &mockSSHChannel,
  562. netConn: client,
  563. User: user,
  564. fs: fs,
  565. }
  566. cmd := sshCommand{
  567. command: "md5sum",
  568. connection: connection,
  569. args: []string{},
  570. }
  571. err = cmd.handle()
  572. assert.Error(t, err, "ssh command must fail, we are sending a fake error")
  573. cmd = sshCommand{
  574. command: "md5sum",
  575. connection: connection,
  576. args: []string{"/../../test_file.dat"},
  577. }
  578. err = cmd.handle()
  579. assert.Error(t, err, "ssh command must fail, we are requesting an invalid path")
  580. cmd = sshCommand{
  581. command: "git-receive-pack",
  582. connection: connection,
  583. args: []string{"/../../testrepo"},
  584. }
  585. err = cmd.handle()
  586. assert.Error(t, err, "ssh command must fail, we are requesting an invalid path")
  587. cmd.connection.User.HomeDir = os.TempDir()
  588. cmd.connection.User.QuotaFiles = 1
  589. cmd.connection.User.UsedQuotaFiles = 2
  590. fs, err = cmd.connection.User.GetFilesystem("123")
  591. assert.NoError(t, err)
  592. cmd.connection.fs = fs
  593. err = cmd.handle()
  594. assert.EqualError(t, err, errQuotaExceeded.Error())
  595. cmd.connection.User.QuotaFiles = 0
  596. cmd.connection.User.UsedQuotaFiles = 0
  597. cmd.connection.User.Permissions = make(map[string][]string)
  598. cmd.connection.User.Permissions["/"] = []string{dataprovider.PermListItems}
  599. err = cmd.handle()
  600. assert.EqualError(t, err, errPermissionDenied.Error())
  601. cmd.connection.User.Permissions["/"] = []string{dataprovider.PermAny}
  602. cmd.command = "invalid_command"
  603. command, err := cmd.getSystemCommand()
  604. assert.NoError(t, err)
  605. err = cmd.executeSystemCommand(command)
  606. assert.Error(t, err, "invalid command must fail")
  607. command, err = cmd.getSystemCommand()
  608. assert.NoError(t, err)
  609. _, err = command.cmd.StderrPipe()
  610. assert.NoError(t, err)
  611. err = cmd.executeSystemCommand(command)
  612. assert.Error(t, err, "command must fail, pipe was already assigned")
  613. err = cmd.executeSystemCommand(command)
  614. assert.Error(t, err, "command must fail, pipe was already assigned")
  615. command, err = cmd.getSystemCommand()
  616. assert.NoError(t, err)
  617. _, err = command.cmd.StdoutPipe()
  618. assert.NoError(t, err)
  619. err = cmd.executeSystemCommand(command)
  620. assert.Error(t, err, "command must fail, pipe was already assigned")
  621. }
  622. func TestCommandsWithExtensionsFilter(t *testing.T) {
  623. buf := make([]byte, 65535)
  624. stdErrBuf := make([]byte, 65535)
  625. mockSSHChannel := MockChannel{
  626. Buffer: bytes.NewBuffer(buf),
  627. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  628. }
  629. server, client := net.Pipe()
  630. defer server.Close()
  631. defer client.Close()
  632. user := dataprovider.User{
  633. Username: "test",
  634. HomeDir: os.TempDir(),
  635. Status: 1,
  636. }
  637. user.Filters.FileExtensions = []dataprovider.ExtensionsFilter{
  638. {
  639. Path: "/subdir",
  640. AllowedExtensions: []string{".jpg"},
  641. DeniedExtensions: []string{},
  642. },
  643. }
  644. fs, err := user.GetFilesystem("123")
  645. assert.NoError(t, err)
  646. connection := Connection{
  647. channel: &mockSSHChannel,
  648. netConn: client,
  649. User: user,
  650. fs: fs,
  651. }
  652. cmd := sshCommand{
  653. command: "md5sum",
  654. connection: connection,
  655. args: []string{"subdir/test.png"},
  656. }
  657. err = cmd.handleHashCommands()
  658. assert.EqualError(t, err, errPermissionDenied.Error())
  659. cmd = sshCommand{
  660. command: "rsync",
  661. connection: connection,
  662. args: []string{"--server", "-vlogDtprze.iLsfxC", ".", "/"},
  663. }
  664. _, err = cmd.getSystemCommand()
  665. assert.EqualError(t, err, errUnsupportedConfig.Error())
  666. cmd = sshCommand{
  667. command: "git-receive-pack",
  668. connection: connection,
  669. args: []string{"/subdir"},
  670. }
  671. _, err = cmd.getSystemCommand()
  672. assert.EqualError(t, err, errUnsupportedConfig.Error())
  673. cmd = sshCommand{
  674. command: "git-receive-pack",
  675. connection: connection,
  676. args: []string{"/subdir/dir"},
  677. }
  678. _, err = cmd.getSystemCommand()
  679. assert.EqualError(t, err, errUnsupportedConfig.Error())
  680. cmd = sshCommand{
  681. command: "git-receive-pack",
  682. connection: connection,
  683. args: []string{"/adir/subdir"},
  684. }
  685. _, err = cmd.getSystemCommand()
  686. assert.NoError(t, err)
  687. }
  688. func TestSSHCommandsRemoteFs(t *testing.T) {
  689. buf := make([]byte, 65535)
  690. stdErrBuf := make([]byte, 65535)
  691. mockSSHChannel := MockChannel{
  692. Buffer: bytes.NewBuffer(buf),
  693. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  694. }
  695. server, client := net.Pipe()
  696. defer func() {
  697. err := server.Close()
  698. assert.NoError(t, err)
  699. }()
  700. defer func() {
  701. err := client.Close()
  702. assert.NoError(t, err)
  703. }()
  704. user := dataprovider.User{}
  705. user.FsConfig = dataprovider.Filesystem{
  706. Provider: 1,
  707. S3Config: vfs.S3FsConfig{
  708. Bucket: "s3bucket",
  709. Endpoint: "endpoint",
  710. Region: "eu-west-1",
  711. },
  712. }
  713. fs, err := user.GetFilesystem("123")
  714. assert.NoError(t, err)
  715. connection := Connection{
  716. channel: &mockSSHChannel,
  717. netConn: client,
  718. User: user,
  719. fs: fs,
  720. }
  721. cmd := sshCommand{
  722. command: "md5sum",
  723. connection: connection,
  724. args: []string{},
  725. }
  726. err = cmd.handleHashCommands()
  727. assert.Error(t, err, "command must fail for a non local filesystem")
  728. command, err := cmd.getSystemCommand()
  729. assert.NoError(t, err)
  730. err = cmd.executeSystemCommand(command)
  731. assert.Error(t, err, "command must fail for a non local filesystem")
  732. }
  733. func TestSSHCommandQuotaScan(t *testing.T) {
  734. buf := make([]byte, 65535)
  735. stdErrBuf := make([]byte, 65535)
  736. readErr := fmt.Errorf("test read error")
  737. mockSSHChannel := MockChannel{
  738. Buffer: bytes.NewBuffer(buf),
  739. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  740. ReadError: readErr,
  741. }
  742. server, client := net.Pipe()
  743. defer func() {
  744. err := server.Close()
  745. assert.NoError(t, err)
  746. }()
  747. defer func() {
  748. err := client.Close()
  749. assert.NoError(t, err)
  750. }()
  751. permissions := make(map[string][]string)
  752. permissions["/"] = []string{dataprovider.PermAny}
  753. user := dataprovider.User{
  754. Permissions: permissions,
  755. QuotaFiles: 1,
  756. HomeDir: "invalid_path",
  757. }
  758. fs, err := user.GetFilesystem("123")
  759. assert.NoError(t, err)
  760. connection := Connection{
  761. channel: &mockSSHChannel,
  762. netConn: client,
  763. User: user,
  764. fs: fs,
  765. }
  766. cmd := sshCommand{
  767. command: "git-receive-pack",
  768. connection: connection,
  769. args: []string{"/testrepo"},
  770. }
  771. err = cmd.rescanHomeDir()
  772. assert.Error(t, err, "scanning an invalid home dir must fail")
  773. }
  774. func TestGitVirtualFolders(t *testing.T) {
  775. permissions := make(map[string][]string)
  776. permissions["/"] = []string{dataprovider.PermAny}
  777. user := dataprovider.User{
  778. Permissions: permissions,
  779. HomeDir: os.TempDir(),
  780. }
  781. fs, err := user.GetFilesystem("123")
  782. assert.NoError(t, err)
  783. conn := Connection{
  784. User: user,
  785. fs: fs,
  786. }
  787. cmd := sshCommand{
  788. command: "git-receive-pack",
  789. connection: conn,
  790. args: []string{"/vdir"},
  791. }
  792. cmd.connection.User.VirtualFolders = append(cmd.connection.User.VirtualFolders, vfs.VirtualFolder{
  793. VirtualPath: "/vdir",
  794. MappedPath: os.TempDir(),
  795. })
  796. _, err = cmd.getSystemCommand()
  797. assert.EqualError(t, err, errUnsupportedConfig.Error())
  798. cmd.connection.User.VirtualFolders = nil
  799. cmd.connection.User.VirtualFolders = append(cmd.connection.User.VirtualFolders, vfs.VirtualFolder{
  800. VirtualPath: "/vdir",
  801. MappedPath: os.TempDir(),
  802. })
  803. cmd.args = []string{"/vdir/subdir"}
  804. _, err = cmd.getSystemCommand()
  805. assert.EqualError(t, err, errUnsupportedConfig.Error())
  806. cmd.args = []string{"/adir/subdir"}
  807. _, err = cmd.getSystemCommand()
  808. assert.NoError(t, err)
  809. }
  810. func TestRsyncOptions(t *testing.T) {
  811. permissions := make(map[string][]string)
  812. permissions["/"] = []string{dataprovider.PermAny}
  813. user := dataprovider.User{
  814. Permissions: permissions,
  815. HomeDir: os.TempDir(),
  816. }
  817. fs, err := user.GetFilesystem("123")
  818. assert.NoError(t, err)
  819. conn := Connection{
  820. User: user,
  821. fs: fs,
  822. }
  823. sshCmd := sshCommand{
  824. command: "rsync",
  825. connection: conn,
  826. args: []string{"--server", "-vlogDtprze.iLsfxC", ".", "/"},
  827. }
  828. cmd, err := sshCmd.getSystemCommand()
  829. assert.NoError(t, err)
  830. assert.True(t, utils.IsStringInSlice("--safe-links", cmd.cmd.Args),
  831. "--safe-links must be added if the user has the create symlinks permission")
  832. permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs,
  833. dataprovider.PermListItems, dataprovider.PermOverwrite, dataprovider.PermDelete, dataprovider.PermRename}
  834. user.Permissions = permissions
  835. fs, err = user.GetFilesystem("123")
  836. assert.NoError(t, err)
  837. conn = Connection{
  838. User: user,
  839. fs: fs,
  840. }
  841. sshCmd = sshCommand{
  842. command: "rsync",
  843. connection: conn,
  844. args: []string{"--server", "-vlogDtprze.iLsfxC", ".", "/"},
  845. }
  846. cmd, err = sshCmd.getSystemCommand()
  847. assert.NoError(t, err)
  848. assert.True(t, utils.IsStringInSlice("--munge-links", cmd.cmd.Args),
  849. "--munge-links must be added if the user has the create symlinks permission")
  850. sshCmd.connection.User.VirtualFolders = append(sshCmd.connection.User.VirtualFolders, vfs.VirtualFolder{
  851. VirtualPath: "/vdir",
  852. MappedPath: os.TempDir(),
  853. })
  854. _, err = sshCmd.getSystemCommand()
  855. assert.EqualError(t, err, errUnsupportedConfig.Error())
  856. }
  857. func TestSystemCommandErrors(t *testing.T) {
  858. buf := make([]byte, 65535)
  859. stdErrBuf := make([]byte, 65535)
  860. readErr := fmt.Errorf("test read error")
  861. writeErr := fmt.Errorf("test write error")
  862. mockSSHChannel := MockChannel{
  863. Buffer: bytes.NewBuffer(buf),
  864. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  865. ReadError: nil,
  866. WriteError: writeErr,
  867. }
  868. server, client := net.Pipe()
  869. defer func() {
  870. err := server.Close()
  871. assert.NoError(t, err)
  872. }()
  873. defer func() {
  874. err := client.Close()
  875. assert.NoError(t, err)
  876. }()
  877. permissions := make(map[string][]string)
  878. permissions["/"] = []string{dataprovider.PermAny}
  879. user := dataprovider.User{
  880. Permissions: permissions,
  881. HomeDir: os.TempDir(),
  882. }
  883. fs, err := user.GetFilesystem("123")
  884. assert.NoError(t, err)
  885. connection := Connection{
  886. channel: &mockSSHChannel,
  887. netConn: client,
  888. User: user,
  889. fs: fs,
  890. }
  891. sshCmd := sshCommand{
  892. command: "ls",
  893. connection: connection,
  894. args: []string{"/"},
  895. }
  896. systemCmd, err := sshCmd.getSystemCommand()
  897. assert.NoError(t, err)
  898. systemCmd.cmd.Dir = os.TempDir()
  899. // FIXME: the command completes but the fake client was unable to read the response
  900. // no error is reported in this case. We can see that the expected code is executed
  901. // reading the test coverage
  902. sshCmd.executeSystemCommand(systemCmd) //nolint:errcheck
  903. mockSSHChannel = MockChannel{
  904. Buffer: bytes.NewBuffer(buf),
  905. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  906. ReadError: readErr,
  907. WriteError: nil,
  908. }
  909. sshCmd.connection.channel = &mockSSHChannel
  910. transfer := Transfer{
  911. transferType: transferDownload,
  912. lock: new(sync.Mutex)}
  913. destBuff := make([]byte, 65535)
  914. dst := bytes.NewBuffer(destBuff)
  915. _, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel, 0)
  916. assert.EqualError(t, err, readErr.Error())
  917. mockSSHChannel = MockChannel{
  918. Buffer: bytes.NewBuffer(buf),
  919. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  920. ReadError: nil,
  921. WriteError: nil,
  922. }
  923. sshCmd.connection.channel = &mockSSHChannel
  924. _, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel, 1)
  925. assert.EqualError(t, err, errQuotaExceeded.Error())
  926. mockSSHChannel = MockChannel{
  927. Buffer: bytes.NewBuffer(buf),
  928. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  929. ReadError: nil,
  930. WriteError: nil,
  931. ShortWriteErr: true,
  932. }
  933. sshCmd.connection.channel = &mockSSHChannel
  934. _, err = transfer.copyFromReaderToWriter(sshCmd.connection.channel, dst, 0)
  935. assert.EqualError(t, err, io.ErrShortWrite.Error())
  936. }
  937. func TestTransferUpdateQuota(t *testing.T) {
  938. transfer := Transfer{
  939. transferType: transferUpload,
  940. bytesReceived: 123,
  941. lock: new(sync.Mutex)}
  942. transfer.TransferError(errors.New("fake error"))
  943. assert.False(t, transfer.updateQuota(1))
  944. }
  945. func TestGetConnectionInfo(t *testing.T) {
  946. c := ConnectionStatus{
  947. Username: "test_user",
  948. ConnectionID: "123",
  949. ClientVersion: "client",
  950. RemoteAddress: "127.0.0.1:1234",
  951. Protocol: protocolSSH,
  952. SSHCommand: "sha1sum /test_file.dat",
  953. }
  954. info := c.GetConnectionInfo()
  955. assert.Contains(t, info, "sha1sum /test_file.dat")
  956. }
  957. func TestSCPFileMode(t *testing.T) {
  958. mode := getFileModeAsString(0, true)
  959. assert.Equal(t, "0755", mode)
  960. mode = getFileModeAsString(0700, true)
  961. assert.Equal(t, "0700", mode)
  962. mode = getFileModeAsString(0750, true)
  963. assert.Equal(t, "0750", mode)
  964. mode = getFileModeAsString(0777, true)
  965. assert.Equal(t, "0777", mode)
  966. mode = getFileModeAsString(0640, false)
  967. assert.Equal(t, "0640", mode)
  968. mode = getFileModeAsString(0600, false)
  969. assert.Equal(t, "0600", mode)
  970. mode = getFileModeAsString(0, false)
  971. assert.Equal(t, "0644", mode)
  972. fileMode := uint32(0777)
  973. fileMode = fileMode | uint32(os.ModeSetgid)
  974. fileMode = fileMode | uint32(os.ModeSetuid)
  975. fileMode = fileMode | uint32(os.ModeSticky)
  976. mode = getFileModeAsString(os.FileMode(fileMode), false)
  977. assert.Equal(t, "7777", mode)
  978. fileMode = uint32(0644)
  979. fileMode = fileMode | uint32(os.ModeSetgid)
  980. mode = getFileModeAsString(os.FileMode(fileMode), false)
  981. assert.Equal(t, "4644", mode)
  982. fileMode = uint32(0600)
  983. fileMode = fileMode | uint32(os.ModeSetuid)
  984. mode = getFileModeAsString(os.FileMode(fileMode), false)
  985. assert.Equal(t, "2600", mode)
  986. fileMode = uint32(0044)
  987. fileMode = fileMode | uint32(os.ModeSticky)
  988. mode = getFileModeAsString(os.FileMode(fileMode), false)
  989. assert.Equal(t, "1044", mode)
  990. }
  991. func TestSCPParseUploadMessage(t *testing.T) {
  992. buf := make([]byte, 65535)
  993. stdErrBuf := make([]byte, 65535)
  994. mockSSHChannel := MockChannel{
  995. Buffer: bytes.NewBuffer(buf),
  996. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  997. ReadError: nil,
  998. }
  999. connection := Connection{
  1000. channel: &mockSSHChannel,
  1001. fs: vfs.NewOsFs("", os.TempDir(), nil),
  1002. }
  1003. scpCommand := scpCommand{
  1004. sshCommand: sshCommand{
  1005. command: "scp",
  1006. connection: connection,
  1007. args: []string{"-t", "/tmp"},
  1008. },
  1009. }
  1010. _, _, err := scpCommand.parseUploadMessage("invalid")
  1011. assert.Error(t, err, "parsing invalid upload message must fail")
  1012. _, _, err = scpCommand.parseUploadMessage("D0755 0")
  1013. assert.Error(t, err, "parsing incomplete upload message must fail")
  1014. _, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir")
  1015. assert.Error(t, err, "parsing upload message with invalid size must fail")
  1016. _, _, err = scpCommand.parseUploadMessage("D0755 0 ")
  1017. assert.Error(t, err, "parsing upload message with invalid name must fail")
  1018. }
  1019. func TestSCPProtocolMessages(t *testing.T) {
  1020. buf := make([]byte, 65535)
  1021. stdErrBuf := make([]byte, 65535)
  1022. readErr := fmt.Errorf("test read error")
  1023. writeErr := fmt.Errorf("test write error")
  1024. mockSSHChannel := MockChannel{
  1025. Buffer: bytes.NewBuffer(buf),
  1026. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1027. ReadError: readErr,
  1028. WriteError: writeErr,
  1029. }
  1030. connection := Connection{
  1031. channel: &mockSSHChannel,
  1032. }
  1033. scpCommand := scpCommand{
  1034. sshCommand: sshCommand{
  1035. command: "scp",
  1036. connection: connection,
  1037. args: []string{"-t", "/tmp"},
  1038. },
  1039. }
  1040. _, err := scpCommand.readProtocolMessage()
  1041. assert.EqualError(t, err, readErr.Error())
  1042. err = scpCommand.sendConfirmationMessage()
  1043. assert.EqualError(t, err, writeErr.Error())
  1044. err = scpCommand.sendProtocolMessage("E\n")
  1045. assert.EqualError(t, err, writeErr.Error())
  1046. _, err = scpCommand.getNextUploadProtocolMessage()
  1047. assert.EqualError(t, err, readErr.Error())
  1048. mockSSHChannel = MockChannel{
  1049. Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")),
  1050. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1051. ReadError: nil,
  1052. WriteError: writeErr,
  1053. }
  1054. scpCommand.connection.channel = &mockSSHChannel
  1055. _, err = scpCommand.getNextUploadProtocolMessage()
  1056. assert.EqualError(t, err, writeErr.Error())
  1057. respBuffer := []byte{0x02}
  1058. protocolErrorMsg := "protocol error msg"
  1059. respBuffer = append(respBuffer, protocolErrorMsg...)
  1060. respBuffer = append(respBuffer, 0x0A)
  1061. mockSSHChannel = MockChannel{
  1062. Buffer: bytes.NewBuffer(respBuffer),
  1063. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1064. ReadError: nil,
  1065. WriteError: nil,
  1066. }
  1067. scpCommand.connection.channel = &mockSSHChannel
  1068. err = scpCommand.readConfirmationMessage()
  1069. if assert.Error(t, err) {
  1070. assert.Equal(t, protocolErrorMsg, err.Error())
  1071. }
  1072. }
  1073. func TestSCPTestDownloadProtocolMessages(t *testing.T) {
  1074. buf := make([]byte, 65535)
  1075. stdErrBuf := make([]byte, 65535)
  1076. readErr := fmt.Errorf("test read error")
  1077. writeErr := fmt.Errorf("test write error")
  1078. mockSSHChannel := MockChannel{
  1079. Buffer: bytes.NewBuffer(buf),
  1080. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1081. ReadError: readErr,
  1082. WriteError: writeErr,
  1083. }
  1084. connection := Connection{
  1085. channel: &mockSSHChannel,
  1086. }
  1087. scpCommand := scpCommand{
  1088. sshCommand: sshCommand{
  1089. command: "scp",
  1090. connection: connection,
  1091. args: []string{"-f", "-p", "/tmp"},
  1092. },
  1093. }
  1094. path := "testDir"
  1095. err := os.Mkdir(path, 0777)
  1096. assert.NoError(t, err)
  1097. stat, err := os.Stat(path)
  1098. assert.NoError(t, err)
  1099. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  1100. assert.EqualError(t, err, writeErr.Error())
  1101. mockSSHChannel = MockChannel{
  1102. Buffer: bytes.NewBuffer(buf),
  1103. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1104. ReadError: readErr,
  1105. WriteError: nil,
  1106. }
  1107. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  1108. assert.EqualError(t, err, readErr.Error())
  1109. mockSSHChannel = MockChannel{
  1110. Buffer: bytes.NewBuffer(buf),
  1111. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1112. ReadError: readErr,
  1113. WriteError: writeErr,
  1114. }
  1115. scpCommand.args = []string{"-f", "/tmp"}
  1116. scpCommand.connection.channel = &mockSSHChannel
  1117. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  1118. assert.EqualError(t, err, writeErr.Error())
  1119. mockSSHChannel = MockChannel{
  1120. Buffer: bytes.NewBuffer(buf),
  1121. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1122. ReadError: readErr,
  1123. WriteError: nil,
  1124. }
  1125. scpCommand.connection.channel = &mockSSHChannel
  1126. err = scpCommand.sendDownloadProtocolMessages(path, stat)
  1127. assert.EqualError(t, err, readErr.Error())
  1128. err = os.Remove(path)
  1129. assert.NoError(t, err)
  1130. }
  1131. func TestSCPCommandHandleErrors(t *testing.T) {
  1132. buf := make([]byte, 65535)
  1133. stdErrBuf := make([]byte, 65535)
  1134. readErr := fmt.Errorf("test read error")
  1135. writeErr := fmt.Errorf("test write error")
  1136. mockSSHChannel := MockChannel{
  1137. Buffer: bytes.NewBuffer(buf),
  1138. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1139. ReadError: readErr,
  1140. WriteError: writeErr,
  1141. }
  1142. server, client := net.Pipe()
  1143. defer func() {
  1144. err := server.Close()
  1145. assert.NoError(t, err)
  1146. }()
  1147. defer func() {
  1148. err := client.Close()
  1149. assert.NoError(t, err)
  1150. }()
  1151. connection := Connection{
  1152. channel: &mockSSHChannel,
  1153. netConn: client,
  1154. }
  1155. scpCommand := scpCommand{
  1156. sshCommand: sshCommand{
  1157. command: "scp",
  1158. connection: connection,
  1159. args: []string{"-f", "/tmp"},
  1160. },
  1161. }
  1162. err := scpCommand.handle()
  1163. assert.EqualError(t, err, readErr.Error())
  1164. scpCommand.args = []string{"-i", "/tmp"}
  1165. err = scpCommand.handle()
  1166. assert.Error(t, err, "invalid scp command must fail")
  1167. }
  1168. func TestSCPErrorsMockFs(t *testing.T) {
  1169. errFake := errors.New("fake error")
  1170. fs := newMockOsFs(errFake, errFake, false, "1234", os.TempDir())
  1171. u := dataprovider.User{}
  1172. u.Username = "test"
  1173. u.Permissions = make(map[string][]string)
  1174. u.Permissions["/"] = []string{dataprovider.PermAny}
  1175. u.HomeDir = os.TempDir()
  1176. buf := make([]byte, 65535)
  1177. stdErrBuf := make([]byte, 65535)
  1178. mockSSHChannel := MockChannel{
  1179. Buffer: bytes.NewBuffer(buf),
  1180. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1181. }
  1182. server, client := net.Pipe()
  1183. defer func() {
  1184. err := server.Close()
  1185. assert.NoError(t, err)
  1186. }()
  1187. defer func() {
  1188. err := client.Close()
  1189. assert.NoError(t, err)
  1190. }()
  1191. connection := Connection{
  1192. channel: &mockSSHChannel,
  1193. netConn: client,
  1194. fs: fs,
  1195. User: u,
  1196. }
  1197. scpCommand := scpCommand{
  1198. sshCommand: sshCommand{
  1199. command: "scp",
  1200. connection: connection,
  1201. args: []string{"-r", "-t", "/tmp"},
  1202. },
  1203. }
  1204. err := scpCommand.handleUpload("test", 0)
  1205. assert.EqualError(t, err, errFake.Error())
  1206. testfile := filepath.Join(u.HomeDir, "testfile")
  1207. err = ioutil.WriteFile(testfile, []byte("test"), 0666)
  1208. assert.NoError(t, err)
  1209. stat, err := os.Stat(u.HomeDir)
  1210. assert.NoError(t, err)
  1211. err = scpCommand.handleRecursiveDownload(u.HomeDir, stat)
  1212. assert.EqualError(t, err, errFake.Error())
  1213. scpCommand.sshCommand.connection.fs = newMockOsFs(errFake, nil, true, "123", os.TempDir())
  1214. err = scpCommand.handleUpload(filepath.Base(testfile), 0)
  1215. assert.EqualError(t, err, errFake.Error())
  1216. err = scpCommand.handleUploadFile(testfile, testfile, 0, false, 4, false)
  1217. assert.NoError(t, err)
  1218. err = os.Remove(testfile)
  1219. assert.NoError(t, err)
  1220. }
  1221. func TestSCPRecursiveDownloadErrors(t *testing.T) {
  1222. buf := make([]byte, 65535)
  1223. stdErrBuf := make([]byte, 65535)
  1224. readErr := fmt.Errorf("test read error")
  1225. writeErr := fmt.Errorf("test write error")
  1226. mockSSHChannel := MockChannel{
  1227. Buffer: bytes.NewBuffer(buf),
  1228. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1229. ReadError: readErr,
  1230. WriteError: writeErr,
  1231. }
  1232. server, client := net.Pipe()
  1233. defer func() {
  1234. err := server.Close()
  1235. assert.NoError(t, err)
  1236. }()
  1237. defer func() {
  1238. err := client.Close()
  1239. assert.NoError(t, err)
  1240. }()
  1241. connection := Connection{
  1242. channel: &mockSSHChannel,
  1243. netConn: client,
  1244. fs: vfs.NewOsFs("123", os.TempDir(), nil),
  1245. }
  1246. scpCommand := scpCommand{
  1247. sshCommand: sshCommand{
  1248. command: "scp",
  1249. connection: connection,
  1250. args: []string{"-r", "-f", "/tmp"},
  1251. },
  1252. }
  1253. path := "testDir"
  1254. err := os.Mkdir(path, 0777)
  1255. assert.NoError(t, err)
  1256. stat, err := os.Stat(path)
  1257. assert.NoError(t, err)
  1258. err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
  1259. assert.EqualError(t, err, writeErr.Error())
  1260. mockSSHChannel = MockChannel{
  1261. Buffer: bytes.NewBuffer(buf),
  1262. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1263. ReadError: nil,
  1264. WriteError: nil,
  1265. }
  1266. scpCommand.connection.channel = &mockSSHChannel
  1267. err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
  1268. assert.Error(t, err, "recursive upload download must fail for a non existing dir")
  1269. err = os.Remove(path)
  1270. assert.NoError(t, err)
  1271. }
  1272. func TestSCPRecursiveUploadErrors(t *testing.T) {
  1273. buf := make([]byte, 65535)
  1274. stdErrBuf := make([]byte, 65535)
  1275. readErr := fmt.Errorf("test read error")
  1276. writeErr := fmt.Errorf("test write error")
  1277. mockSSHChannel := MockChannel{
  1278. Buffer: bytes.NewBuffer(buf),
  1279. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1280. ReadError: readErr,
  1281. WriteError: writeErr,
  1282. }
  1283. connection := Connection{
  1284. channel: &mockSSHChannel,
  1285. }
  1286. scpCommand := scpCommand{
  1287. sshCommand: sshCommand{
  1288. command: "scp",
  1289. connection: connection,
  1290. args: []string{"-r", "-t", "/tmp"},
  1291. },
  1292. }
  1293. err := scpCommand.handleRecursiveUpload()
  1294. assert.Error(t, err, "recursive upload must fail, we send a fake error message")
  1295. mockSSHChannel = MockChannel{
  1296. Buffer: bytes.NewBuffer(buf),
  1297. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1298. ReadError: readErr,
  1299. WriteError: nil,
  1300. }
  1301. scpCommand.connection.channel = &mockSSHChannel
  1302. err = scpCommand.handleRecursiveUpload()
  1303. assert.Error(t, err, "recursive upload must fail, we send a fake error message")
  1304. }
  1305. func TestSCPCreateDirs(t *testing.T) {
  1306. buf := make([]byte, 65535)
  1307. stdErrBuf := make([]byte, 65535)
  1308. u := dataprovider.User{}
  1309. u.HomeDir = "home_rel_path"
  1310. u.Username = "test"
  1311. u.Permissions = make(map[string][]string)
  1312. u.Permissions["/"] = []string{dataprovider.PermAny}
  1313. mockSSHChannel := MockChannel{
  1314. Buffer: bytes.NewBuffer(buf),
  1315. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1316. ReadError: nil,
  1317. WriteError: nil,
  1318. }
  1319. fs, err := u.GetFilesystem("123")
  1320. assert.NoError(t, err)
  1321. connection := Connection{
  1322. User: u,
  1323. channel: &mockSSHChannel,
  1324. fs: fs,
  1325. }
  1326. scpCommand := scpCommand{
  1327. sshCommand: sshCommand{
  1328. command: "scp",
  1329. connection: connection,
  1330. args: []string{"-r", "-t", "/tmp"},
  1331. },
  1332. }
  1333. err = scpCommand.handleCreateDir("invalid_dir")
  1334. assert.Error(t, err, "create invalid dir must fail")
  1335. }
  1336. func TestSCPDownloadFileData(t *testing.T) {
  1337. testfile := "testfile"
  1338. buf := make([]byte, 65535)
  1339. readErr := fmt.Errorf("test read error")
  1340. writeErr := fmt.Errorf("test write error")
  1341. stdErrBuf := make([]byte, 65535)
  1342. mockSSHChannelReadErr := MockChannel{
  1343. Buffer: bytes.NewBuffer(buf),
  1344. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1345. ReadError: readErr,
  1346. WriteError: nil,
  1347. }
  1348. mockSSHChannelWriteErr := MockChannel{
  1349. Buffer: bytes.NewBuffer(buf),
  1350. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1351. ReadError: nil,
  1352. WriteError: writeErr,
  1353. }
  1354. connection := Connection{
  1355. channel: &mockSSHChannelReadErr,
  1356. }
  1357. scpCommand := scpCommand{
  1358. sshCommand: sshCommand{
  1359. command: "scp",
  1360. connection: connection,
  1361. args: []string{"-r", "-f", "/tmp"},
  1362. },
  1363. }
  1364. err := ioutil.WriteFile(testfile, []byte("test"), 0666)
  1365. assert.NoError(t, err)
  1366. stat, err := os.Stat(testfile)
  1367. assert.NoError(t, err)
  1368. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  1369. assert.EqualError(t, err, readErr.Error())
  1370. scpCommand.connection.channel = &mockSSHChannelWriteErr
  1371. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  1372. assert.EqualError(t, err, writeErr.Error())
  1373. scpCommand.args = []string{"-r", "-p", "-f", "/tmp"}
  1374. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  1375. assert.EqualError(t, err, writeErr.Error())
  1376. scpCommand.connection.channel = &mockSSHChannelReadErr
  1377. err = scpCommand.sendDownloadFileData(testfile, stat, nil)
  1378. assert.EqualError(t, err, readErr.Error())
  1379. err = os.Remove(testfile)
  1380. assert.NoError(t, err)
  1381. }
  1382. func TestSCPUploadFiledata(t *testing.T) {
  1383. testfile := "testfile"
  1384. buf := make([]byte, 65535)
  1385. stdErrBuf := make([]byte, 65535)
  1386. readErr := fmt.Errorf("test read error")
  1387. writeErr := fmt.Errorf("test write error")
  1388. mockSSHChannel := MockChannel{
  1389. Buffer: bytes.NewBuffer(buf),
  1390. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1391. ReadError: readErr,
  1392. WriteError: writeErr,
  1393. }
  1394. connection := Connection{
  1395. User: dataprovider.User{
  1396. Username: "testuser",
  1397. },
  1398. protocol: protocolSCP,
  1399. channel: &mockSSHChannel,
  1400. fs: vfs.NewOsFs("", os.TempDir(), nil),
  1401. }
  1402. scpCommand := scpCommand{
  1403. sshCommand: sshCommand{
  1404. command: "scp",
  1405. connection: connection,
  1406. args: []string{"-r", "-t", "/tmp"},
  1407. },
  1408. }
  1409. file, err := os.Create(testfile)
  1410. assert.NoError(t, err)
  1411. transfer := Transfer{
  1412. file: file,
  1413. path: file.Name(),
  1414. start: time.Now(),
  1415. bytesSent: 0,
  1416. bytesReceived: 0,
  1417. user: scpCommand.connection.User,
  1418. connectionID: "",
  1419. transferType: transferDownload,
  1420. lastActivity: time.Now(),
  1421. isNewFile: true,
  1422. protocol: connection.protocol,
  1423. transferError: nil,
  1424. isFinished: false,
  1425. minWriteOffset: 0,
  1426. lock: new(sync.Mutex),
  1427. }
  1428. addTransfer(&transfer)
  1429. err = scpCommand.getUploadFileData(2, &transfer)
  1430. assert.Error(t, err, "upload must fail, we send a fake write error message")
  1431. mockSSHChannel = MockChannel{
  1432. Buffer: bytes.NewBuffer(buf),
  1433. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1434. ReadError: readErr,
  1435. WriteError: nil,
  1436. }
  1437. scpCommand.connection.channel = &mockSSHChannel
  1438. file, err = os.Create(testfile)
  1439. assert.NoError(t, err)
  1440. transfer.file = file
  1441. transfer.isFinished = false
  1442. addTransfer(&transfer)
  1443. err = scpCommand.getUploadFileData(2, &transfer)
  1444. assert.Error(t, err, "upload must fail, we send a fake read error message")
  1445. respBuffer := []byte("12")
  1446. respBuffer = append(respBuffer, 0x02)
  1447. mockSSHChannel = MockChannel{
  1448. Buffer: bytes.NewBuffer(respBuffer),
  1449. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1450. ReadError: nil,
  1451. WriteError: nil,
  1452. }
  1453. scpCommand.connection.channel = &mockSSHChannel
  1454. file, err = os.Create(testfile)
  1455. assert.NoError(t, err)
  1456. transfer.file = file
  1457. transfer.isFinished = false
  1458. addTransfer(&transfer)
  1459. err = scpCommand.getUploadFileData(2, &transfer)
  1460. assert.Error(t, err, "upload must fail, we have not enough data to read")
  1461. // the file is already closed so we have an error on trasfer closing
  1462. mockSSHChannel = MockChannel{
  1463. Buffer: bytes.NewBuffer(buf),
  1464. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1465. ReadError: nil,
  1466. WriteError: nil,
  1467. }
  1468. addTransfer(&transfer)
  1469. err = scpCommand.getUploadFileData(0, &transfer)
  1470. assert.EqualError(t, err, errTransferClosed.Error())
  1471. mockSSHChannel = MockChannel{
  1472. Buffer: bytes.NewBuffer(buf),
  1473. StdErrBuffer: bytes.NewBuffer(stdErrBuf),
  1474. ReadError: nil,
  1475. WriteError: nil,
  1476. }
  1477. addTransfer(&transfer)
  1478. err = scpCommand.getUploadFileData(2, &transfer)
  1479. assert.True(t, errors.Is(err, os.ErrClosed))
  1480. err = os.Remove(testfile)
  1481. assert.NoError(t, err)
  1482. }
  1483. func TestUploadError(t *testing.T) {
  1484. oldUploadMode := uploadMode
  1485. uploadMode = uploadModeAtomic
  1486. connection := Connection{
  1487. User: dataprovider.User{
  1488. Username: "testuser",
  1489. },
  1490. protocol: protocolSCP,
  1491. }
  1492. testfile := "testfile"
  1493. fileTempName := "temptestfile"
  1494. file, err := os.Create(fileTempName)
  1495. assert.NoError(t, err)
  1496. transfer := Transfer{
  1497. file: file,
  1498. path: testfile,
  1499. start: time.Now(),
  1500. bytesSent: 0,
  1501. bytesReceived: 100,
  1502. user: connection.User,
  1503. connectionID: "",
  1504. transferType: transferUpload,
  1505. lastActivity: time.Now(),
  1506. isNewFile: true,
  1507. protocol: connection.protocol,
  1508. transferError: nil,
  1509. isFinished: false,
  1510. minWriteOffset: 0,
  1511. lock: new(sync.Mutex),
  1512. }
  1513. addTransfer(&transfer)
  1514. errFake := errors.New("fake error")
  1515. transfer.TransferError(errFake)
  1516. err = transfer.Close()
  1517. assert.EqualError(t, err, errFake.Error())
  1518. assert.Equal(t, int64(0), transfer.bytesReceived)
  1519. assert.NoFileExists(t, testfile)
  1520. assert.NoFileExists(t, fileTempName)
  1521. uploadMode = oldUploadMode
  1522. }
  1523. func TestConnectionStatusStruct(t *testing.T) {
  1524. var transfers []connectionTransfer
  1525. transferUL := connectionTransfer{
  1526. OperationType: operationUpload,
  1527. StartTime: utils.GetTimeAsMsSinceEpoch(time.Now()),
  1528. Size: 123,
  1529. LastActivity: utils.GetTimeAsMsSinceEpoch(time.Now()),
  1530. Path: "/test.upload",
  1531. }
  1532. transferDL := connectionTransfer{
  1533. OperationType: operationDownload,
  1534. StartTime: utils.GetTimeAsMsSinceEpoch(time.Now()),
  1535. Size: 123,
  1536. LastActivity: utils.GetTimeAsMsSinceEpoch(time.Now()),
  1537. Path: "/test.download",
  1538. }
  1539. transfers = append(transfers, transferUL)
  1540. transfers = append(transfers, transferDL)
  1541. c := ConnectionStatus{
  1542. Username: "test",
  1543. ConnectionID: "123",
  1544. ClientVersion: "fakeClient-1.0.0",
  1545. RemoteAddress: "127.0.0.1:1234",
  1546. ConnectionTime: utils.GetTimeAsMsSinceEpoch(time.Now()),
  1547. LastActivity: utils.GetTimeAsMsSinceEpoch(time.Now()),
  1548. Protocol: "SFTP",
  1549. Transfers: transfers,
  1550. }
  1551. durationString := c.GetConnectionDuration()
  1552. assert.NotEqual(t, 0, len(durationString))
  1553. transfersString := c.GetTransfersAsString()
  1554. assert.NotEqual(t, 0, len(transfersString))
  1555. connInfo := c.GetConnectionInfo()
  1556. assert.NotEqual(t, 0, len(connInfo))
  1557. }
  1558. func TestProxyProtocolVersion(t *testing.T) {
  1559. c := Configuration{
  1560. ProxyProtocol: 1,
  1561. }
  1562. proxyListener, err := c.getProxyListener(nil)
  1563. assert.NoError(t, err)
  1564. assert.Nil(t, proxyListener.Policy)
  1565. c.ProxyProtocol = 2
  1566. proxyListener, _ = c.getProxyListener(nil)
  1567. assert.NoError(t, err)
  1568. assert.NotNil(t, proxyListener.Policy)
  1569. c.ProxyProtocol = 1
  1570. c.ProxyAllowed = []string{"invalid"}
  1571. _, err = c.getProxyListener(nil)
  1572. assert.Error(t, err)
  1573. c.ProxyProtocol = 2
  1574. _, err = c.getProxyListener(nil)
  1575. assert.Error(t, err)
  1576. }
  1577. func TestLoadHostKeys(t *testing.T) {
  1578. c := Configuration{}
  1579. c.HostKeys = []string{".", "missing file"}
  1580. err := c.checkAndLoadHostKeys("..", &ssh.ServerConfig{})
  1581. assert.Error(t, err)
  1582. testfile := filepath.Join(os.TempDir(), "invalidkey")
  1583. err = ioutil.WriteFile(testfile, []byte("some bytes"), 0666)
  1584. assert.NoError(t, err)
  1585. c.HostKeys = []string{testfile}
  1586. err = c.checkAndLoadHostKeys("..", &ssh.ServerConfig{})
  1587. assert.Error(t, err)
  1588. err = os.Remove(testfile)
  1589. assert.NoError(t, err)
  1590. }
  1591. func TestCertCheckerInitErrors(t *testing.T) {
  1592. c := Configuration{}
  1593. c.TrustedUserCAKeys = []string{".", "missing file"}
  1594. err := c.initializeCertChecker("")
  1595. assert.Error(t, err)
  1596. testfile := filepath.Join(os.TempDir(), "invalidkey")
  1597. err = ioutil.WriteFile(testfile, []byte("some bytes"), 0666)
  1598. assert.NoError(t, err)
  1599. c.TrustedUserCAKeys = []string{testfile}
  1600. err = c.initializeCertChecker("")
  1601. assert.Error(t, err)
  1602. err = os.Remove(testfile)
  1603. assert.NoError(t, err)
  1604. }