package middleware import ( "context" "fmt" "net" "sync" "time" "code.crute.us/mcrute/golib/clients/netbox" "github.com/labstack/echo/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) var netboxFilterFailures = promauto.NewCounter(prometheus.CounterOpts{ Name: "netbox_ip_filter_refresh_failures", Help: "Total number of failures refreshing netbox sourced IP ranges", }) type ipFilter struct { sync.RWMutex allowLocalhost bool allowedRanges []*net.IPNet } func (f *ipFilter) UpdateRanges(r []*net.IPNet) { f.Lock() defer f.Unlock() f.allowedRanges = r } func (f *ipFilter) Middleware(next echo.HandlerFunc) echo.HandlerFunc { _, v4Localhost, _ := net.ParseCIDR("127.0.0.0/8") _, v6Localhost, _ := net.ParseCIDR("::1/128") return func(c echo.Context) error { f.RLock() defer f.RUnlock() if f.allowedRanges == nil { c.Logger().Error("No allowed IPs configured for filter") return echo.ErrNotFound } ip := net.ParseIP(c.RealIP()) if ip == nil { c.Logger().Error("Unable to parse IP in IPFilter") return echo.ErrNotFound } found := false for _, ipnet := range f.allowedRanges { if ipnet.Contains(ip) { found = true break } } if f.allowLocalhost && (v4Localhost.Contains(ip) || v6Localhost.Contains(ip)) { found = true } if !found { c.Logger().Errorf("IP %s not in range for filter", c.RealIP()) return echo.ErrNotFound } return next(c) } } func NewIPFilter(allowedRanges []*net.IPNet) echo.MiddlewareFunc { return (&ipFilter{allowedRanges: allowedRanges}).Middleware } type NetboxIPFilter struct { NetboxClient netbox.NetboxClient Tag string IncludeLocalhost bool Logger echo.Logger f *ipFilter hasInit bool } func (f *NetboxIPFilter) Init() error { nets, err := f.NetboxClient.GetPrefixesWithTag(f.Tag) if err != nil { return err } f.Logger.Debugf("Got prefixes: %s", nets) f.f = &ipFilter{ allowedRanges: nets, allowLocalhost: f.IncludeLocalhost, } f.hasInit = true return nil } func (f *NetboxIPFilter) Middleware(next echo.HandlerFunc) echo.HandlerFunc { return f.f.Middleware(next) } func (f *NetboxIPFilter) RunRefresh(c context.Context, wg *sync.WaitGroup) error { wg.Add(1) defer wg.Done() if !f.hasInit { return fmt.Errorf("NetboxIPFilter: has not been initialized before RunRefresh called") } f.Logger.Info("Starting netbox IP address filter refresh loop") t := time.NewTicker(time.Hour) defer t.Stop() for { select { case <-t.C: if nets, err := f.NetboxClient.GetPrefixesWithTag(f.Tag); err != nil { f.Logger.Errorf("Error refreshing netbox prefixes for IP filter: %w", err) netboxFilterFailures.Inc() } else { f.Logger.Debugf("Got prefixes: %s", nets) f.f.UpdateRanges(nets) } case <-c.Done(): f.Logger.Info("Shutting down netbox IP address filter refresh loop") return nil } } }