diff options
author | Mike Crute <mike@crute.us> | 2022-11-15 20:58:53 -0800 |
---|---|---|
committer | Mike Crute <mike@crute.us> | 2022-11-15 20:58:53 -0800 |
commit | 6e4a03e9cac2e774208a9189a4af646e69a658a8 (patch) | |
tree | 438f28960dd3baa403a3de1a12de5c3b89e646a5 | |
parent | 600378d36f109a5eccffb804134fb78f465768e9 (diff) | |
download | golib-echo/v0.8.0.tar.bz2 golib-echo/v0.8.0.tar.xz golib-echo/v0.8.0.zip |
echo: add session store and middlewareecho/v0.8.0
-rw-r--r-- | echo/session/middleware.go | 60 | ||||
-rw-r--r-- | echo/session/store.go | 248 | ||||
-rw-r--r-- | echo/session/store_test.go | 67 |
3 files changed, 375 insertions, 0 deletions
diff --git a/echo/session/middleware.go b/echo/session/middleware.go new file mode 100644 index 0000000..7c1d1d2 --- /dev/null +++ b/echo/session/middleware.go | |||
@@ -0,0 +1,60 @@ | |||
1 | package session | ||
2 | |||
3 | import ( | ||
4 | "github.com/labstack/echo/v4" | ||
5 | "github.com/labstack/echo/v4/middleware" | ||
6 | ) | ||
7 | |||
8 | const ( | ||
9 | sessionStoreContextKey = "__golib_echo_sessionStore" | ||
10 | ) | ||
11 | |||
12 | type Config[T Session] struct { | ||
13 | Skipper middleware.Skipper | ||
14 | Store Store[T] | ||
15 | } | ||
16 | |||
17 | func GetSessionStore[T Session](c echo.Context) Store[T] { | ||
18 | store, ok := c.Get(sessionStoreContextKey).(Store[T]) | ||
19 | if !ok { | ||
20 | // This should never happen if the middleware is run | ||
21 | panic("GetSessionStore: error fetching session store") | ||
22 | } | ||
23 | return store | ||
24 | } | ||
25 | |||
26 | func Middleware[T Session](store Store[T]) echo.MiddlewareFunc { | ||
27 | return MiddlewareWithConfig(Config[T]{ | ||
28 | Skipper: middleware.DefaultSkipper, | ||
29 | Store: store, | ||
30 | }) | ||
31 | } | ||
32 | |||
33 | func MiddlewareWithConfig[T Session](config Config[T]) echo.MiddlewareFunc { | ||
34 | if config.Skipper == nil { | ||
35 | config.Skipper = middleware.DefaultSkipper | ||
36 | } | ||
37 | if config.Store == nil { | ||
38 | panic("session/MiddlewareWithConfig: Store is required for middleware") | ||
39 | } | ||
40 | |||
41 | return func(next echo.HandlerFunc) echo.HandlerFunc { | ||
42 | return func(c echo.Context) error { | ||
43 | if config.Skipper(c) { | ||
44 | return next(c) | ||
45 | } | ||
46 | |||
47 | c.Set(sessionStoreContextKey, config.Store) | ||
48 | |||
49 | // This is executed by echo before headers are written to the client | ||
50 | // and can manipulate the headers first. | ||
51 | c.Response().Before(func() { | ||
52 | if err := config.Store.Commit(c); err != nil { | ||
53 | c.Logger().Error("Error committing session: %s", err) | ||
54 | } | ||
55 | }) | ||
56 | |||
57 | return next(c) | ||
58 | } | ||
59 | } | ||
60 | } | ||
diff --git a/echo/session/store.go b/echo/session/store.go new file mode 100644 index 0000000..f82b3b6 --- /dev/null +++ b/echo/session/store.go | |||
@@ -0,0 +1,248 @@ | |||
1 | package session | ||
2 | |||
3 | import ( | ||
4 | "crypto/rsa" | ||
5 | "encoding/json" | ||
6 | "errors" | ||
7 | "fmt" | ||
8 | "net/http" | ||
9 | "time" | ||
10 | |||
11 | "github.com/labstack/echo/v4" | ||
12 | "gopkg.in/square/go-jose.v2" | ||
13 | ) | ||
14 | |||
15 | const ( | ||
16 | sessionCookieName = "sinfo" | ||
17 | sessionContextKey = "__golib_echo_session" | ||
18 | sessionDirtyContextKey = "__golib_echo_sessionDirty" | ||
19 | ) | ||
20 | |||
21 | type Session interface { | ||
22 | Expires() time.Time | ||
23 | } | ||
24 | |||
25 | type Store[T Session] interface { | ||
26 | // New creates a new session and attaches it to the request context. | ||
27 | New(echo.Context) T | ||
28 | |||
29 | // Get will get a session. If the session exists and has not been | ||
30 | // decoded it will first decoded it, validate it, and cache it on the | ||
31 | // request context. Subsequent requests to Get will not parse the | ||
32 | // session again. If the session does not exist or can not be decoded | ||
33 | // a new session will be returned. | ||
34 | Get(echo.Context) T | ||
35 | |||
36 | // GetStrict is like Get but will return errors if anything goes wrong. | ||
37 | // This is not normally what you want but can be useful for debugging | ||
38 | // session store issues. Unlike Get, GetStrict will never return a new | ||
39 | // session when there is an error. | ||
40 | GetStrict(echo.Context) (T, error) | ||
41 | |||
42 | // Update will update the session in the request context and mark | ||
43 | // the session as dirty. Only dirty sessions can be committed to the | ||
44 | // client. | ||
45 | Update(echo.Context, T) error | ||
46 | |||
47 | // Delete will mark the session as deleted which will result in removal | ||
48 | // of the session from storage during commit. | ||
49 | Delete(echo.Context) error | ||
50 | |||
51 | // Commit causes the store to persist changes to the session to the | ||
52 | // storage mechanism and add a session identifier cookie to the client | ||
53 | // response. Only dirty or new sessions will result in changes to the | ||
54 | // client and storage. | ||
55 | Commit(echo.Context) error | ||
56 | } | ||
57 | |||
58 | type CookieStore[T Session] struct { | ||
59 | key *rsa.PrivateKey | ||
60 | enc jose.Encrypter | ||
61 | factory func(echo.Context) T | ||
62 | } | ||
63 | |||
64 | // NewCookieStore creates a new cookie-based session store. It requires | ||
65 | // an RSA private key that is used to encrypt to contents fo the cookie | ||
66 | // as well as a factory function that will create a new session if none | ||
67 | // can be found in the user request. | ||
68 | func NewCookieStore[T Session](key *rsa.PrivateKey, fact func(echo.Context) T) (Store[T], error) { | ||
69 | enc, err := jose.NewEncrypter( | ||
70 | jose.A128GCM, | ||
71 | jose.Recipient{ | ||
72 | Algorithm: jose.RSA_OAEP_256, | ||
73 | Key: key.Public(), | ||
74 | }, | ||
75 | nil, | ||
76 | ) | ||
77 | if err != nil { | ||
78 | return nil, err | ||
79 | } | ||
80 | |||
81 | return &CookieStore[T]{ | ||
82 | key: key, | ||
83 | enc: enc, | ||
84 | factory: fact, | ||
85 | }, nil | ||
86 | } | ||
87 | |||
88 | func (s *CookieStore[T]) New(c echo.Context) T { | ||
89 | session := s.factory(c) | ||
90 | c.Set(sessionContextKey, session) | ||
91 | c.Set(sessionDirtyContextKey, true) | ||
92 | return session | ||
93 | } | ||
94 | |||
95 | func (s *CookieStore[T]) Get(c echo.Context) T { | ||
96 | ses, err := s.GetStrict(c) | ||
97 | if err != nil { | ||
98 | return s.New(c) | ||
99 | } | ||
100 | return ses | ||
101 | } | ||
102 | |||
103 | func (s *CookieStore[T]) GetStrict(c echo.Context) (T, error) { | ||
104 | var empty T | ||
105 | |||
106 | // There is a decoded session already in the context, cast and use it | ||
107 | if cs := c.Get(sessionContextKey); cs != nil { | ||
108 | if res, ok := cs.(T); ok { | ||
109 | return res, nil | ||
110 | } | ||
111 | return empty, fmt.Errorf("CookieStore.Get: Error casting context session") | ||
112 | } | ||
113 | |||
114 | // No session in the context and context is marked as dirty. Must | ||
115 | // have been deleted, in which case create a new session and keep the | ||
116 | // context dirty so that the new session is written instead of the old | ||
117 | // one. | ||
118 | if dirty, ok := c.Get(sessionDirtyContextKey).(bool); ok && dirty { | ||
119 | return s.New(c), nil | ||
120 | } | ||
121 | |||
122 | // Otherwise the session should be stored in a cookie, if so then get | ||
123 | // it and decode it. Any failure here should result in a blank session | ||
124 | // being created. | ||
125 | cookie, err := c.Cookie(sessionCookieName) | ||
126 | if err != nil { | ||
127 | if errors.Is(err, http.ErrNoCookie) { | ||
128 | return s.New(c), nil | ||
129 | } | ||
130 | return empty, fmt.Errorf("CookieStore.Get: error getting session cookie: %w", err) | ||
131 | } | ||
132 | |||
133 | session := s.factory(c) | ||
134 | |||
135 | // If the cookie fails to decode then return a new session | ||
136 | if err := s.decodeCookie(cookie, session); err != nil { | ||
137 | return empty, fmt.Errorf("CookieStore.Get: error decoding session cookie: %w", err) | ||
138 | } | ||
139 | |||
140 | // If the session has expired, return a new session | ||
141 | if time.Now().After(session.Expires()) { | ||
142 | return empty, fmt.Errorf("CookieStore.Get: session expired at: %s", session.Expires()) | ||
143 | } | ||
144 | |||
145 | c.Set(sessionContextKey, session) | ||
146 | c.Set(sessionDirtyContextKey, false) | ||
147 | |||
148 | return session, nil | ||
149 | } | ||
150 | |||
151 | func (s *CookieStore[T]) decodeCookie(c *http.Cookie, out T) error { | ||
152 | o, err := jose.ParseEncrypted(c.Value) | ||
153 | if err != nil { | ||
154 | return fmt.Errorf("CookieStore.decodeCookie: error parsing encrypted JWT: %w", err) | ||
155 | } | ||
156 | |||
157 | d, err := o.Decrypt(s.key) | ||
158 | if err != nil { | ||
159 | return fmt.Errorf("CookieStore.decodeCookie: error decrypting JWT: %w", err) | ||
160 | } | ||
161 | |||
162 | if err := json.Unmarshal(d, out); err != nil { | ||
163 | return fmt.Errorf("CookieStore.decodeCookie: error unmarshalling JWT: %w", err) | ||
164 | } | ||
165 | |||
166 | return nil | ||
167 | } | ||
168 | |||
169 | func (s *CookieStore[T]) Update(c echo.Context, session T) error { | ||
170 | c.Set(sessionContextKey, session) | ||
171 | c.Set(sessionDirtyContextKey, true) | ||
172 | return nil | ||
173 | } | ||
174 | |||
175 | func (s *CookieStore[T]) Delete(c echo.Context) error { | ||
176 | c.Set(sessionContextKey, nil) | ||
177 | c.Set(sessionDirtyContextKey, true) | ||
178 | return nil | ||
179 | } | ||
180 | |||
181 | func (s *CookieStore[T]) Commit(c echo.Context) error { | ||
182 | // Don't re-write the session to the client if it hasn't changed. Saves | ||
183 | // some time in request processing to avoid crypto and serialization | ||
184 | // overhead. | ||
185 | dirty, ok := c.Get(sessionDirtyContextKey).(bool) | ||
186 | if ok && !dirty { | ||
187 | return nil | ||
188 | } | ||
189 | |||
190 | // If the context is dirty and the session is nil then the session was | ||
191 | // deleted so delete the cookie on the client as well. | ||
192 | cs := c.Get(sessionContextKey) | ||
193 | if cs == nil && dirty { | ||
194 | c.SetCookie(&http.Cookie{ | ||
195 | Name: sessionCookieName, | ||
196 | Secure: true, | ||
197 | HttpOnly: true, | ||
198 | Path: "/", | ||
199 | SameSite: http.SameSiteLaxMode, | ||
200 | MaxAge: -1, | ||
201 | }) | ||
202 | return nil | ||
203 | } else if cs == nil { | ||
204 | return nil | ||
205 | } | ||
206 | |||
207 | session, ok := cs.(T) | ||
208 | if !ok { | ||
209 | return fmt.Errorf("CookieStore.Commit: unable to cast session") | ||
210 | } | ||
211 | |||
212 | cookie := &http.Cookie{ | ||
213 | Name: sessionCookieName, | ||
214 | Secure: true, | ||
215 | HttpOnly: true, | ||
216 | Path: "/", | ||
217 | SameSite: http.SameSiteLaxMode, | ||
218 | Expires: session.Expires(), | ||
219 | } | ||
220 | |||
221 | if err := s.encodeCookie(c, cookie); err != nil { | ||
222 | return err | ||
223 | } | ||
224 | |||
225 | c.SetCookie(cookie) | ||
226 | return nil | ||
227 | } | ||
228 | |||
229 | func (s *CookieStore[T]) encodeCookie(c echo.Context, cookie *http.Cookie) error { | ||
230 | md, err := json.Marshal(c.Get(sessionContextKey)) | ||
231 | if err != nil { | ||
232 | return fmt.Errorf("CookieStore.encodeCookie: failed to encode json: %w", err) | ||
233 | } | ||
234 | |||
235 | o, err := s.enc.Encrypt(md) | ||
236 | if err != nil { | ||
237 | return fmt.Errorf("CookieStore.encodeCookie: error encrypting JWT: %w", err) | ||
238 | } | ||
239 | |||
240 | ser, err := o.CompactSerialize() | ||
241 | if err != nil { | ||
242 | return fmt.Errorf("CookieStore.encodeCookie: error serialzing JWT: %w", err) | ||
243 | } | ||
244 | |||
245 | cookie.Value = ser | ||
246 | |||
247 | return nil | ||
248 | } | ||
diff --git a/echo/session/store_test.go b/echo/session/store_test.go new file mode 100644 index 0000000..561caca --- /dev/null +++ b/echo/session/store_test.go | |||
@@ -0,0 +1,67 @@ | |||
1 | package session | ||
2 | |||
3 | import ( | ||
4 | "crypto/rand" | ||
5 | "crypto/rsa" | ||
6 | "net/http" | ||
7 | "net/http/httptest" | ||
8 | "testing" | ||
9 | "time" | ||
10 | |||
11 | "github.com/labstack/echo/v4" | ||
12 | "github.com/stretchr/testify/assert" | ||
13 | ) | ||
14 | |||
15 | type thing struct { | ||
16 | Name string | ||
17 | Value int | ||
18 | } | ||
19 | |||
20 | type testSession struct { | ||
21 | Name string | ||
22 | Thing *thing | ||
23 | Exp time.Time | ||
24 | } | ||
25 | |||
26 | func (s *testSession) Expires() time.Time { | ||
27 | return s.Exp | ||
28 | } | ||
29 | |||
30 | // TODO: This is ugly, make it better and test more edge cases | ||
31 | func TestStore(t *testing.T) { | ||
32 | e := echo.New() | ||
33 | |||
34 | pk, err := rsa.GenerateKey(rand.Reader, 2048) | ||
35 | assert.NoError(t, err) | ||
36 | |||
37 | st, err := NewCookieStore[*testSession](pk, func(echo.Context) *testSession { | ||
38 | return &testSession{} | ||
39 | }) | ||
40 | assert.NoError(t, err) | ||
41 | |||
42 | req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
43 | res := httptest.NewRecorder() | ||
44 | |||
45 | c := e.NewContext(req, res) | ||
46 | |||
47 | s := st.New(c) | ||
48 | s.Name = "foo" | ||
49 | s.Thing = &thing{"bar", 10} | ||
50 | s.Exp = time.Now().Add(time.Hour) | ||
51 | |||
52 | assert.NoError(t, st.Commit(c)) | ||
53 | |||
54 | c2 := e.NewContext(&http.Request{ | ||
55 | Header: http.Header{ | ||
56 | "Cookie": []string{res.Header().Get("Set-Cookie")}, | ||
57 | }, | ||
58 | }, httptest.NewRecorder()) | ||
59 | |||
60 | nt, err := st.GetStrict(c2) | ||
61 | assert.NoError(t, err) | ||
62 | |||
63 | assert.Equal(t, nt.Name, "foo") | ||
64 | assert.Equal(t, nt.Thing.Name, "bar") | ||
65 | assert.Equal(t, nt.Thing.Value, 10) | ||
66 | assert.Equal(t, nt.Exp.Unix(), s.Exp.Unix()) | ||
67 | } | ||