diff options
author | Mike Crute <mike@crute.us> | 2022-11-15 20:58:02 -0800 |
---|---|---|
committer | Mike Crute <mike@crute.us> | 2022-11-15 20:58:23 -0800 |
commit | 600378d36f109a5eccffb804134fb78f465768e9 (patch) | |
tree | f321a5a45625d6317102b2b8c44105a564ce035d | |
parent | 8df3d3d2ea65bf9053ad5191215bb01e29866498 (diff) | |
download | golib-600378d36f109a5eccffb804134fb78f465768e9.tar.bz2 golib-600378d36f109a5eccffb804134fb78f465768e9.tar.xz golib-600378d36f109a5eccffb804134fb78f465768e9.zip |
echo: add CSRF middleware
-rw-r--r-- | echo/go.mod | 2 | ||||
-rw-r--r-- | echo/middleware/csrf.go | 337 | ||||
-rw-r--r-- | echo/middleware/csrf_test.go | 47 |
3 files changed, 385 insertions, 1 deletions
diff --git a/echo/go.mod b/echo/go.mod index 2fe2e2b..d62094d 100644 --- a/echo/go.mod +++ b/echo/go.mod | |||
@@ -1,6 +1,6 @@ | |||
1 | module code.crute.us/mcrute/golib/echo | 1 | module code.crute.us/mcrute/golib/echo |
2 | 2 | ||
3 | go 1.17 | 3 | go 1.18 |
4 | 4 | ||
5 | replace code.crute.us/mcrute/golib => ../ | 5 | replace code.crute.us/mcrute/golib => ../ |
6 | 6 | ||
diff --git a/echo/middleware/csrf.go b/echo/middleware/csrf.go new file mode 100644 index 0000000..f709964 --- /dev/null +++ b/echo/middleware/csrf.go | |||
@@ -0,0 +1,337 @@ | |||
1 | package middleware | ||
2 | |||
3 | import ( | ||
4 | "crypto/rand" | ||
5 | "crypto/subtle" | ||
6 | "encoding/base64" | ||
7 | "fmt" | ||
8 | "net/http" | ||
9 | "net/url" | ||
10 | "strings" | ||
11 | |||
12 | "code.crute.us/mcrute/golib/echo/session" | ||
13 | "github.com/labstack/echo/v4" | ||
14 | "github.com/labstack/echo/v4/middleware" | ||
15 | ) | ||
16 | |||
17 | const ( | ||
18 | csrfTokenContextKey = "__golib_echo_csrfToken" | ||
19 | csrfSkipContextKey = "__golib_echo_skipCsrf" | ||
20 | ) | ||
21 | |||
22 | var ( | ||
23 | ErrInvalidOrigin = echo.NewHTTPError(http.StatusBadRequest, "invalid origin") | ||
24 | ErrInvalidReferer = echo.NewHTTPError(http.StatusBadRequest, "invalid referer") | ||
25 | ErrMissingReferer = echo.NewHTTPError(http.StatusBadRequest, "referer header is missing") | ||
26 | ErrCSRFTokenMissing = echo.NewHTTPError(http.StatusBadRequest, "CSRF token is missing") | ||
27 | ErrCSRFTokenInvalid = echo.NewHTTPError(http.StatusBadRequest, "CSRF token is invalid") | ||
28 | ) | ||
29 | |||
30 | type CSRFAwareSession interface { | ||
31 | session.Session | ||
32 | GetCSRFSecret() string | ||
33 | SetCSRFSecret(string) | ||
34 | } | ||
35 | |||
36 | func generateRandomBytes(l int) ([]byte, error) { | ||
37 | o := make([]byte, l) | ||
38 | if _, err := rand.Read(o); err != nil { | ||
39 | return nil, err | ||
40 | } | ||
41 | return o, nil | ||
42 | } | ||
43 | |||
44 | func xorBytes(a, b []byte) []byte { | ||
45 | n := len(a) | ||
46 | if len(b) < n { | ||
47 | n = len(b) | ||
48 | } | ||
49 | |||
50 | res := make([]byte, n) | ||
51 | for i := 0; i < n; i++ { | ||
52 | res[i] = a[i] ^ b[i] | ||
53 | } | ||
54 | |||
55 | return res | ||
56 | } | ||
57 | |||
58 | func mask(token []byte, size int) string { | ||
59 | sec, err := generateRandomBytes(size) | ||
60 | if err != nil { | ||
61 | return "" | ||
62 | } | ||
63 | val := append(sec, xorBytes(sec, token)...) | ||
64 | return base64.StdEncoding.EncodeToString(val) | ||
65 | } | ||
66 | |||
67 | func decodeAndMask(token string, size int) string { | ||
68 | t, err := base64.StdEncoding.DecodeString(token) | ||
69 | if err != nil { | ||
70 | return "" | ||
71 | } | ||
72 | return mask(t, size) | ||
73 | } | ||
74 | |||
75 | func unmask(token string, size int) ([]byte, error) { | ||
76 | t, err := base64.StdEncoding.DecodeString(token) | ||
77 | if err != nil { | ||
78 | return nil, err | ||
79 | } | ||
80 | |||
81 | if len(t) != size*2 { | ||
82 | return nil, fmt.Errorf("unmask: token length incorrect") | ||
83 | } | ||
84 | |||
85 | return xorBytes(t[size:], t[:size]), nil | ||
86 | } | ||
87 | |||
88 | func GetCSRFToken(c echo.Context) string { | ||
89 | return c.Get(csrfTokenContextKey).(string) | ||
90 | } | ||
91 | |||
92 | func DisableCSRFProtection() echo.MiddlewareFunc { | ||
93 | return func(next echo.HandlerFunc) echo.HandlerFunc { | ||
94 | return func(c echo.Context) error { | ||
95 | c.Set(csrfSkipContextKey, true) | ||
96 | return next(c) | ||
97 | } | ||
98 | } | ||
99 | } | ||
100 | |||
101 | type CSRFConfig[T CSRFAwareSession] struct { | ||
102 | Skipper middleware.Skipper | ||
103 | HeaderName string | ||
104 | FieldName string | ||
105 | TokenSize int | ||
106 | SessionStore session.Store[T] | ||
107 | TrustedOrigins []string | ||
108 | } | ||
109 | |||
110 | func CSRFProtect[T CSRFAwareSession](store session.Store[T]) echo.MiddlewareFunc { | ||
111 | return CSRFProtectWithConfig(CSRFConfig[T]{ | ||
112 | Skipper: middleware.DefaultSkipper, | ||
113 | HeaderName: "X-CSRF-Token", | ||
114 | FieldName: "csrf-token", | ||
115 | TokenSize: 32, | ||
116 | SessionStore: store, | ||
117 | TrustedOrigins: []string{}, | ||
118 | }) | ||
119 | } | ||
120 | |||
121 | func CSRFProtectWithConfig[T CSRFAwareSession](cfg CSRFConfig[T]) echo.MiddlewareFunc { | ||
122 | if cfg.Skipper == nil { | ||
123 | cfg.Skipper = DefaultContentSecurityPolicyConfig.Skipper | ||
124 | } | ||
125 | if cfg.SessionStore == nil { | ||
126 | panic("CSRFProtectWithConfig: SessionStore must not be nil") | ||
127 | } | ||
128 | |||
129 | trustedOrigins := make([]*url.URL, len(cfg.TrustedOrigins)) | ||
130 | for i, u := range cfg.TrustedOrigins { | ||
131 | pu, err := url.Parse(u) | ||
132 | if err != nil { | ||
133 | // These are critical errors that should fail at server boot | ||
134 | panic(fmt.Errorf("Error parsing TrustedOrigins URL: %w", err)) | ||
135 | } | ||
136 | trustedOrigins[i] = pu | ||
137 | } | ||
138 | |||
139 | return func(next echo.HandlerFunc) echo.HandlerFunc { | ||
140 | return func(c echo.Context) error { | ||
141 | if cfg.Skipper(c) { | ||
142 | return next(c) | ||
143 | } | ||
144 | |||
145 | if skip, ok := c.Get(csrfSkipContextKey).(bool); ok && skip { | ||
146 | return next(c) | ||
147 | } | ||
148 | |||
149 | // If there's no CSRF token in the session then set one for future | ||
150 | // requests, even if this request is safe. | ||
151 | session := cfg.SessionStore.Get(c) | ||
152 | if session.GetCSRFSecret() == "" { | ||
153 | tok, err := generateRandomBytes(cfg.TokenSize) | ||
154 | if err != nil { | ||
155 | c.Logger().Errorf("CSRF: error generating token: %w", err) | ||
156 | return echo.ErrInternalServerError | ||
157 | } | ||
158 | session.SetCSRFSecret(base64.StdEncoding.EncodeToString(tok)) | ||
159 | } | ||
160 | |||
161 | // Stash this in the request context so that other things can get to | ||
162 | // it without needing a direct session dependency. | ||
163 | c.Set(csrfTokenContextKey, decodeAndMask(session.GetCSRFSecret(), cfg.TokenSize)) | ||
164 | |||
165 | // These methods are considered safe and do not require CRSF | ||
166 | // protection. | ||
167 | switch c.Request().Method { | ||
168 | case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: | ||
169 | return next(c) | ||
170 | } | ||
171 | |||
172 | if origin := c.Request().Header.Get("Origin"); origin != "" { | ||
173 | if err := validateOrigin(c, origin, trustedOrigins); err != nil { | ||
174 | return err | ||
175 | } | ||
176 | } else if c.Request().URL.Scheme == "https" { | ||
177 | // If the Origin header wasn't provided, reject HTTPS requests if | ||
178 | // the Referer header doesn't match an allowed value. | ||
179 | // | ||
180 | // Suppose user visits http://example.com/ | ||
181 | // An active network attacker (man-in-the-middle, MITM) sends a | ||
182 | // POST form that targets https://example.com/detonate-bomb/ and | ||
183 | // submits it via JavaScript. | ||
184 | // | ||
185 | // The attacker will need to provide a CSRF cookie and token, but | ||
186 | // that's no problem for a MITM and the session-independent secret | ||
187 | // we're using. So the MITM can circumvent the CSRF protection. This | ||
188 | // is true for any HTTP connection, but anyone using HTTPS expects | ||
189 | // better! For this reason, for https://example.com/ we need | ||
190 | // additional protection that treats http://example.com/ as | ||
191 | // completely untrusted. Under HTTPS, Barth et al. found that the | ||
192 | // Referer header is missing for same-domain requests in only about | ||
193 | // 0.2% of cases or less, so we can use strict Referer checking. | ||
194 | if err := validateReferer(c, trustedOrigins); err != nil { | ||
195 | return err | ||
196 | } | ||
197 | } | ||
198 | |||
199 | requestToken := getRequestToken(c.Request(), cfg.FieldName, cfg.HeaderName) | ||
200 | if requestToken == "" { | ||
201 | return ErrCSRFTokenMissing | ||
202 | } | ||
203 | |||
204 | unmasked, err := unmask(requestToken, cfg.TokenSize) | ||
205 | if err != nil { | ||
206 | c.Logger().Debugf("CSRF: error unmasking request token: %w", err) | ||
207 | return ErrCSRFTokenInvalid | ||
208 | } | ||
209 | |||
210 | sessionToken, err := base64.StdEncoding.DecodeString(session.GetCSRFSecret()) | ||
211 | if err != nil { | ||
212 | c.Logger().Debugf("CSRF: error decoding session token: %w", err) | ||
213 | return ErrCSRFTokenInvalid | ||
214 | } | ||
215 | |||
216 | if subtle.ConstantTimeCompare(unmasked, sessionToken) != 1 { | ||
217 | c.Logger().Debug("CSRF: session token and request token are not equal") | ||
218 | return ErrCSRFTokenInvalid | ||
219 | } | ||
220 | |||
221 | c.Response().Header().Add("Vary", "Cookie") | ||
222 | |||
223 | return next(c) | ||
224 | } | ||
225 | } | ||
226 | } | ||
227 | |||
228 | func getRequestToken(r *http.Request, fieldName, headerName string) string { | ||
229 | token := r.Header.Get(headerName) | ||
230 | if token == "" { | ||
231 | token = r.PostFormValue(fieldName) | ||
232 | } | ||
233 | if token == "" && r.MultipartForm != nil { | ||
234 | vals := r.MultipartForm.Value[fieldName] | ||
235 | if len(vals) > 0 { | ||
236 | token = vals[0] | ||
237 | } | ||
238 | } | ||
239 | return token | ||
240 | } | ||
241 | |||
242 | func validateReferer(c echo.Context, trusted []*url.URL) error { | ||
243 | referer := c.Request().Referer() | ||
244 | if referer == "" { | ||
245 | c.Logger().Debugf("CSRF: no referer in request") | ||
246 | return ErrMissingReferer | ||
247 | } | ||
248 | |||
249 | u, err := url.Parse(referer) | ||
250 | if err != nil { | ||
251 | c.Logger().Debugf("CSRF: error parsing referer: %w", err) | ||
252 | return ErrInvalidReferer | ||
253 | } | ||
254 | |||
255 | if u.Host == "" { | ||
256 | c.Logger().Debugf("CSRF: no host in referer") | ||
257 | return ErrInvalidReferer | ||
258 | } | ||
259 | |||
260 | if u.Scheme != "https" { | ||
261 | c.Logger().Debugf("CSRF: referer is not https") | ||
262 | return ErrInvalidReferer | ||
263 | } | ||
264 | |||
265 | for _, o := range trusted { | ||
266 | if isSameDomain(u, o) { | ||
267 | return nil | ||
268 | } | ||
269 | } | ||
270 | |||
271 | return ErrInvalidReferer | ||
272 | } | ||
273 | |||
274 | func validateOrigin(c echo.Context, origin string, trusted []*url.URL) error { | ||
275 | parsedOrigin, err := url.Parse(origin) | ||
276 | if err != nil { | ||
277 | c.Logger().Debugf("CSRF: Error parsing origin: %w", err) | ||
278 | return ErrInvalidOrigin | ||
279 | } | ||
280 | |||
281 | thisHost := &url.URL{ | ||
282 | Scheme: c.Request().URL.Scheme, | ||
283 | Host: c.Request().Host, | ||
284 | } | ||
285 | |||
286 | if thisHost.Scheme == "" { | ||
287 | if c.Request().TLS != nil { | ||
288 | thisHost.Scheme = "https" | ||
289 | } else { | ||
290 | thisHost.Scheme = "http" | ||
291 | } | ||
292 | } | ||
293 | |||
294 | if isSameDomain(parsedOrigin, thisHost) { | ||
295 | return nil | ||
296 | } | ||
297 | |||
298 | for _, to := range trusted { | ||
299 | if isSameDomain(parsedOrigin, to) { | ||
300 | return nil | ||
301 | } | ||
302 | } | ||
303 | |||
304 | return ErrInvalidOrigin | ||
305 | } | ||
306 | |||
307 | // isSameDomain checks that the host is either an exact match for the | ||
308 | // pattern, or in the case that a pattern starts with a dot, that the | ||
309 | // host has a suffix of that pattern to allow for subdomain matches | ||
310 | // (e.g. .example.com matches example.com and foo.example.com) | ||
311 | func isSameDomain(host, pattern *url.URL) bool { | ||
312 | if host == nil || pattern == nil { | ||
313 | return false | ||
314 | } | ||
315 | |||
316 | if host.Scheme != pattern.Scheme { | ||
317 | return false | ||
318 | } | ||
319 | |||
320 | host.Host = strings.ToLower(host.Host) | ||
321 | pattern.Host = strings.ToLower(pattern.Host) | ||
322 | |||
323 | if pattern.Host == host.Host { | ||
324 | return true | ||
325 | } | ||
326 | |||
327 | if strings.HasPrefix(pattern.Host, "*") { | ||
328 | if strings.HasSuffix(host.Host, pattern.Host[1:]) { | ||
329 | return true | ||
330 | } | ||
331 | if host.Host == pattern.Host[2:] { | ||
332 | return true | ||
333 | } | ||
334 | } | ||
335 | |||
336 | return false | ||
337 | } | ||
diff --git a/echo/middleware/csrf_test.go b/echo/middleware/csrf_test.go new file mode 100644 index 0000000..6ac1d35 --- /dev/null +++ b/echo/middleware/csrf_test.go | |||
@@ -0,0 +1,47 @@ | |||
1 | package middleware | ||
2 | |||
3 | import ( | ||
4 | "net/url" | ||
5 | "testing" | ||
6 | |||
7 | "github.com/stretchr/testify/assert" | ||
8 | ) | ||
9 | |||
10 | func TestIsSameDomain(t *testing.T) { | ||
11 | tests := []struct { | ||
12 | host string | ||
13 | pattern string | ||
14 | expectedMatch bool | ||
15 | }{ | ||
16 | {"http://example.com", "http://EXAMPLE.com", true}, // Case difference | ||
17 | {"http://EXAMPLE.com", "http://EXAMPLE.com", true}, // Case difference | ||
18 | {"http://example.com", "http://example.com", true}, // Exact match | ||
19 | {"http://example.com", "http://*.example.com", true}, // Star match | ||
20 | {"http://www.example.com", "http://*.example.com", true}, // Subdomain match | ||
21 | {"http://www.foo.example.com", "http://*.example.com", true}, // Sub-subdomain match | ||
22 | {"https://example.com", "http://example.com", false}, // Scheme differs | ||
23 | {"http://www.example.com", "http://example.com", false}, // Subdomain vs exact | ||
24 | {"http://fooexample.com", "http://example.com", false}, // Similar suffix | ||
25 | {"http://fooexample.com", "http://*.example.com", false}, // Similar suffix wildcard | ||
26 | } | ||
27 | |||
28 | for _, tc := range tests { | ||
29 | host, err := url.Parse(tc.host) | ||
30 | assert.NoError(t, err) | ||
31 | |||
32 | pattern, err := url.Parse(tc.pattern) | ||
33 | assert.NoError(t, err) | ||
34 | |||
35 | if tc.expectedMatch { | ||
36 | assert.True(t, isSameDomain(host, pattern), "%s does not match %s but should", tc.host, tc.pattern) | ||
37 | } else { | ||
38 | assert.False(t, isSameDomain(host, pattern), "%s matches %s but should not", tc.host, tc.pattern) | ||
39 | } | ||
40 | } | ||
41 | |||
42 | testUrl, err := url.Parse("http://example.com") | ||
43 | assert.NoError(t, err) | ||
44 | assert.False(t, isSameDomain(nil, nil), "nil should never match") | ||
45 | assert.False(t, isSameDomain(testUrl, nil), "nil should never match") | ||
46 | assert.False(t, isSameDomain(nil, testUrl), "nil should never match") | ||
47 | } | ||