basic_auth.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. package web
  2. import (
  3. "backup-x/entity"
  4. "backup-x/util"
  5. "bytes"
  6. "encoding/base64"
  7. "log"
  8. "net/http"
  9. "strings"
  10. "time"
  11. )
  12. // ViewFunc func
  13. type ViewFunc func(http.ResponseWriter, *http.Request)
  14. type loginDetect struct {
  15. FailTimes int
  16. }
  17. var ld = &loginDetect{}
  18. // BasicAuth basic auth
  19. func BasicAuth(f ViewFunc) ViewFunc {
  20. return func(w http.ResponseWriter, r *http.Request) {
  21. conf, _ := entity.GetConfigCache()
  22. // 帐号或密码为空。跳过
  23. if conf.Username == "" && conf.Password == "" {
  24. // 执行被装饰的函数
  25. f(w, r)
  26. return
  27. }
  28. // 认证帐号密码
  29. basicAuthPrefix := "Basic "
  30. // 获取 request header
  31. auth := r.Header.Get("Authorization")
  32. // 如果是 http basic auth
  33. if strings.HasPrefix(auth, basicAuthPrefix) {
  34. // 解码认证信息
  35. payload, err := base64.StdEncoding.DecodeString(
  36. auth[len(basicAuthPrefix):],
  37. )
  38. if err == nil {
  39. pair := bytes.SplitN(payload, []byte(":"), 2)
  40. pwd, _ := util.DecryptByEncryptKey(conf.EncryptKey, conf.Password)
  41. if len(pair) == 2 &&
  42. bytes.Equal(pair[0], []byte(conf.Username)) &&
  43. bytes.Equal(pair[1], []byte(pwd)) {
  44. ld.FailTimes = 0
  45. // 执行被装饰的函数
  46. f(w, r)
  47. return
  48. }
  49. }
  50. ld.FailTimes = ld.FailTimes + 1
  51. if ld.FailTimes > 5 {
  52. log.Printf("%s 登陆失败超过5次! 并延时60s响应\n", r.RemoteAddr)
  53. time.Sleep(60 * time.Second)
  54. ld.FailTimes = 0
  55. }
  56. log.Printf("%s 登陆失败!\n", r.RemoteAddr)
  57. }
  58. // 认证失败,提示 401 Unauthorized
  59. // Restricted 可以改成其他的值
  60. w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
  61. // 401 状态码
  62. w.WriteHeader(http.StatusUnauthorized)
  63. log.Printf("%s 请求登陆!\n", r.RemoteAddr)
  64. }
  65. }