aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Crute <mike@crute.us>2022-11-15 20:58:02 -0800
committerMike Crute <mike@crute.us>2022-11-15 20:58:23 -0800
commit600378d36f109a5eccffb804134fb78f465768e9 (patch)
treef321a5a45625d6317102b2b8c44105a564ce035d
parent8df3d3d2ea65bf9053ad5191215bb01e29866498 (diff)
downloadgolib-600378d36f109a5eccffb804134fb78f465768e9.tar.bz2
golib-600378d36f109a5eccffb804134fb78f465768e9.tar.xz
golib-600378d36f109a5eccffb804134fb78f465768e9.zip
echo: add CSRF middleware
-rw-r--r--echo/go.mod2
-rw-r--r--echo/middleware/csrf.go337
-rw-r--r--echo/middleware/csrf_test.go47
3 files changed, 385 insertions, 1 deletions
diff --git a/echo/go.mod b/echo/go.mod
index 2fe2e2b..d62094d 100644
--- a/echo/go.mod
+++ b/echo/go.mod
@@ -1,6 +1,6 @@
1module code.crute.us/mcrute/golib/echo 1module code.crute.us/mcrute/golib/echo
2 2
3go 1.17 3go 1.18
4 4
5replace code.crute.us/mcrute/golib => ../ 5replace code.crute.us/mcrute/golib => ../
6 6
diff --git a/echo/middleware/csrf.go b/echo/middleware/csrf.go
new file mode 100644
index 0000000..f709964
--- /dev/null
+++ b/echo/middleware/csrf.go
@@ -0,0 +1,337 @@
1package middleware
2
3import (
4 "crypto/rand"
5 "crypto/subtle"
6 "encoding/base64"
7 "fmt"
8 "net/http"
9 "net/url"
10 "strings"
11
12 "code.crute.us/mcrute/golib/echo/session"
13 "github.com/labstack/echo/v4"
14 "github.com/labstack/echo/v4/middleware"
15)
16
17const (
18 csrfTokenContextKey = "__golib_echo_csrfToken"
19 csrfSkipContextKey = "__golib_echo_skipCsrf"
20)
21
22var (
23 ErrInvalidOrigin = echo.NewHTTPError(http.StatusBadRequest, "invalid origin")
24 ErrInvalidReferer = echo.NewHTTPError(http.StatusBadRequest, "invalid referer")
25 ErrMissingReferer = echo.NewHTTPError(http.StatusBadRequest, "referer header is missing")
26 ErrCSRFTokenMissing = echo.NewHTTPError(http.StatusBadRequest, "CSRF token is missing")
27 ErrCSRFTokenInvalid = echo.NewHTTPError(http.StatusBadRequest, "CSRF token is invalid")
28)
29
30type CSRFAwareSession interface {
31 session.Session
32 GetCSRFSecret() string
33 SetCSRFSecret(string)
34}
35
36func generateRandomBytes(l int) ([]byte, error) {
37 o := make([]byte, l)
38 if _, err := rand.Read(o); err != nil {
39 return nil, err
40 }
41 return o, nil
42}
43
44func xorBytes(a, b []byte) []byte {
45 n := len(a)
46 if len(b) < n {
47 n = len(b)
48 }
49
50 res := make([]byte, n)
51 for i := 0; i < n; i++ {
52 res[i] = a[i] ^ b[i]
53 }
54
55 return res
56}
57
58func mask(token []byte, size int) string {
59 sec, err := generateRandomBytes(size)
60 if err != nil {
61 return ""
62 }
63 val := append(sec, xorBytes(sec, token)...)
64 return base64.StdEncoding.EncodeToString(val)
65}
66
67func decodeAndMask(token string, size int) string {
68 t, err := base64.StdEncoding.DecodeString(token)
69 if err != nil {
70 return ""
71 }
72 return mask(t, size)
73}
74
75func unmask(token string, size int) ([]byte, error) {
76 t, err := base64.StdEncoding.DecodeString(token)
77 if err != nil {
78 return nil, err
79 }
80
81 if len(t) != size*2 {
82 return nil, fmt.Errorf("unmask: token length incorrect")
83 }
84
85 return xorBytes(t[size:], t[:size]), nil
86}
87
88func GetCSRFToken(c echo.Context) string {
89 return c.Get(csrfTokenContextKey).(string)
90}
91
92func DisableCSRFProtection() echo.MiddlewareFunc {
93 return func(next echo.HandlerFunc) echo.HandlerFunc {
94 return func(c echo.Context) error {
95 c.Set(csrfSkipContextKey, true)
96 return next(c)
97 }
98 }
99}
100
101type CSRFConfig[T CSRFAwareSession] struct {
102 Skipper middleware.Skipper
103 HeaderName string
104 FieldName string
105 TokenSize int
106 SessionStore session.Store[T]
107 TrustedOrigins []string
108}
109
110func CSRFProtect[T CSRFAwareSession](store session.Store[T]) echo.MiddlewareFunc {
111 return CSRFProtectWithConfig(CSRFConfig[T]{
112 Skipper: middleware.DefaultSkipper,
113 HeaderName: "X-CSRF-Token",
114 FieldName: "csrf-token",
115 TokenSize: 32,
116 SessionStore: store,
117 TrustedOrigins: []string{},
118 })
119}
120
121func CSRFProtectWithConfig[T CSRFAwareSession](cfg CSRFConfig[T]) echo.MiddlewareFunc {
122 if cfg.Skipper == nil {
123 cfg.Skipper = DefaultContentSecurityPolicyConfig.Skipper
124 }
125 if cfg.SessionStore == nil {
126 panic("CSRFProtectWithConfig: SessionStore must not be nil")
127 }
128
129 trustedOrigins := make([]*url.URL, len(cfg.TrustedOrigins))
130 for i, u := range cfg.TrustedOrigins {
131 pu, err := url.Parse(u)
132 if err != nil {
133 // These are critical errors that should fail at server boot
134 panic(fmt.Errorf("Error parsing TrustedOrigins URL: %w", err))
135 }
136 trustedOrigins[i] = pu
137 }
138
139 return func(next echo.HandlerFunc) echo.HandlerFunc {
140 return func(c echo.Context) error {
141 if cfg.Skipper(c) {
142 return next(c)
143 }
144
145 if skip, ok := c.Get(csrfSkipContextKey).(bool); ok && skip {
146 return next(c)
147 }
148
149 // If there's no CSRF token in the session then set one for future
150 // requests, even if this request is safe.
151 session := cfg.SessionStore.Get(c)
152 if session.GetCSRFSecret() == "" {
153 tok, err := generateRandomBytes(cfg.TokenSize)
154 if err != nil {
155 c.Logger().Errorf("CSRF: error generating token: %w", err)
156 return echo.ErrInternalServerError
157 }
158 session.SetCSRFSecret(base64.StdEncoding.EncodeToString(tok))
159 }
160
161 // Stash this in the request context so that other things can get to
162 // it without needing a direct session dependency.
163 c.Set(csrfTokenContextKey, decodeAndMask(session.GetCSRFSecret(), cfg.TokenSize))
164
165 // These methods are considered safe and do not require CRSF
166 // protection.
167 switch c.Request().Method {
168 case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
169 return next(c)
170 }
171
172 if origin := c.Request().Header.Get("Origin"); origin != "" {
173 if err := validateOrigin(c, origin, trustedOrigins); err != nil {
174 return err
175 }
176 } else if c.Request().URL.Scheme == "https" {
177 // If the Origin header wasn't provided, reject HTTPS requests if
178 // the Referer header doesn't match an allowed value.
179 //
180 // Suppose user visits http://example.com/
181 // An active network attacker (man-in-the-middle, MITM) sends a
182 // POST form that targets https://example.com/detonate-bomb/ and
183 // submits it via JavaScript.
184 //
185 // The attacker will need to provide a CSRF cookie and token, but
186 // that's no problem for a MITM and the session-independent secret
187 // we're using. So the MITM can circumvent the CSRF protection. This
188 // is true for any HTTP connection, but anyone using HTTPS expects
189 // better! For this reason, for https://example.com/ we need
190 // additional protection that treats http://example.com/ as
191 // completely untrusted. Under HTTPS, Barth et al. found that the
192 // Referer header is missing for same-domain requests in only about
193 // 0.2% of cases or less, so we can use strict Referer checking.
194 if err := validateReferer(c, trustedOrigins); err != nil {
195 return err
196 }
197 }
198
199 requestToken := getRequestToken(c.Request(), cfg.FieldName, cfg.HeaderName)
200 if requestToken == "" {
201 return ErrCSRFTokenMissing
202 }
203
204 unmasked, err := unmask(requestToken, cfg.TokenSize)
205 if err != nil {
206 c.Logger().Debugf("CSRF: error unmasking request token: %w", err)
207 return ErrCSRFTokenInvalid
208 }
209
210 sessionToken, err := base64.StdEncoding.DecodeString(session.GetCSRFSecret())
211 if err != nil {
212 c.Logger().Debugf("CSRF: error decoding session token: %w", err)
213 return ErrCSRFTokenInvalid
214 }
215
216 if subtle.ConstantTimeCompare(unmasked, sessionToken) != 1 {
217 c.Logger().Debug("CSRF: session token and request token are not equal")
218 return ErrCSRFTokenInvalid
219 }
220
221 c.Response().Header().Add("Vary", "Cookie")
222
223 return next(c)
224 }
225 }
226}
227
228func getRequestToken(r *http.Request, fieldName, headerName string) string {
229 token := r.Header.Get(headerName)
230 if token == "" {
231 token = r.PostFormValue(fieldName)
232 }
233 if token == "" && r.MultipartForm != nil {
234 vals := r.MultipartForm.Value[fieldName]
235 if len(vals) > 0 {
236 token = vals[0]
237 }
238 }
239 return token
240}
241
242func validateReferer(c echo.Context, trusted []*url.URL) error {
243 referer := c.Request().Referer()
244 if referer == "" {
245 c.Logger().Debugf("CSRF: no referer in request")
246 return ErrMissingReferer
247 }
248
249 u, err := url.Parse(referer)
250 if err != nil {
251 c.Logger().Debugf("CSRF: error parsing referer: %w", err)
252 return ErrInvalidReferer
253 }
254
255 if u.Host == "" {
256 c.Logger().Debugf("CSRF: no host in referer")
257 return ErrInvalidReferer
258 }
259
260 if u.Scheme != "https" {
261 c.Logger().Debugf("CSRF: referer is not https")
262 return ErrInvalidReferer
263 }
264
265 for _, o := range trusted {
266 if isSameDomain(u, o) {
267 return nil
268 }
269 }
270
271 return ErrInvalidReferer
272}
273
274func validateOrigin(c echo.Context, origin string, trusted []*url.URL) error {
275 parsedOrigin, err := url.Parse(origin)
276 if err != nil {
277 c.Logger().Debugf("CSRF: Error parsing origin: %w", err)
278 return ErrInvalidOrigin
279 }
280
281 thisHost := &url.URL{
282 Scheme: c.Request().URL.Scheme,
283 Host: c.Request().Host,
284 }
285
286 if thisHost.Scheme == "" {
287 if c.Request().TLS != nil {
288 thisHost.Scheme = "https"
289 } else {
290 thisHost.Scheme = "http"
291 }
292 }
293
294 if isSameDomain(parsedOrigin, thisHost) {
295 return nil
296 }
297
298 for _, to := range trusted {
299 if isSameDomain(parsedOrigin, to) {
300 return nil
301 }
302 }
303
304 return ErrInvalidOrigin
305}
306
307// isSameDomain checks that the host is either an exact match for the
308// pattern, or in the case that a pattern starts with a dot, that the
309// host has a suffix of that pattern to allow for subdomain matches
310// (e.g. .example.com matches example.com and foo.example.com)
311func isSameDomain(host, pattern *url.URL) bool {
312 if host == nil || pattern == nil {
313 return false
314 }
315
316 if host.Scheme != pattern.Scheme {
317 return false
318 }
319
320 host.Host = strings.ToLower(host.Host)
321 pattern.Host = strings.ToLower(pattern.Host)
322
323 if pattern.Host == host.Host {
324 return true
325 }
326
327 if strings.HasPrefix(pattern.Host, "*") {
328 if strings.HasSuffix(host.Host, pattern.Host[1:]) {
329 return true
330 }
331 if host.Host == pattern.Host[2:] {
332 return true
333 }
334 }
335
336 return false
337}
diff --git a/echo/middleware/csrf_test.go b/echo/middleware/csrf_test.go
new file mode 100644
index 0000000..6ac1d35
--- /dev/null
+++ b/echo/middleware/csrf_test.go
@@ -0,0 +1,47 @@
1package middleware
2
3import (
4 "net/url"
5 "testing"
6
7 "github.com/stretchr/testify/assert"
8)
9
10func TestIsSameDomain(t *testing.T) {
11 tests := []struct {
12 host string
13 pattern string
14 expectedMatch bool
15 }{
16 {"http://example.com", "http://EXAMPLE.com", true}, // Case difference
17 {"http://EXAMPLE.com", "http://EXAMPLE.com", true}, // Case difference
18 {"http://example.com", "http://example.com", true}, // Exact match
19 {"http://example.com", "http://*.example.com", true}, // Star match
20 {"http://www.example.com", "http://*.example.com", true}, // Subdomain match
21 {"http://www.foo.example.com", "http://*.example.com", true}, // Sub-subdomain match
22 {"https://example.com", "http://example.com", false}, // Scheme differs
23 {"http://www.example.com", "http://example.com", false}, // Subdomain vs exact
24 {"http://fooexample.com", "http://example.com", false}, // Similar suffix
25 {"http://fooexample.com", "http://*.example.com", false}, // Similar suffix wildcard
26 }
27
28 for _, tc := range tests {
29 host, err := url.Parse(tc.host)
30 assert.NoError(t, err)
31
32 pattern, err := url.Parse(tc.pattern)
33 assert.NoError(t, err)
34
35 if tc.expectedMatch {
36 assert.True(t, isSameDomain(host, pattern), "%s does not match %s but should", tc.host, tc.pattern)
37 } else {
38 assert.False(t, isSameDomain(host, pattern), "%s matches %s but should not", tc.host, tc.pattern)
39 }
40 }
41
42 testUrl, err := url.Parse("http://example.com")
43 assert.NoError(t, err)
44 assert.False(t, isSameDomain(nil, nil), "nil should never match")
45 assert.False(t, isSameDomain(testUrl, nil), "nil should never match")
46 assert.False(t, isSameDomain(nil, testUrl), "nil should never match")
47}