1
0

oauth2_test.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "encoding/json"
  17. "testing"
  18. "time"
  19. "github.com/rs/xid"
  20. sdkkms "github.com/sftpgo/sdk/kms"
  21. "github.com/stretchr/testify/assert"
  22. "github.com/stretchr/testify/require"
  23. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  24. "github.com/drakkan/sftpgo/v2/internal/kms"
  25. "github.com/drakkan/sftpgo/v2/internal/util"
  26. )
  27. func TestMemoryOAuth2Manager(t *testing.T) {
  28. mgr := newOAuth2Manager(0)
  29. m, ok := mgr.(*memoryOAuth2Manager)
  30. require.True(t, ok)
  31. require.Len(t, m.pendingAuths, 0)
  32. _, err := m.getPendingAuth(xid.New().String())
  33. require.Error(t, err)
  34. assert.Contains(t, err.Error(), "no auth request found")
  35. auth := newOAuth2PendingAuth(1, "https://...", "cid", kms.NewPlainSecret("mysecret"))
  36. m.addPendingAuth(auth)
  37. require.Len(t, m.pendingAuths, 1)
  38. a, err := m.getPendingAuth(auth.State)
  39. assert.NoError(t, err)
  40. assert.Equal(t, auth.State, a.State)
  41. assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus())
  42. m.removePendingAuth(auth.State)
  43. _, err = m.getPendingAuth(auth.State)
  44. require.Error(t, err)
  45. assert.Contains(t, err.Error(), "no auth request found")
  46. require.Len(t, m.pendingAuths, 0)
  47. state := xid.New().String()
  48. auth = oauth2PendingAuth{
  49. State: state,
  50. Provider: 1,
  51. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
  52. }
  53. m.addPendingAuth(auth)
  54. auth = oauth2PendingAuth{
  55. State: xid.New().String(),
  56. Provider: 1,
  57. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  58. }
  59. m.addPendingAuth(auth)
  60. require.Len(t, m.pendingAuths, 2)
  61. _, err = m.getPendingAuth(auth.State)
  62. require.Error(t, err)
  63. assert.Contains(t, err.Error(), "auth request is too old")
  64. m.cleanup()
  65. require.Len(t, m.pendingAuths, 1)
  66. m.removePendingAuth(state)
  67. require.Len(t, m.pendingAuths, 0)
  68. }
  69. func TestDbOAuth2Manager(t *testing.T) {
  70. if !isSharedProviderSupported() {
  71. t.Skip("this test it is not available with this provider")
  72. }
  73. mgr := newOAuth2Manager(1)
  74. m, ok := mgr.(*dbOAuth2Manager)
  75. require.True(t, ok)
  76. _, err := m.getPendingAuth(xid.New().String())
  77. require.Error(t, err)
  78. auth := newOAuth2PendingAuth(1, "https://...", "client_id", kms.NewPlainSecret("my db secret"))
  79. m.addPendingAuth(auth)
  80. a, err := m.getPendingAuth(auth.State)
  81. assert.NoError(t, err)
  82. assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus())
  83. session, err := dataprovider.GetSharedSession(auth.State)
  84. assert.NoError(t, err)
  85. authReq := oauth2PendingAuth{}
  86. err = json.Unmarshal(session.Data.([]byte), &authReq)
  87. assert.NoError(t, err)
  88. assert.Equal(t, sdkkms.SecretStatusSecretBox, authReq.ClientSecret.GetStatus())
  89. m.cleanup()
  90. _, err = m.getPendingAuth(auth.State)
  91. assert.NoError(t, err)
  92. m.removePendingAuth(auth.State)
  93. _, err = m.getPendingAuth(auth.State)
  94. assert.Error(t, err)
  95. auth = oauth2PendingAuth{
  96. State: xid.New().String(),
  97. Provider: 1,
  98. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  99. ClientSecret: kms.NewPlainSecret("db secret"),
  100. }
  101. m.addPendingAuth(auth)
  102. _, err = m.getPendingAuth(auth.State)
  103. assert.Error(t, err)
  104. _, err = dataprovider.GetSharedSession(auth.State)
  105. assert.NoError(t, err)
  106. m.cleanup()
  107. _, err = dataprovider.GetSharedSession(auth.State)
  108. assert.Error(t, err)
  109. _, err = m.decodePendingAuthData("not a byte array")
  110. require.Error(t, err)
  111. assert.Contains(t, err.Error(), "invalid auth request data")
  112. _, err = m.decodePendingAuthData([]byte("{not a json"))
  113. require.Error(t, err)
  114. // adding a request with a non plain secret will fail
  115. auth = oauth2PendingAuth{
  116. State: xid.New().String(),
  117. Provider: 1,
  118. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  119. ClientSecret: kms.NewPlainSecret("db secret"),
  120. }
  121. auth.ClientSecret.SetStatus(sdkkms.SecretStatusSecretBox)
  122. m.addPendingAuth(auth)
  123. _, err = dataprovider.GetSharedSession(auth.State)
  124. assert.Error(t, err)
  125. asJSON, err := json.Marshal(auth)
  126. assert.NoError(t, err)
  127. _, err = m.decodePendingAuthData(asJSON)
  128. assert.Error(t, err)
  129. }