gui_csrf.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package main
  2. import (
  3. "bufio"
  4. "crypto/rand"
  5. "encoding/base64"
  6. "fmt"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "sync"
  12. "time"
  13. "github.com/calmh/syncthing/osutil"
  14. )
  15. var csrfTokens []string
  16. var csrfMut sync.Mutex
  17. // Check for CSRF token on /rest/ URLs. If a correct one is not given, reject
  18. // the request with 403. For / and /index.html, set a new CSRF cookie if none
  19. // is currently set.
  20. func csrfMiddleware(w http.ResponseWriter, r *http.Request) {
  21. if strings.HasPrefix(r.URL.Path, "/rest/") {
  22. token := r.Header.Get("X-CSRF-Token")
  23. if !validCsrfToken(token) {
  24. http.Error(w, "CSRF Error", 403)
  25. }
  26. } else if r.URL.Path == "/" || r.URL.Path == "/index.html" {
  27. cookie, err := r.Cookie("CSRF-Token")
  28. if err != nil || !validCsrfToken(cookie.Value) {
  29. cookie = &http.Cookie{
  30. Name: "CSRF-Token",
  31. Value: newCsrfToken(),
  32. }
  33. http.SetCookie(w, cookie)
  34. }
  35. }
  36. }
  37. func validCsrfToken(token string) bool {
  38. csrfMut.Lock()
  39. defer csrfMut.Unlock()
  40. for _, t := range csrfTokens {
  41. if t == token {
  42. return true
  43. }
  44. }
  45. return false
  46. }
  47. func newCsrfToken() string {
  48. bs := make([]byte, 30)
  49. _, err := rand.Reader.Read(bs)
  50. if err != nil {
  51. l.Fatalln(err)
  52. }
  53. token := base64.StdEncoding.EncodeToString(bs)
  54. csrfMut.Lock()
  55. csrfTokens = append(csrfTokens, token)
  56. if len(csrfTokens) > 10 {
  57. csrfTokens = csrfTokens[len(csrfTokens)-10:]
  58. }
  59. defer csrfMut.Unlock()
  60. saveCsrfTokens()
  61. return token
  62. }
  63. func saveCsrfTokens() {
  64. name := filepath.Join(confDir, "csrftokens.txt")
  65. tmp := fmt.Sprintf("%s.tmp.%d", name, time.Now().UnixNano())
  66. f, err := os.OpenFile(tmp, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644)
  67. if err != nil {
  68. return
  69. }
  70. defer os.Remove(tmp)
  71. for _, t := range csrfTokens {
  72. _, err := fmt.Fprintln(f, t)
  73. if err != nil {
  74. return
  75. }
  76. }
  77. err = f.Close()
  78. if err != nil {
  79. return
  80. }
  81. osutil.Rename(tmp, name)
  82. }
  83. func loadCsrfTokens() {
  84. name := filepath.Join(confDir, "csrftokens.txt")
  85. f, err := os.Open(name)
  86. if err != nil {
  87. return
  88. }
  89. defer f.Close()
  90. s := bufio.NewScanner(f)
  91. for s.Scan() {
  92. csrfTokens = append(csrfTokens, s.Text())
  93. }
  94. }