浏览代码

parse IP proxy header also if listening on UNIX domain socket

Fixes #867

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 3 年之前
父节点
当前提交
118744a860
共有 7 个文件被更改,包括 48 次插入5 次删除
  1. 1 1
      docs/repo.md
  2. 5 0
      httpd/httpd.go
  3. 12 0
      httpd/internal_test.go
  4. 7 2
      httpd/server.go
  5. 12 0
      webdavd/internal_test.go
  6. 6 2
      webdavd/server.go
  7. 5 0
      webdavd/webdavd.go

+ 1 - 1
docs/repo.md

@@ -1,6 +1,6 @@
 # SFTPGo repositories
 # SFTPGo repositories
 
 
-These repository are available through Oregon State University's free mirror service. Special thanks to Lance Albertson, Director of the Oregon State University Open Source Lab, who helped me with the initial setup.
+These repositories are available through Oregon State University's free mirror service. Special thanks to Lance Albertson, Director of the Oregon State University Open Source Lab, who helped me with the initial setup.
 
 
 ## APT repo
 ## APT repo
 
 

+ 5 - 0
httpd/httpd.go

@@ -476,6 +476,11 @@ func (b *Binding) checkBranding() {
 }
 }
 
 
 func (b *Binding) parseAllowedProxy() error {
 func (b *Binding) parseAllowedProxy() error {
+	if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 {
+		// unix domain socket
+		b.allowHeadersFrom = []func(net.IP) bool{func(ip net.IP) bool { return true }}
+		return nil
+	}
 	allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed)
 	allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed)
 	if err != nil {
 	if err != nil {
 		return err
 		return err

+ 12 - 0
httpd/internal_test.go

@@ -1532,6 +1532,18 @@ func TestJWTTokenCleanup(t *testing.T) {
 	stopCleanupTicker()
 	stopCleanupTicker()
 }
 }
 
 
+func TestAllowedProxyUnixDomainSocket(t *testing.T) {
+	b := Binding{
+		Address:      filepath.Join(os.TempDir(), "sock"),
+		ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"},
+	}
+	err := b.parseAllowedProxy()
+	assert.NoError(t, err)
+	if assert.Len(t, b.allowHeadersFrom, 1) {
+		assert.True(t, b.allowHeadersFrom[0](nil))
+	}
+}
+
 func TestProxyHeaders(t *testing.T) {
 func TestProxyHeaders(t *testing.T) {
 	username := "adminTest"
 	username := "adminTest"
 	password := "testPwd"
 	password := "testPwd"

+ 7 - 2
httpd/server.go

@@ -9,6 +9,7 @@ import (
 	"log"
 	"log"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"path/filepath"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -972,9 +973,13 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request {
 func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
 func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 		ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
-		ip := net.ParseIP(ipAddr)
+		var ip net.IP
+		isUnixSocket := filepath.IsAbs(s.binding.Address)
+		if !isUnixSocket {
+			ip = net.ParseIP(ipAddr)
+		}
 		areHeadersAllowed := false
 		areHeadersAllowed := false
-		if ip != nil {
+		if isUnixSocket || ip != nil {
 			for _, allow := range s.binding.allowHeadersFrom {
 			for _, allow := range s.binding.allowHeadersFrom {
 				if allow(ip) {
 				if allow(ip) {
 					parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth)
 					parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth)

+ 12 - 0
webdavd/internal_test.go

@@ -412,6 +412,18 @@ func TestUserInvalidParams(t *testing.T) {
 	writeLog(req, http.StatusOK, nil)
 	writeLog(req, http.StatusOK, nil)
 }
 }
 
 
+func TestAllowedProxyUnixDomainSocket(t *testing.T) {
+	b := Binding{
+		Address:      filepath.Join(os.TempDir(), "sock"),
+		ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"},
+	}
+	err := b.parseAllowedProxy()
+	assert.NoError(t, err)
+	if assert.Len(t, b.allowHeadersFrom, 1) {
+		assert.True(t, b.allowHeadersFrom[0](nil))
+	}
+}
+
 func TestRemoteAddress(t *testing.T) {
 func TestRemoteAddress(t *testing.T) {
 	remoteAddr1 := "100.100.100.100"
 	remoteAddr1 := "100.100.100.100"
 	remoteAddr2 := "172.172.172.172"
 	remoteAddr2 := "172.172.172.172"

+ 6 - 2
webdavd/server.go

@@ -331,8 +331,12 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
 
 
 func (s *webDavServer) checkRemoteAddress(r *http.Request) string {
 func (s *webDavServer) checkRemoteAddress(r *http.Request) string {
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
-	ip := net.ParseIP(ipAddr)
-	if ip != nil {
+	var ip net.IP
+	isUnixSocket := filepath.IsAbs(s.binding.Address)
+	if !isUnixSocket {
+		ip = net.ParseIP(ipAddr)
+	}
+	if isUnixSocket || ip != nil {
 		for _, allow := range s.binding.allowHeadersFrom {
 		for _, allow := range s.binding.allowHeadersFrom {
 			if allow(ip) {
 			if allow(ip) {
 				parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth)
 				parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth)

+ 5 - 0
webdavd/webdavd.go

@@ -109,6 +109,11 @@ type Binding struct {
 }
 }
 
 
 func (b *Binding) parseAllowedProxy() error {
 func (b *Binding) parseAllowedProxy() error {
+	if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 {
+		// unix domain socket
+		b.allowHeadersFrom = []func(net.IP) bool{func(ip net.IP) bool { return true }}
+		return nil
+	}
 	allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed)
 	allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed)
 	if err != nil {
 	if err != nil {
 		return err
 		return err