diff options
Diffstat (limited to 'echo/middleware/ip_filter.go')
-rw-r--r-- | echo/middleware/ip_filter.go | 141 |
1 files changed, 118 insertions, 23 deletions
diff --git a/echo/middleware/ip_filter.go b/echo/middleware/ip_filter.go index 007791e..2d79925 100644 --- a/echo/middleware/ip_filter.go +++ b/echo/middleware/ip_filter.go | |||
@@ -1,39 +1,134 @@ | |||
1 | package middleware | 1 | package middleware |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "context" | ||
5 | "fmt" | ||
4 | "net" | 6 | "net" |
7 | "sync" | ||
8 | "time" | ||
5 | 9 | ||
10 | "code.crute.us/mcrute/golib/echo/netbox" | ||
6 | "github.com/labstack/echo/v4" | 11 | "github.com/labstack/echo/v4" |
12 | "github.com/prometheus/client_golang/prometheus" | ||
13 | "github.com/prometheus/client_golang/prometheus/promauto" | ||
7 | ) | 14 | ) |
8 | 15 | ||
9 | func NewIPFilter(allowedRanges []*net.IPNet) echo.MiddlewareFunc { | 16 | var netboxFilterFailures = promauto.NewCounter(prometheus.CounterOpts{ |
10 | return func(next echo.HandlerFunc) echo.HandlerFunc { | 17 | Name: "netbox_ip_filter_refresh_failures", |
11 | return func(c echo.Context) error { | 18 | Help: "Total number of failures refreshing netbox sourced IP ranges", |
12 | if allowedRanges == nil { | 19 | }) |
13 | c.Logger().Error("No allowed IPs configured for filter") | ||
14 | return echo.ErrNotFound | ||
15 | } | ||
16 | 20 | ||
17 | ip := net.ParseIP(c.RealIP()) | 21 | type ipFilter struct { |
18 | if ip == nil { | 22 | sync.RWMutex |
19 | c.Logger().Error("Unable to parse IP in IPFilter") | 23 | allowLocalhost bool |
20 | return echo.ErrNotFound | 24 | allowedRanges []*net.IPNet |
21 | } | 25 | } |
22 | 26 | ||
23 | found := false | 27 | func (f *ipFilter) UpdateRanges(r []*net.IPNet) { |
24 | for _, ipnet := range allowedRanges { | 28 | f.Lock() |
25 | if ipnet.Contains(ip) { | 29 | defer f.Unlock() |
26 | found = true | 30 | f.allowedRanges = r |
27 | break | 31 | } |
28 | } | 32 | |
29 | } | 33 | func (f *ipFilter) Middleware(next echo.HandlerFunc) echo.HandlerFunc { |
34 | _, v4Localhost, _ := net.ParseCIDR("127.0.0.0/8") | ||
35 | _, v6Localhost, _ := net.ParseCIDR("::1/128") | ||
36 | |||
37 | return func(c echo.Context) error { | ||
38 | f.RLock() | ||
39 | defer f.RUnlock() | ||
40 | |||
41 | if f.allowedRanges == nil { | ||
42 | c.Logger().Error("No allowed IPs configured for filter") | ||
43 | return echo.ErrNotFound | ||
44 | } | ||
30 | 45 | ||
31 | if !found { | 46 | ip := net.ParseIP(c.RealIP()) |
32 | c.Logger().Errorf("IP %s not in range for filter", c.RealIP()) | 47 | if ip == nil { |
33 | return echo.ErrNotFound | 48 | c.Logger().Error("Unable to parse IP in IPFilter") |
49 | return echo.ErrNotFound | ||
50 | } | ||
51 | |||
52 | found := false | ||
53 | for _, ipnet := range f.allowedRanges { | ||
54 | if ipnet.Contains(ip) { | ||
55 | found = true | ||
56 | break | ||
34 | } | 57 | } |
58 | } | ||
59 | |||
60 | if f.allowLocalhost && (v4Localhost.Contains(ip) || v6Localhost.Contains(ip)) { | ||
61 | found = true | ||
62 | } | ||
63 | |||
64 | if !found { | ||
65 | c.Logger().Errorf("IP %s not in range for filter", c.RealIP()) | ||
66 | return echo.ErrNotFound | ||
67 | } | ||
68 | |||
69 | return next(c) | ||
70 | } | ||
71 | } | ||
35 | 72 | ||
36 | return next(c) | 73 | func NewIPFilter(allowedRanges []*net.IPNet) echo.MiddlewareFunc { |
74 | return (&ipFilter{allowedRanges: allowedRanges}).Middleware | ||
75 | } | ||
76 | |||
77 | type NetboxIPFilter struct { | ||
78 | NetboxClient netbox.NetboxClient | ||
79 | Tag string | ||
80 | IncludeLocalhost bool | ||
81 | Logger echo.Logger | ||
82 | f *ipFilter | ||
83 | hasInit bool | ||
84 | } | ||
85 | |||
86 | func (f *NetboxIPFilter) Init() error { | ||
87 | nets, err := f.NetboxClient.GetPrefixesWithTag(f.Tag) | ||
88 | if err != nil { | ||
89 | return err | ||
90 | } | ||
91 | f.Logger.Debugf("Got prefixes: %s", nets) | ||
92 | |||
93 | f.f = &ipFilter{ | ||
94 | allowedRanges: nets, | ||
95 | allowLocalhost: f.IncludeLocalhost, | ||
96 | } | ||
97 | f.hasInit = true | ||
98 | |||
99 | return nil | ||
100 | } | ||
101 | |||
102 | func (f *NetboxIPFilter) Middleware(next echo.HandlerFunc) echo.HandlerFunc { | ||
103 | return f.f.Middleware(next) | ||
104 | } | ||
105 | |||
106 | func (f *NetboxIPFilter) RunRefresh(c context.Context, wg *sync.WaitGroup) error { | ||
107 | wg.Add(1) | ||
108 | defer wg.Done() | ||
109 | |||
110 | if !f.hasInit { | ||
111 | return fmt.Errorf("NetboxIPFilter: has not been initialized before RunRefresh called") | ||
112 | } | ||
113 | |||
114 | f.Logger.Info("Starting netbox IP address filter refresh loop") | ||
115 | |||
116 | t := time.NewTicker(time.Hour) | ||
117 | defer t.Stop() | ||
118 | |||
119 | for { | ||
120 | select { | ||
121 | case <-t.C: | ||
122 | if nets, err := f.NetboxClient.GetPrefixesWithTag(f.Tag); err != nil { | ||
123 | f.Logger.Errorf("Error refreshing netbox prefixes for IP filter: %w", err) | ||
124 | netboxFilterFailures.Inc() | ||
125 | } else { | ||
126 | f.Logger.Debugf("Got prefixes: %s", nets) | ||
127 | f.f.UpdateRanges(nets) | ||
128 | } | ||
129 | case <-c.Done(): | ||
130 | f.Logger.Info("Shutting down netbox IP address filter refresh loop") | ||
131 | return nil | ||
37 | } | 132 | } |
38 | } | 133 | } |
39 | } | 134 | } |