middleware.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package httpd
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "github.com/go-chi/jwtauth/v5"
  7. "github.com/lestrrat-go/jwx/jwt"
  8. "github.com/drakkan/sftpgo/logger"
  9. "github.com/drakkan/sftpgo/utils"
  10. )
  11. var connAddrKey = &contextKey{"connection address"}
  12. type contextKey struct {
  13. name string
  14. }
  15. func (k *contextKey) String() string {
  16. return "context value " + k.name
  17. }
  18. func saveConnectionAddress(next http.Handler) http.Handler {
  19. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  20. ctx := context.WithValue(r.Context(), connAddrKey, r.RemoteAddr)
  21. next.ServeHTTP(w, r.WithContext(ctx))
  22. })
  23. }
  24. func jwtAuthenticator(next http.Handler) http.Handler {
  25. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  26. token, _, err := jwtauth.FromContext(r.Context())
  27. if err != nil || token == nil {
  28. logger.Debug(logSender, "", "error getting jwt token: %v", err)
  29. sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
  30. return
  31. }
  32. err = jwt.Validate(token)
  33. if err != nil {
  34. logger.Debug(logSender, "", "error validating jwt token: %v", err)
  35. sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
  36. return
  37. }
  38. if !utils.IsStringInSlice(tokenAudienceAPI, token.Audience()) {
  39. logger.Debug(logSender, "", "the token audience is not valid for API usage")
  40. sendAPIResponse(w, r, nil, "Your token audience is not valid", http.StatusUnauthorized)
  41. return
  42. }
  43. if isTokenInvalidated(r) {
  44. logger.Debug(logSender, "", "the token has been invalidated")
  45. sendAPIResponse(w, r, nil, "Your token is no longer valid", http.StatusUnauthorized)
  46. return
  47. }
  48. // Token is authenticated, pass it through
  49. next.ServeHTTP(w, r)
  50. })
  51. }
  52. func jwtAuthenticatorWeb(next http.Handler) http.Handler {
  53. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  54. token, _, err := jwtauth.FromContext(r.Context())
  55. if err != nil || token == nil {
  56. logger.Debug(logSender, "", "error getting web jwt token: %v", err)
  57. http.Redirect(w, r, webLoginPath, http.StatusFound)
  58. return
  59. }
  60. err = jwt.Validate(token)
  61. if err != nil {
  62. logger.Debug(logSender, "", "error validating web jwt token: %v", err)
  63. http.Redirect(w, r, webLoginPath, http.StatusFound)
  64. return
  65. }
  66. if !utils.IsStringInSlice(tokenAudienceWeb, token.Audience()) {
  67. logger.Debug(logSender, "", "the token audience is not valid for Web usage")
  68. http.Redirect(w, r, webLoginPath, http.StatusFound)
  69. return
  70. }
  71. if isTokenInvalidated(r) {
  72. logger.Debug(logSender, "", "the token has been invalidated")
  73. http.Redirect(w, r, webLoginPath, http.StatusFound)
  74. return
  75. }
  76. // Token is authenticated, pass it through
  77. next.ServeHTTP(w, r)
  78. })
  79. }
  80. func checkPerm(perm string) func(next http.Handler) http.Handler {
  81. return func(next http.Handler) http.Handler {
  82. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  83. _, claims, err := jwtauth.FromContext(r.Context())
  84. if err != nil {
  85. if isWebAdminRequest(r) {
  86. renderBadRequestPage(w, r, err)
  87. } else {
  88. sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  89. }
  90. return
  91. }
  92. tokenClaims := jwtTokenClaims{}
  93. tokenClaims.Decode(claims)
  94. if !tokenClaims.hasPerm(perm) {
  95. if isWebAdminRequest(r) {
  96. renderForbiddenPage(w, r, "You don't have permission for this action")
  97. } else {
  98. sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  99. }
  100. return
  101. }
  102. next.ServeHTTP(w, r)
  103. })
  104. }
  105. }
  106. func verifyCSRFHeader(next http.Handler) http.Handler {
  107. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  108. tokenString := r.Header.Get(csrfHeaderToken)
  109. token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
  110. if err != nil || token == nil {
  111. logger.Debug(logSender, "", "error validating CSRF header: %v", err)
  112. sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
  113. return
  114. }
  115. if !utils.IsStringInSlice(tokenAudienceCSRF, token.Audience()) {
  116. logger.Debug(logSender, "", "error validating CSRF header audience")
  117. sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
  118. return
  119. }
  120. next.ServeHTTP(w, r)
  121. })
  122. }