aboutsummaryrefslogtreecommitdiff
path: root/echo/middleware/csrf.go
blob: f709964d258ea06998f670a03f62d48ec1d8ec09 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
package middleware

import (
	"crypto/rand"
	"crypto/subtle"
	"encoding/base64"
	"fmt"
	"net/http"
	"net/url"
	"strings"

	"code.crute.us/mcrute/golib/echo/session"
	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"
)

const (
	csrfTokenContextKey = "__golib_echo_csrfToken"
	csrfSkipContextKey  = "__golib_echo_skipCsrf"
)

var (
	ErrInvalidOrigin    = echo.NewHTTPError(http.StatusBadRequest, "invalid origin")
	ErrInvalidReferer   = echo.NewHTTPError(http.StatusBadRequest, "invalid referer")
	ErrMissingReferer   = echo.NewHTTPError(http.StatusBadRequest, "referer header is missing")
	ErrCSRFTokenMissing = echo.NewHTTPError(http.StatusBadRequest, "CSRF token is missing")
	ErrCSRFTokenInvalid = echo.NewHTTPError(http.StatusBadRequest, "CSRF token is invalid")
)

type CSRFAwareSession interface {
	session.Session
	GetCSRFSecret() string
	SetCSRFSecret(string)
}

func generateRandomBytes(l int) ([]byte, error) {
	o := make([]byte, l)
	if _, err := rand.Read(o); err != nil {
		return nil, err
	}
	return o, nil
}

func xorBytes(a, b []byte) []byte {
	n := len(a)
	if len(b) < n {
		n = len(b)
	}

	res := make([]byte, n)
	for i := 0; i < n; i++ {
		res[i] = a[i] ^ b[i]
	}

	return res
}

func mask(token []byte, size int) string {
	sec, err := generateRandomBytes(size)
	if err != nil {
		return ""
	}
	val := append(sec, xorBytes(sec, token)...)
	return base64.StdEncoding.EncodeToString(val)
}

func decodeAndMask(token string, size int) string {
	t, err := base64.StdEncoding.DecodeString(token)
	if err != nil {
		return ""
	}
	return mask(t, size)
}

func unmask(token string, size int) ([]byte, error) {
	t, err := base64.StdEncoding.DecodeString(token)
	if err != nil {
		return nil, err
	}

	if len(t) != size*2 {
		return nil, fmt.Errorf("unmask: token length incorrect")
	}

	return xorBytes(t[size:], t[:size]), nil
}

func GetCSRFToken(c echo.Context) string {
	return c.Get(csrfTokenContextKey).(string)
}

func DisableCSRFProtection() echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			c.Set(csrfSkipContextKey, true)
			return next(c)
		}
	}
}

type CSRFConfig[T CSRFAwareSession] struct {
	Skipper        middleware.Skipper
	HeaderName     string
	FieldName      string
	TokenSize      int
	SessionStore   session.Store[T]
	TrustedOrigins []string
}

func CSRFProtect[T CSRFAwareSession](store session.Store[T]) echo.MiddlewareFunc {
	return CSRFProtectWithConfig(CSRFConfig[T]{
		Skipper:        middleware.DefaultSkipper,
		HeaderName:     "X-CSRF-Token",
		FieldName:      "csrf-token",
		TokenSize:      32,
		SessionStore:   store,
		TrustedOrigins: []string{},
	})
}

func CSRFProtectWithConfig[T CSRFAwareSession](cfg CSRFConfig[T]) echo.MiddlewareFunc {
	if cfg.Skipper == nil {
		cfg.Skipper = DefaultContentSecurityPolicyConfig.Skipper
	}
	if cfg.SessionStore == nil {
		panic("CSRFProtectWithConfig: SessionStore must not be nil")
	}

	trustedOrigins := make([]*url.URL, len(cfg.TrustedOrigins))
	for i, u := range cfg.TrustedOrigins {
		pu, err := url.Parse(u)
		if err != nil {
			// These are critical errors that should fail at server boot
			panic(fmt.Errorf("Error parsing TrustedOrigins URL: %w", err))
		}
		trustedOrigins[i] = pu
	}

	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			if cfg.Skipper(c) {
				return next(c)
			}

			if skip, ok := c.Get(csrfSkipContextKey).(bool); ok && skip {
				return next(c)
			}

			// If there's no CSRF token in the session then set one for future
			// requests, even if this request is safe.
			session := cfg.SessionStore.Get(c)
			if session.GetCSRFSecret() == "" {
				tok, err := generateRandomBytes(cfg.TokenSize)
				if err != nil {
					c.Logger().Errorf("CSRF: error generating token: %w", err)
					return echo.ErrInternalServerError
				}
				session.SetCSRFSecret(base64.StdEncoding.EncodeToString(tok))
			}

			// Stash this in the request context so that other things can get to
			// it without needing a direct session dependency.
			c.Set(csrfTokenContextKey, decodeAndMask(session.GetCSRFSecret(), cfg.TokenSize))

			// These methods are considered safe and do not require CRSF
			// protection.
			switch c.Request().Method {
			case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
				return next(c)
			}

			if origin := c.Request().Header.Get("Origin"); origin != "" {
				if err := validateOrigin(c, origin, trustedOrigins); err != nil {
					return err
				}
			} else if c.Request().URL.Scheme == "https" {
				// If the Origin header wasn't provided, reject HTTPS requests if
				// the Referer header doesn't match an allowed value.
				//
				// Suppose user visits http://example.com/
				// An active network attacker (man-in-the-middle, MITM) sends a
				// POST form that targets https://example.com/detonate-bomb/ and
				// submits it via JavaScript.
				//
				// The attacker will need to provide a CSRF cookie and token, but
				// that's no problem for a MITM and the session-independent secret
				// we're using. So the MITM can circumvent the CSRF protection. This
				// is true for any HTTP connection, but anyone using HTTPS expects
				// better! For this reason, for https://example.com/ we need
				// additional protection that treats http://example.com/ as
				// completely untrusted. Under HTTPS, Barth et al. found that the
				// Referer header is missing for same-domain requests in only about
				// 0.2% of cases or less, so we can use strict Referer checking.
				if err := validateReferer(c, trustedOrigins); err != nil {
					return err
				}
			}

			requestToken := getRequestToken(c.Request(), cfg.FieldName, cfg.HeaderName)
			if requestToken == "" {
				return ErrCSRFTokenMissing
			}

			unmasked, err := unmask(requestToken, cfg.TokenSize)
			if err != nil {
				c.Logger().Debugf("CSRF: error unmasking request token: %w", err)
				return ErrCSRFTokenInvalid
			}

			sessionToken, err := base64.StdEncoding.DecodeString(session.GetCSRFSecret())
			if err != nil {
				c.Logger().Debugf("CSRF: error decoding session token: %w", err)
				return ErrCSRFTokenInvalid
			}

			if subtle.ConstantTimeCompare(unmasked, sessionToken) != 1 {
				c.Logger().Debug("CSRF: session token and request token are not equal")
				return ErrCSRFTokenInvalid
			}

			c.Response().Header().Add("Vary", "Cookie")

			return next(c)
		}
	}
}

func getRequestToken(r *http.Request, fieldName, headerName string) string {
	token := r.Header.Get(headerName)
	if token == "" {
		token = r.PostFormValue(fieldName)
	}
	if token == "" && r.MultipartForm != nil {
		vals := r.MultipartForm.Value[fieldName]
		if len(vals) > 0 {
			token = vals[0]
		}
	}
	return token
}

func validateReferer(c echo.Context, trusted []*url.URL) error {
	referer := c.Request().Referer()
	if referer == "" {
		c.Logger().Debugf("CSRF: no referer in request")
		return ErrMissingReferer
	}

	u, err := url.Parse(referer)
	if err != nil {
		c.Logger().Debugf("CSRF: error parsing referer: %w", err)
		return ErrInvalidReferer
	}

	if u.Host == "" {
		c.Logger().Debugf("CSRF: no host in referer")
		return ErrInvalidReferer
	}

	if u.Scheme != "https" {
		c.Logger().Debugf("CSRF: referer is not https")
		return ErrInvalidReferer
	}

	for _, o := range trusted {
		if isSameDomain(u, o) {
			return nil
		}
	}

	return ErrInvalidReferer
}

func validateOrigin(c echo.Context, origin string, trusted []*url.URL) error {
	parsedOrigin, err := url.Parse(origin)
	if err != nil {
		c.Logger().Debugf("CSRF: Error parsing origin: %w", err)
		return ErrInvalidOrigin
	}

	thisHost := &url.URL{
		Scheme: c.Request().URL.Scheme,
		Host:   c.Request().Host,
	}

	if thisHost.Scheme == "" {
		if c.Request().TLS != nil {
			thisHost.Scheme = "https"
		} else {
			thisHost.Scheme = "http"
		}
	}

	if isSameDomain(parsedOrigin, thisHost) {
		return nil
	}

	for _, to := range trusted {
		if isSameDomain(parsedOrigin, to) {
			return nil
		}
	}

	return ErrInvalidOrigin
}

// isSameDomain checks that the host is either an exact match for the
// pattern, or in the case that a pattern starts with a dot, that the
// host has a suffix of that pattern to allow for subdomain matches
// (e.g. .example.com matches example.com and foo.example.com)
func isSameDomain(host, pattern *url.URL) bool {
	if host == nil || pattern == nil {
		return false
	}

	if host.Scheme != pattern.Scheme {
		return false
	}

	host.Host = strings.ToLower(host.Host)
	pattern.Host = strings.ToLower(pattern.Host)

	if pattern.Host == host.Host {
		return true
	}

	if strings.HasPrefix(pattern.Host, "*") {
		if strings.HasSuffix(host.Host, pattern.Host[1:]) {
			return true
		}
		if host.Host == pattern.Host[2:] {
			return true
		}
	}

	return false
}