middleware.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package httpd
  2. import (
  3. "errors"
  4. "net/http"
  5. "runtime/debug"
  6. "github.com/go-chi/chi/v5/middleware"
  7. "github.com/go-chi/jwtauth/v5"
  8. "github.com/lestrrat-go/jwx/jwt"
  9. "github.com/drakkan/sftpgo/v2/logger"
  10. "github.com/drakkan/sftpgo/v2/util"
  11. )
  12. var (
  13. forwardedProtoKey = &contextKey{"forwarded proto"}
  14. errInvalidToken = errors.New("invalid JWT token")
  15. )
  16. type contextKey struct {
  17. name string
  18. }
  19. func (k *contextKey) String() string {
  20. return "context value " + k.name
  21. }
  22. func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
  23. token, _, err := jwtauth.FromContext(r.Context())
  24. var redirectPath string
  25. if audience == tokenAudienceWebAdmin {
  26. redirectPath = webLoginPath
  27. } else {
  28. redirectPath = webClientLoginPath
  29. }
  30. isAPIToken := (audience == tokenAudienceAPI || audience == tokenAudienceAPIUser)
  31. if err != nil || token == nil {
  32. logger.Debug(logSender, "", "error getting jwt token: %v", err)
  33. if isAPIToken {
  34. sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
  35. } else {
  36. http.Redirect(w, r, redirectPath, http.StatusFound)
  37. }
  38. return errInvalidToken
  39. }
  40. err = jwt.Validate(token)
  41. if err != nil {
  42. logger.Debug(logSender, "", "error validating jwt token: %v", err)
  43. if isAPIToken {
  44. sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
  45. } else {
  46. http.Redirect(w, r, redirectPath, http.StatusFound)
  47. }
  48. return errInvalidToken
  49. }
  50. if !util.IsStringInSlice(audience, token.Audience()) {
  51. logger.Debug(logSender, "", "the token is not valid for audience %#v", audience)
  52. if isAPIToken {
  53. sendAPIResponse(w, r, nil, "Your token audience is not valid", http.StatusUnauthorized)
  54. } else {
  55. http.Redirect(w, r, redirectPath, http.StatusFound)
  56. }
  57. return errInvalidToken
  58. }
  59. if isTokenInvalidated(r) {
  60. logger.Debug(logSender, "", "the token has been invalidated")
  61. if isAPIToken {
  62. sendAPIResponse(w, r, nil, "Your token is no longer valid", http.StatusUnauthorized)
  63. } else {
  64. http.Redirect(w, r, redirectPath, http.StatusFound)
  65. }
  66. return errInvalidToken
  67. }
  68. return nil
  69. }
  70. func jwtAuthenticatorAPI(next http.Handler) http.Handler {
  71. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  72. if err := validateJWTToken(w, r, tokenAudienceAPI); err != nil {
  73. return
  74. }
  75. // Token is authenticated, pass it through
  76. next.ServeHTTP(w, r)
  77. })
  78. }
  79. func jwtAuthenticatorAPIUser(next http.Handler) http.Handler {
  80. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  81. if err := validateJWTToken(w, r, tokenAudienceAPIUser); err != nil {
  82. return
  83. }
  84. // Token is authenticated, pass it through
  85. next.ServeHTTP(w, r)
  86. })
  87. }
  88. func jwtAuthenticatorWebAdmin(next http.Handler) http.Handler {
  89. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  90. if err := validateJWTToken(w, r, tokenAudienceWebAdmin); err != nil {
  91. return
  92. }
  93. // Token is authenticated, pass it through
  94. next.ServeHTTP(w, r)
  95. })
  96. }
  97. func jwtAuthenticatorWebClient(next http.Handler) http.Handler {
  98. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  99. if err := validateJWTToken(w, r, tokenAudienceWebClient); err != nil {
  100. return
  101. }
  102. // Token is authenticated, pass it through
  103. next.ServeHTTP(w, r)
  104. })
  105. }
  106. func checkHTTPUserPerm(perm string) func(next http.Handler) http.Handler {
  107. return func(next http.Handler) http.Handler {
  108. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  109. _, claims, err := jwtauth.FromContext(r.Context())
  110. if err != nil {
  111. if isWebRequest(r) {
  112. renderClientBadRequestPage(w, r, err)
  113. } else {
  114. sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  115. }
  116. return
  117. }
  118. tokenClaims := jwtTokenClaims{}
  119. tokenClaims.Decode(claims)
  120. // for web client perms are negated and not granted
  121. if tokenClaims.hasPerm(perm) {
  122. if isWebRequest(r) {
  123. renderClientForbiddenPage(w, r, "You don't have permission for this action")
  124. } else {
  125. sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  126. }
  127. return
  128. }
  129. next.ServeHTTP(w, r)
  130. })
  131. }
  132. }
  133. func checkPerm(perm string) func(next http.Handler) http.Handler {
  134. return func(next http.Handler) http.Handler {
  135. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  136. _, claims, err := jwtauth.FromContext(r.Context())
  137. if err != nil {
  138. if isWebRequest(r) {
  139. renderBadRequestPage(w, r, err)
  140. } else {
  141. sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  142. }
  143. return
  144. }
  145. tokenClaims := jwtTokenClaims{}
  146. tokenClaims.Decode(claims)
  147. if !tokenClaims.hasPerm(perm) {
  148. if isWebRequest(r) {
  149. renderForbiddenPage(w, r, "You don't have permission for this action")
  150. } else {
  151. sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  152. }
  153. return
  154. }
  155. next.ServeHTTP(w, r)
  156. })
  157. }
  158. }
  159. func verifyCSRFHeader(next http.Handler) http.Handler {
  160. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  161. tokenString := r.Header.Get(csrfHeaderToken)
  162. token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
  163. if err != nil || token == nil {
  164. logger.Debug(logSender, "", "error validating CSRF header: %v", err)
  165. sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
  166. return
  167. }
  168. if !util.IsStringInSlice(tokenAudienceCSRF, token.Audience()) {
  169. logger.Debug(logSender, "", "error validating CSRF header audience")
  170. sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
  171. return
  172. }
  173. next.ServeHTTP(w, r)
  174. })
  175. }
  176. func recoverer(next http.Handler) http.Handler {
  177. fn := func(w http.ResponseWriter, r *http.Request) {
  178. defer func() {
  179. if rvr := recover(); rvr != nil {
  180. if rvr == http.ErrAbortHandler {
  181. panic(rvr)
  182. }
  183. logEntry := middleware.GetLogEntry(r)
  184. if logEntry != nil {
  185. logEntry.Panic(rvr, debug.Stack())
  186. } else {
  187. middleware.PrintPrettyStack(rvr)
  188. }
  189. w.WriteHeader(http.StatusInternalServerError)
  190. }
  191. }()
  192. next.ServeHTTP(w, r)
  193. }
  194. return http.HandlerFunc(fn)
  195. }