From 4d8fde6a8882e63e8dfafdf9f62c73b7b1036ebf Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Fri, 12 Nov 2021 20:54:39 -0800 Subject: echo: add prometheus and multiple middleware - Integrate prometheus - Integrate CORS - Access control prometheus by IP - Fix redirectors to consider port - Fix redirectors to not redirect IP addresses - Use body limit e.Use(middleware.BodyLimit("2M")) --- echo/echo_default.go | 177 +++++++++++++++++++++++++++++------------- echo/middleware/ip_filter.go | 39 ++++++++++ echo/middleware/redirect.go | 75 ++++++++++++++++++ echo/prometheus/prometheus.go | 16 ++-- 4 files changed, 244 insertions(+), 63 deletions(-) create mode 100644 echo/middleware/ip_filter.go create mode 100644 echo/middleware/redirect.go diff --git a/echo/echo_default.go b/echo/echo_default.go index ddaab44..92a5cfd 100644 --- a/echo/echo_default.go +++ b/echo/echo_default.go @@ -6,13 +6,17 @@ import ( "fmt" "html/template" "io/fs" + "net" "net/http" "path" + "strconv" "sync" gltls "code.crute.us/mcrute/golib/crypto/tls" glmw "code.crute.us/mcrute/golib/echo/middleware" "code.crute.us/mcrute/golib/echo/prometheus" + glnet "code.crute.us/mcrute/golib/net" + glservice "code.crute.us/mcrute/golib/service" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" @@ -22,21 +26,23 @@ import ( // Docs: https://echo.labstack.com/guide/ // TODO: -// - Intgrate prometheus -// - Integrate CORS -// - Access control prometheus by IP -// - Fix redirectors to consider port -// - Fix redirectors to not redirect IP addresses // - Integrate CSRF // - Integrate session -// - Use bodylimit e.Use(middleware.BodyLimit("2M")) +// - Enable auto cert management by passing hostnames + +const ( + defaultBodySizeLimit = "10M" +) type EchoConfig struct { Debug bool - BindAddress string - BindTLSAddress string + Hostnames []string + BindAddresses []string + BindTLSAddresses []string TLSCacheDir string + BodySizeLimit string TrustedProxyIPRanges []string + ManagementIPRanges []string EmbeddedTemplates fs.FS DiskTemplates fs.FS TemplateGlob *string @@ -44,14 +50,16 @@ type EchoConfig struct { CombinedHostLogFile string RedirectToWWW bool ContentSecurityPolicy *glmw.ContentSecurityPolicyConfig + DisablePrometheus bool PrometheusConfig *prometheus.PrometheusConfig + CORSConfig *middleware.CORSConfig } type EchoWrapper struct { *echo.Echo - tlsServer http.Server + servers []*http.Server + tlsServers []*http.Server templateFS fs.FS - bindAddress string ocspErrors chan gltls.OcspError ocspManager *gltls.OcspManager initDone bool @@ -79,45 +87,56 @@ func (w *EchoWrapper) RunCertificateManager(ctx context.Context, wg *sync.WaitGr return w.ocspManager.Run(ctx, wg) } -func (w *EchoWrapper) run(ctx context.Context, wg *sync.WaitGroup, f func() error, sf func(context.Context) error) error { - if !w.initDone { - return fmt.Errorf("Echo is not initialized. Call Init()") - } +func (w *EchoWrapper) makeServerJob(s *http.Server, echoInit bool) glservice.RunnerFunc { + return func(ctx context.Context, wg *sync.WaitGroup) error { + if !w.initDone { + return fmt.Errorf("Echo is not initialized. Call Init()") + } - wg.Add(1) - defer wg.Done() + wg.Add(1) + defer wg.Done() - err := make(chan error) - go func() { err <- f() }() - select { - case e := <-err: - return e - default: - } + w.Logger.Infof("Starting server with address: %s", s.Addr) - select { - case <-ctx.Done(): - w.Logger.Info("Shutting down web server") - return sf(ctx) + err := make(chan error) + go func() { + if s.TLSConfig == nil && echoInit { + err <- w.Echo.StartServer(s) + } else if s.TLSConfig == nil && !echoInit { + err <- s.ListenAndServe() + } else { + err <- s.ListenAndServeTLS("", "") + } + }() + select { + case e := <-err: + return e + default: + } + + select { + case <-ctx.Done(): + w.Logger.Info("Shutting down web server") + return s.Shutdown(ctx) + } } } -func (w *EchoWrapper) Serve(ctx context.Context, wg *sync.WaitGroup) error { - return w.run( - ctx, - wg, - func() error { return w.Echo.Start(w.bindAddress) }, - func(ctx context.Context) error { return w.Echo.Shutdown(ctx) }, - ) -} +func (w *EchoWrapper) MakeServerJobs() []glservice.RunnerFunc { + out := []glservice.RunnerFunc{} -func (w *EchoWrapper) ServeTLS(ctx context.Context, wg *sync.WaitGroup) error { - return w.run( - ctx, - wg, - func() error { return w.tlsServer.ListenAndServeTLS("", "") }, - func(ctx context.Context) error { return w.tlsServer.Shutdown(ctx) }, - ) + for i, s := range w.servers { + // The first http (not https) server should do an echo.StartServer to + // configure some internal echo state and print the banner (if + // configured). + out = append(out, w.makeServerJob(s, i == 0)) + } + + for _, s := range w.tlsServers { + out = append(out, w.makeServerJob(s, false)) + } + + return out } func (w *EchoWrapper) GetTemplateFS() fs.FS { @@ -183,25 +202,55 @@ func NewDefaultEchoWithConfig(c EchoConfig) (*EchoWrapper, error) { Errors: cmek, } - ts := http.Server{ - Addr: c.BindTLSAddress, - TLSConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - GetCertificate: cm.GetCertificate, - }, - Handler: e, + servers := make([]*http.Server, len(c.BindAddresses)) + for i, a := range c.BindAddresses { + servers[i] = &http.Server{ + Addr: a, + Handler: e, + } + } + + tlsServers := make([]*http.Server, len(c.BindTLSAddresses)) + for i, a := range c.BindTLSAddresses { + tlsServers[i] = &http.Server{ + Addr: a, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: cm.GetCertificate, + }, + Handler: e, + } } e.Use(middleware.Logger()) e.Use(glmw.Recover()) - e.Use(middleware.HTTPSRedirect()) + + if c.BodySizeLimit == "" { + e.Use(middleware.BodyLimit(defaultBodySizeLimit)) + } else if c.BodySizeLimit != "0" { + e.Use(middleware.BodyLimit(c.BodySizeLimit)) + } + + _, tlsPort, err := net.SplitHostPort(tlsServers[0].Addr) + if err != nil { + return nil, fmt.Errorf("Unable to split TLS addr and port: %w", err) + } + tlsPortI, err := strconv.Atoi(tlsPort) + if err != nil { + return nil, fmt.Errorf("Unable to convert TLS port to int: %w", err) + } + e.Use(glmw.HTTPSRedirectWithConfig(glmw.HTTPSRedirectConfig{Port: tlsPortI})) + if c.RedirectToWWW { e.Use(middleware.WWWRedirect()) } + e.Use(middleware.Decompress()) e.Use(middleware.GzipWithConfig(middleware.GzipConfig{ // TODO: This mw causes prometheus responses to show up compressed for // browsers. Why? + // + // Also, this path should use the config path if we keep it Skipper: func(c echo.Context) bool { if c.Path() == "/metrics" { return true @@ -210,24 +259,42 @@ func NewDefaultEchoWithConfig(c EchoConfig) (*EchoWrapper, error) { }, Level: 5, })) + e.Use(glmw.StrictSecure()) + if c.CORSConfig != nil { + e.Use(middleware.CORSWithConfig(*c.CORSConfig)) + } else { + e.Use(middleware.CORS()) + } + if c.ContentSecurityPolicy != nil { e.Use(glmw.ContentSecurityPolicyWithConfig(*c.ContentSecurityPolicy)) } else { return nil, fmt.Errorf("ContentSecurityPolicy is required") } - if c.PrometheusConfig != nil { - prom := prometheus.NewPrometheusWithConfig(c.PrometheusConfig) + if !c.DisablePrometheus { + mips, err := glnet.ParseCIDRSlice(c.ManagementIPRanges) + if err != nil { + return nil, err + } + + var prom *prometheus.Prometheus + if c.PrometheusConfig != nil { + prom = prometheus.NewPrometheusWithConfig(c.PrometheusConfig) + } else { + prom = prometheus.NewPrometheus() + } + e.Use(prom.MiddlewareHandler) - e.GET(c.PrometheusConfig.MetricsPath, prom.MetricsHandler) + e.GET(prom.Config.MetricsPath, prom.MetricsHandler, glmw.NewIPFilter(mips)) } return &EchoWrapper{ Echo: e, - tlsServer: ts, - bindAddress: c.BindAddress, + servers: servers, + tlsServers: tlsServers, ocspErrors: cmek, ocspManager: cm, templateFS: templates, diff --git a/echo/middleware/ip_filter.go b/echo/middleware/ip_filter.go new file mode 100644 index 0000000..007791e --- /dev/null +++ b/echo/middleware/ip_filter.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "net" + + "github.com/labstack/echo/v4" +) + +func NewIPFilter(allowedRanges []*net.IPNet) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if allowedRanges == nil { + c.Logger().Error("No allowed IPs configured for filter") + return echo.ErrNotFound + } + + ip := net.ParseIP(c.RealIP()) + if ip == nil { + c.Logger().Error("Unable to parse IP in IPFilter") + return echo.ErrNotFound + } + + found := false + for _, ipnet := range allowedRanges { + if ipnet.Contains(ip) { + found = true + break + } + } + + if !found { + c.Logger().Errorf("IP %s not in range for filter", c.RealIP()) + return echo.ErrNotFound + } + + return next(c) + } + } +} diff --git a/echo/middleware/redirect.go b/echo/middleware/redirect.go new file mode 100644 index 0000000..134bfee --- /dev/null +++ b/echo/middleware/redirect.go @@ -0,0 +1,75 @@ +package middleware + +/* +HTTP to HTTPS Redirect Middleware + +This is a duplicate of existing functionality in Echo because the Echo default +middleware doesn't support redirecting to a different HTTPS port which is +needed in dev environments and some prod environments where the server runs on +an off-port. +*/ + +import ( + "fmt" + "net" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +type HTTPSRedirectConfig struct { + Skipper middleware.Skipper + Port int + Code int +} + +var DefaultHTTPSRedirectConfig = HTTPSRedirectConfig{ + Skipper: middleware.DefaultSkipper, + Port: 443, + Code: http.StatusMovedPermanently, +} + +func HTTPSRedirect() echo.MiddlewareFunc { + return HTTPSRedirectWithConfig(DefaultHTTPSRedirectConfig) +} + +func HTTPSRedirectWithConfig(config HTTPSRedirectConfig) echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultHTTPSRedirectConfig.Skipper + } + if config.Code == 0 { + config.Code = DefaultHTTPSRedirectConfig.Code + } + if config.Port == 0 { + config.Port = DefaultHTTPSRedirectConfig.Port + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) || c.Scheme() == "https" { + return next(c) + } + + var err error + req := c.Request() + + host := req.URL.Host + if host == "" { + host, _, err = net.SplitHostPort(req.Host) + if err != nil { + return echo.ErrBadRequest + } + } + + // Browers assume 443 if it's an https request, otherwise the port + // needs to be specified in the URL + redir := fmt.Sprintf("https://%s%s", host, req.RequestURI) + if config.Port != 443 { + redir = fmt.Sprintf("https://%s:%d%s", host, config.Port, req.RequestURI) + } + + return c.Redirect(http.StatusMovedPermanently, redir) + } + } +} diff --git a/echo/prometheus/prometheus.go b/echo/prometheus/prometheus.go index 616e425..2fbf252 100644 --- a/echo/prometheus/prometheus.go +++ b/echo/prometheus/prometheus.go @@ -14,7 +14,7 @@ import ( ) type Prometheus struct { - config *PrometheusConfig + Config *PrometheusConfig requestCount *prometheus.CounterVec requestDuration *prometheus.HistogramVec requestSize *prometheus.HistogramVec @@ -65,7 +65,7 @@ func NewPrometheusWithConfig(c *PrometheusConfig) *Prometheus { } return &Prometheus{ - config: c, + Config: c, MetricsHandler: echo.WrapHandler(promhttp.Handler()), requestCount: promauto.NewCounterVec(prometheus.CounterOpts{ Subsystem: c.Subsystem, @@ -92,11 +92,11 @@ func NewPrometheusWithConfig(c *PrometheusConfig) *Prometheus { func (p *Prometheus) MiddlewareHandler(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if c.Path() == p.config.MetricsPath { + if c.Path() == p.Config.MetricsPath { return next(c) } - if p.config.Skipper(c) { + if p.Config.Skipper(c) { return next(c) } @@ -116,9 +116,9 @@ func (p *Prometheus) MiddlewareHandler(next echo.HandlerFunc) echo.HandlerFunc { } } - url := p.config.ExtractUrl(c) - if len(p.config.ContextLabel) > 0 { - u := c.Get(p.config.ContextLabel) + url := p.Config.ExtractUrl(c) + if len(p.Config.ContextLabel) > 0 { + u := c.Get(p.Config.ContextLabel) if u == nil { u = "unknown" } @@ -128,7 +128,7 @@ func (p *Prometheus) MiddlewareHandler(next echo.HandlerFunc) echo.HandlerFunc { s := strconv.Itoa(status) m := c.Request().Method p.requestDuration.WithLabelValues(s, m, url).Observe(elapsed) - p.requestCount.WithLabelValues(s, m, p.config.ExtractHost(c), url).Inc() + p.requestCount.WithLabelValues(s, m, p.Config.ExtractHost(c), url).Inc() p.requestSize.WithLabelValues(s, m, url).Observe(float64(reqSz)) p.responseSize.WithLabelValues(s, m, url).Observe(float64(c.Response().Size)) -- cgit v1.2.3