| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 | 
							- package httpd
 
- import (
 
- 	"context"
 
- 	"crypto/tls"
 
- 	"crypto/x509"
 
- 	"fmt"
 
- 	"net/http"
 
- 	"os"
 
- 	"path/filepath"
 
- 	"time"
 
- 	"github.com/drakkan/sftpgo/ldapauthserver/config"
 
- 	"github.com/drakkan/sftpgo/ldapauthserver/logger"
 
- 	"github.com/drakkan/sftpgo/ldapauthserver/utils"
 
- 	"github.com/go-chi/chi/v5"
 
- 	"github.com/go-chi/chi/v5/middleware"
 
- 	"github.com/go-chi/render"
 
- )
 
- const (
 
- 	logSender      = "httpd"
 
- 	versionPath    = "/api/v1/version"
 
- 	checkAuthPath  = "/api/v1/check_auth"
 
- 	maxRequestSize = 1 << 18 // 256KB
 
- )
 
- var (
 
- 	ldapConfig config.LDAPConfig
 
- 	httpAuth   httpAuthProvider
 
- 	certMgr    *certManager
 
- 	rootCAs    *x509.CertPool
 
- )
 
- // StartHTTPServer initializes and starts the HTTP Server
 
- func StartHTTPServer(configDir string, httpConfig config.HTTPDConfig) error {
 
- 	var err error
 
- 	authUserFile := getConfigPath(httpConfig.AuthUserFile, configDir)
 
- 	httpAuth, err = newBasicAuthProvider(authUserFile)
 
- 	if err != nil {
 
- 		return err
 
- 	}
 
- 	router := chi.NewRouter()
 
- 	router.Use(middleware.RequestID)
 
- 	router.Use(middleware.RealIP)
 
- 	router.Use(logger.NewStructuredLogger(logger.GetLogger()))
 
- 	router.Use(middleware.Recoverer)
 
- 	router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
- 		sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
 
- 	}))
 
- 	router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
- 		sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed)
 
- 	}))
 
- 	router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
 
- 		render.JSON(w, r, utils.GetAppVersion())
 
- 	})
 
- 	router.Group(func(router chi.Router) {
 
- 		router.Use(checkAuth)
 
- 		router.Post(checkAuthPath, checkSFTPGoUserAuth)
 
- 	})
 
- 	ldapConfig = config.GetLDAPConfig()
 
- 	loadCACerts(configDir)
 
- 	certificateFile := getConfigPath(httpConfig.CertificateFile, configDir)
 
- 	certificateKeyFile := getConfigPath(httpConfig.CertificateKeyFile, configDir)
 
- 	httpServer := &http.Server{
 
- 		Addr:           fmt.Sprintf("%s:%d", httpConfig.BindAddress, httpConfig.BindPort),
 
- 		Handler:        router,
 
- 		ReadTimeout:    70 * time.Second,
 
- 		WriteTimeout:   70 * time.Second,
 
- 		IdleTimeout:    120 * time.Second,
 
- 		MaxHeaderBytes: 1 << 16, // 64KB
 
- 	}
 
- 	if len(certificateFile) > 0 && len(certificateKeyFile) > 0 {
 
- 		certMgr, err = newCertManager(certificateFile, certificateKeyFile)
 
- 		if err != nil {
 
- 			return err
 
- 		}
 
- 		config := &tls.Config{
 
- 			GetCertificate: certMgr.GetCertificateFunc(),
 
- 			MinVersion:     tls.VersionTLS12,
 
- 		}
 
- 		httpServer.TLSConfig = config
 
- 		return httpServer.ListenAndServeTLS("", "")
 
- 	}
 
- 	return httpServer.ListenAndServe()
 
- }
 
- func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
 
- 	var errorString string
 
- 	if err != nil {
 
- 		errorString = err.Error()
 
- 	}
 
- 	resp := apiResponse{
 
- 		Error:      errorString,
 
- 		Message:    message,
 
- 		HTTPStatus: code,
 
- 	}
 
- 	ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
 
- 	render.JSON(w, r.WithContext(ctx), resp)
 
- }
 
- func loadCACerts(configDir string) error {
 
- 	var err error
 
- 	rootCAs, err = x509.SystemCertPool()
 
- 	if err != nil {
 
- 		rootCAs = x509.NewCertPool()
 
- 	}
 
- 	for _, ca := range ldapConfig.CACertificates {
 
- 		caPath := getConfigPath(ca, configDir)
 
- 		certs, err := os.ReadFile(caPath)
 
- 		if err != nil {
 
- 			logger.Warn(logSender, "", "error loading ca cert %q: %v", caPath, err)
 
- 			return err
 
- 		}
 
- 		if !rootCAs.AppendCertsFromPEM(certs) {
 
- 			logger.Warn(logSender, "", "unable to add ca cert %q", caPath)
 
- 		} else {
 
- 			logger.Debug(logSender, "", "ca cert %q added to the trusted certificates", caPath)
 
- 		}
 
- 	}
 
- 	return nil
 
- }
 
- // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths
 
- func ReloadTLSCertificate() {
 
- 	if certMgr != nil {
 
- 		certMgr.loadCertificate()
 
- 	}
 
- }
 
- func getConfigPath(name, configDir string) string {
 
- 	if !utils.IsFileInputValid(name) {
 
- 		return ""
 
- 	}
 
- 	if len(name) > 0 && !filepath.IsAbs(name) {
 
- 		return filepath.Join(configDir, name)
 
- 	}
 
- 	return name
 
- }
 
 
  |