aboutsummaryrefslogtreecommitdiff
path: root/echo/middleware/ip_filter.go
diff options
context:
space:
mode:
Diffstat (limited to 'echo/middleware/ip_filter.go')
-rw-r--r--echo/middleware/ip_filter.go141
1 files changed, 118 insertions, 23 deletions
diff --git a/echo/middleware/ip_filter.go b/echo/middleware/ip_filter.go
index 007791e..2d79925 100644
--- a/echo/middleware/ip_filter.go
+++ b/echo/middleware/ip_filter.go
@@ -1,39 +1,134 @@
1package middleware 1package middleware
2 2
3import ( 3import (
4 "context"
5 "fmt"
4 "net" 6 "net"
7 "sync"
8 "time"
5 9
10 "code.crute.us/mcrute/golib/echo/netbox"
6 "github.com/labstack/echo/v4" 11 "github.com/labstack/echo/v4"
12 "github.com/prometheus/client_golang/prometheus"
13 "github.com/prometheus/client_golang/prometheus/promauto"
7) 14)
8 15
9func NewIPFilter(allowedRanges []*net.IPNet) echo.MiddlewareFunc { 16var netboxFilterFailures = promauto.NewCounter(prometheus.CounterOpts{
10 return func(next echo.HandlerFunc) echo.HandlerFunc { 17 Name: "netbox_ip_filter_refresh_failures",
11 return func(c echo.Context) error { 18 Help: "Total number of failures refreshing netbox sourced IP ranges",
12 if allowedRanges == nil { 19})
13 c.Logger().Error("No allowed IPs configured for filter")
14 return echo.ErrNotFound
15 }
16 20
17 ip := net.ParseIP(c.RealIP()) 21type ipFilter struct {
18 if ip == nil { 22 sync.RWMutex
19 c.Logger().Error("Unable to parse IP in IPFilter") 23 allowLocalhost bool
20 return echo.ErrNotFound 24 allowedRanges []*net.IPNet
21 } 25}
22 26
23 found := false 27func (f *ipFilter) UpdateRanges(r []*net.IPNet) {
24 for _, ipnet := range allowedRanges { 28 f.Lock()
25 if ipnet.Contains(ip) { 29 defer f.Unlock()
26 found = true 30 f.allowedRanges = r
27 break 31}
28 } 32
29 } 33func (f *ipFilter) Middleware(next echo.HandlerFunc) echo.HandlerFunc {
34 _, v4Localhost, _ := net.ParseCIDR("127.0.0.0/8")
35 _, v6Localhost, _ := net.ParseCIDR("::1/128")
36
37 return func(c echo.Context) error {
38 f.RLock()
39 defer f.RUnlock()
40
41 if f.allowedRanges == nil {
42 c.Logger().Error("No allowed IPs configured for filter")
43 return echo.ErrNotFound
44 }
30 45
31 if !found { 46 ip := net.ParseIP(c.RealIP())
32 c.Logger().Errorf("IP %s not in range for filter", c.RealIP()) 47 if ip == nil {
33 return echo.ErrNotFound 48 c.Logger().Error("Unable to parse IP in IPFilter")
49 return echo.ErrNotFound
50 }
51
52 found := false
53 for _, ipnet := range f.allowedRanges {
54 if ipnet.Contains(ip) {
55 found = true
56 break
34 } 57 }
58 }
59
60 if f.allowLocalhost && (v4Localhost.Contains(ip) || v6Localhost.Contains(ip)) {
61 found = true
62 }
63
64 if !found {
65 c.Logger().Errorf("IP %s not in range for filter", c.RealIP())
66 return echo.ErrNotFound
67 }
68
69 return next(c)
70 }
71}
35 72
36 return next(c) 73func NewIPFilter(allowedRanges []*net.IPNet) echo.MiddlewareFunc {
74 return (&ipFilter{allowedRanges: allowedRanges}).Middleware
75}
76
77type NetboxIPFilter struct {
78 NetboxClient netbox.NetboxClient
79 Tag string
80 IncludeLocalhost bool
81 Logger echo.Logger
82 f *ipFilter
83 hasInit bool
84}
85
86func (f *NetboxIPFilter) Init() error {
87 nets, err := f.NetboxClient.GetPrefixesWithTag(f.Tag)
88 if err != nil {
89 return err
90 }
91 f.Logger.Debugf("Got prefixes: %s", nets)
92
93 f.f = &ipFilter{
94 allowedRanges: nets,
95 allowLocalhost: f.IncludeLocalhost,
96 }
97 f.hasInit = true
98
99 return nil
100}
101
102func (f *NetboxIPFilter) Middleware(next echo.HandlerFunc) echo.HandlerFunc {
103 return f.f.Middleware(next)
104}
105
106func (f *NetboxIPFilter) RunRefresh(c context.Context, wg *sync.WaitGroup) error {
107 wg.Add(1)
108 defer wg.Done()
109
110 if !f.hasInit {
111 return fmt.Errorf("NetboxIPFilter: has not been initialized before RunRefresh called")
112 }
113
114 f.Logger.Info("Starting netbox IP address filter refresh loop")
115
116 t := time.NewTicker(time.Hour)
117 defer t.Stop()
118
119 for {
120 select {
121 case <-t.C:
122 if nets, err := f.NetboxClient.GetPrefixesWithTag(f.Tag); err != nil {
123 f.Logger.Errorf("Error refreshing netbox prefixes for IP filter: %w", err)
124 netboxFilterFailures.Inc()
125 } else {
126 f.Logger.Debugf("Got prefixes: %s", nets)
127 f.f.UpdateRanges(nets)
128 }
129 case <-c.Done():
130 f.Logger.Info("Shutting down netbox IP address filter refresh loop")
131 return nil
37 } 132 }
38 } 133 }
39} 134}