server_test.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. //go:build glidertests
  2. package ssh
  3. import (
  4. "bytes"
  5. "context"
  6. "io"
  7. "testing"
  8. "time"
  9. )
  10. func TestAddHostKey(t *testing.T) {
  11. s := Server{}
  12. signer, err := generateSigner()
  13. if err != nil {
  14. t.Fatal(err)
  15. }
  16. s.AddHostKey(signer)
  17. if len(s.HostSigners) != 1 {
  18. t.Fatal("Key was not properly added")
  19. }
  20. signer, err = generateSigner()
  21. if err != nil {
  22. t.Fatal(err)
  23. }
  24. s.AddHostKey(signer)
  25. if len(s.HostSigners) != 1 {
  26. t.Fatal("Key was not properly replaced")
  27. }
  28. }
  29. func TestServerShutdown(t *testing.T) {
  30. l := newLocalListener()
  31. testBytes := []byte("Hello world\n")
  32. s := &Server{
  33. Handler: func(s Session) {
  34. s.Write(testBytes)
  35. time.Sleep(50 * time.Millisecond)
  36. },
  37. }
  38. go func() {
  39. err := s.Serve(l)
  40. if err != nil && err != ErrServerClosed {
  41. t.Fatal(err)
  42. }
  43. }()
  44. sessDone := make(chan struct{})
  45. sess, _, cleanup := newClientSession(t, l.Addr().String(), nil)
  46. go func() {
  47. defer cleanup()
  48. defer close(sessDone)
  49. var stdout bytes.Buffer
  50. sess.Stdout = &stdout
  51. if err := sess.Run(""); err != nil {
  52. t.Fatal(err)
  53. }
  54. if !bytes.Equal(stdout.Bytes(), testBytes) {
  55. t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes())
  56. }
  57. }()
  58. srvDone := make(chan struct{})
  59. go func() {
  60. defer close(srvDone)
  61. err := s.Shutdown(context.Background())
  62. if err != nil {
  63. t.Fatal(err)
  64. }
  65. }()
  66. timeout := time.After(2 * time.Second)
  67. select {
  68. case <-timeout:
  69. t.Fatal("timeout")
  70. return
  71. case <-srvDone:
  72. // TODO: add timeout for sessDone
  73. <-sessDone
  74. return
  75. }
  76. }
  77. func TestServerClose(t *testing.T) {
  78. l := newLocalListener()
  79. s := &Server{
  80. Handler: func(s Session) {
  81. time.Sleep(5 * time.Second)
  82. },
  83. }
  84. go func() {
  85. err := s.Serve(l)
  86. if err != nil && err != ErrServerClosed {
  87. t.Fatal(err)
  88. }
  89. }()
  90. clientDoneChan := make(chan struct{})
  91. closeDoneChan := make(chan struct{})
  92. sess, _, cleanup := newClientSession(t, l.Addr().String(), nil)
  93. go func() {
  94. defer cleanup()
  95. defer close(clientDoneChan)
  96. <-closeDoneChan
  97. if err := sess.Run(""); err != nil && err != io.EOF {
  98. t.Fatal(err)
  99. }
  100. }()
  101. go func() {
  102. err := s.Close()
  103. if err != nil {
  104. t.Fatal(err)
  105. }
  106. close(closeDoneChan)
  107. }()
  108. timeout := time.After(100 * time.Millisecond)
  109. select {
  110. case <-timeout:
  111. t.Error("timeout")
  112. return
  113. case <-s.getDoneChan():
  114. <-clientDoneChan
  115. return
  116. }
  117. }