middleware.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package httpd
  2. import (
  3. "errors"
  4. "net/http"
  5. "runtime/debug"
  6. "github.com/go-chi/chi/v5/middleware"
  7. "github.com/go-chi/jwtauth/v5"
  8. "github.com/lestrrat-go/jwx/jwt"
  9. "github.com/drakkan/sftpgo/v2/logger"
  10. "github.com/drakkan/sftpgo/v2/utils"
  11. )
  12. var (
  13. forwardedProtoKey = &contextKey{"forwarded proto"}
  14. errInvalidToken = errors.New("invalid JWT token")
  15. )
  16. type contextKey struct {
  17. name string
  18. }
  19. func (k *contextKey) String() string {
  20. return "context value " + k.name
  21. }
  22. func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
  23. token, _, err := jwtauth.FromContext(r.Context())
  24. var redirectPath string
  25. if audience == tokenAudienceWebAdmin {
  26. redirectPath = webLoginPath
  27. } else {
  28. redirectPath = webClientLoginPath
  29. }
  30. isAPIToken := (audience == tokenAudienceAPI || audience == tokenAudienceAPIUser)
  31. if err != nil || token == nil {
  32. logger.Debug(logSender, "", "error getting jwt token: %v", err)
  33. if isAPIToken {
  34. sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
  35. } else {
  36. http.Redirect(w, r, redirectPath, http.StatusFound)
  37. }
  38. return errInvalidToken
  39. }
  40. err = jwt.Validate(token)
  41. if err != nil {
  42. logger.Debug(logSender, "", "error validating jwt token: %v", err)
  43. if isAPIToken {
  44. sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
  45. } else {
  46. http.Redirect(w, r, redirectPath, http.StatusFound)
  47. }
  48. return errInvalidToken
  49. }
  50. if !utils.IsStringInSlice(audience, token.Audience()) {
  51. logger.Debug(logSender, "", "the token is not valid for audience %#v", audience)
  52. if isAPIToken {
  53. sendAPIResponse(w, r, nil, "Your token audience is not valid", http.StatusUnauthorized)
  54. } else {
  55. http.Redirect(w, r, redirectPath, http.StatusFound)
  56. }
  57. return errInvalidToken
  58. }
  59. if isTokenInvalidated(r) {
  60. logger.Debug(logSender, "", "the token has been invalidated")
  61. if isAPIToken {
  62. sendAPIResponse(w, r, nil, "Your token is no longer valid", http.StatusUnauthorized)
  63. } else {
  64. http.Redirect(w, r, redirectPath, http.StatusFound)
  65. }
  66. return errInvalidToken
  67. }
  68. return nil
  69. }
  70. func jwtAuthenticatorAPI(next http.Handler) http.Handler {
  71. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  72. if err := validateJWTToken(w, r, tokenAudienceAPI); err != nil {
  73. return
  74. }
  75. // Token is authenticated, pass it through
  76. next.ServeHTTP(w, r)
  77. })
  78. }
  79. func jwtAuthenticatorAPIUser(next http.Handler) http.Handler {
  80. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  81. if err := validateJWTToken(w, r, tokenAudienceAPIUser); err != nil {
  82. return
  83. }
  84. // Token is authenticated, pass it through
  85. next.ServeHTTP(w, r)
  86. })
  87. }
  88. func jwtAuthenticatorWebAdmin(next http.Handler) http.Handler {
  89. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  90. if err := validateJWTToken(w, r, tokenAudienceWebAdmin); err != nil {
  91. return
  92. }
  93. // Token is authenticated, pass it through
  94. next.ServeHTTP(w, r)
  95. })
  96. }
  97. func jwtAuthenticatorWebClient(next http.Handler) http.Handler {
  98. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  99. if err := validateJWTToken(w, r, tokenAudienceWebClient); err != nil {
  100. return
  101. }
  102. // Token is authenticated, pass it through
  103. next.ServeHTTP(w, r)
  104. })
  105. }
  106. //nolint:unparam
  107. func checkHTTPUserPerm(perm string) func(next http.Handler) http.Handler {
  108. return func(next http.Handler) http.Handler {
  109. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  110. _, claims, err := jwtauth.FromContext(r.Context())
  111. if err != nil {
  112. if isWebRequest(r) {
  113. renderClientBadRequestPage(w, r, err)
  114. } else {
  115. sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  116. }
  117. return
  118. }
  119. tokenClaims := jwtTokenClaims{}
  120. tokenClaims.Decode(claims)
  121. // for web client perms are negated and not granted
  122. if tokenClaims.hasPerm(perm) {
  123. if isWebRequest(r) {
  124. renderClientForbiddenPage(w, r, "You don't have permission for this action")
  125. } else {
  126. sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  127. }
  128. return
  129. }
  130. next.ServeHTTP(w, r)
  131. })
  132. }
  133. }
  134. func checkPerm(perm string) func(next http.Handler) http.Handler {
  135. return func(next http.Handler) http.Handler {
  136. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  137. _, claims, err := jwtauth.FromContext(r.Context())
  138. if err != nil {
  139. if isWebRequest(r) {
  140. renderBadRequestPage(w, r, err)
  141. } else {
  142. sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  143. }
  144. return
  145. }
  146. tokenClaims := jwtTokenClaims{}
  147. tokenClaims.Decode(claims)
  148. if !tokenClaims.hasPerm(perm) {
  149. if isWebRequest(r) {
  150. renderForbiddenPage(w, r, "You don't have permission for this action")
  151. } else {
  152. sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  153. }
  154. return
  155. }
  156. next.ServeHTTP(w, r)
  157. })
  158. }
  159. }
  160. func verifyCSRFHeader(next http.Handler) http.Handler {
  161. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  162. tokenString := r.Header.Get(csrfHeaderToken)
  163. token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
  164. if err != nil || token == nil {
  165. logger.Debug(logSender, "", "error validating CSRF header: %v", err)
  166. sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
  167. return
  168. }
  169. if !utils.IsStringInSlice(tokenAudienceCSRF, token.Audience()) {
  170. logger.Debug(logSender, "", "error validating CSRF header audience")
  171. sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
  172. return
  173. }
  174. next.ServeHTTP(w, r)
  175. })
  176. }
  177. func recoverer(next http.Handler) http.Handler {
  178. fn := func(w http.ResponseWriter, r *http.Request) {
  179. defer func() {
  180. if rvr := recover(); rvr != nil {
  181. if rvr == http.ErrAbortHandler {
  182. panic(rvr)
  183. }
  184. logEntry := middleware.GetLogEntry(r)
  185. if logEntry != nil {
  186. logEntry.Panic(rvr, debug.Stack())
  187. } else {
  188. middleware.PrintPrettyStack(rvr)
  189. }
  190. w.WriteHeader(http.StatusInternalServerError)
  191. }
  192. }()
  193. next.ServeHTTP(w, r)
  194. }
  195. return http.HandlerFunc(fn)
  196. }