aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Crute <mike@crute.us>2022-11-15 20:58:53 -0800
committerMike Crute <mike@crute.us>2022-11-15 20:58:53 -0800
commit6e4a03e9cac2e774208a9189a4af646e69a658a8 (patch)
tree438f28960dd3baa403a3de1a12de5c3b89e646a5
parent600378d36f109a5eccffb804134fb78f465768e9 (diff)
downloadgolib-6e4a03e9cac2e774208a9189a4af646e69a658a8.tar.bz2
golib-6e4a03e9cac2e774208a9189a4af646e69a658a8.tar.xz
golib-6e4a03e9cac2e774208a9189a4af646e69a658a8.zip
echo: add session store and middlewareecho/v0.8.0
-rw-r--r--echo/session/middleware.go60
-rw-r--r--echo/session/store.go248
-rw-r--r--echo/session/store_test.go67
3 files changed, 375 insertions, 0 deletions
diff --git a/echo/session/middleware.go b/echo/session/middleware.go
new file mode 100644
index 0000000..7c1d1d2
--- /dev/null
+++ b/echo/session/middleware.go
@@ -0,0 +1,60 @@
1package session
2
3import (
4 "github.com/labstack/echo/v4"
5 "github.com/labstack/echo/v4/middleware"
6)
7
8const (
9 sessionStoreContextKey = "__golib_echo_sessionStore"
10)
11
12type Config[T Session] struct {
13 Skipper middleware.Skipper
14 Store Store[T]
15}
16
17func GetSessionStore[T Session](c echo.Context) Store[T] {
18 store, ok := c.Get(sessionStoreContextKey).(Store[T])
19 if !ok {
20 // This should never happen if the middleware is run
21 panic("GetSessionStore: error fetching session store")
22 }
23 return store
24}
25
26func Middleware[T Session](store Store[T]) echo.MiddlewareFunc {
27 return MiddlewareWithConfig(Config[T]{
28 Skipper: middleware.DefaultSkipper,
29 Store: store,
30 })
31}
32
33func MiddlewareWithConfig[T Session](config Config[T]) echo.MiddlewareFunc {
34 if config.Skipper == nil {
35 config.Skipper = middleware.DefaultSkipper
36 }
37 if config.Store == nil {
38 panic("session/MiddlewareWithConfig: Store is required for middleware")
39 }
40
41 return func(next echo.HandlerFunc) echo.HandlerFunc {
42 return func(c echo.Context) error {
43 if config.Skipper(c) {
44 return next(c)
45 }
46
47 c.Set(sessionStoreContextKey, config.Store)
48
49 // This is executed by echo before headers are written to the client
50 // and can manipulate the headers first.
51 c.Response().Before(func() {
52 if err := config.Store.Commit(c); err != nil {
53 c.Logger().Error("Error committing session: %s", err)
54 }
55 })
56
57 return next(c)
58 }
59 }
60}
diff --git a/echo/session/store.go b/echo/session/store.go
new file mode 100644
index 0000000..f82b3b6
--- /dev/null
+++ b/echo/session/store.go
@@ -0,0 +1,248 @@
1package session
2
3import (
4 "crypto/rsa"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "net/http"
9 "time"
10
11 "github.com/labstack/echo/v4"
12 "gopkg.in/square/go-jose.v2"
13)
14
15const (
16 sessionCookieName = "sinfo"
17 sessionContextKey = "__golib_echo_session"
18 sessionDirtyContextKey = "__golib_echo_sessionDirty"
19)
20
21type Session interface {
22 Expires() time.Time
23}
24
25type Store[T Session] interface {
26 // New creates a new session and attaches it to the request context.
27 New(echo.Context) T
28
29 // Get will get a session. If the session exists and has not been
30 // decoded it will first decoded it, validate it, and cache it on the
31 // request context. Subsequent requests to Get will not parse the
32 // session again. If the session does not exist or can not be decoded
33 // a new session will be returned.
34 Get(echo.Context) T
35
36 // GetStrict is like Get but will return errors if anything goes wrong.
37 // This is not normally what you want but can be useful for debugging
38 // session store issues. Unlike Get, GetStrict will never return a new
39 // session when there is an error.
40 GetStrict(echo.Context) (T, error)
41
42 // Update will update the session in the request context and mark
43 // the session as dirty. Only dirty sessions can be committed to the
44 // client.
45 Update(echo.Context, T) error
46
47 // Delete will mark the session as deleted which will result in removal
48 // of the session from storage during commit.
49 Delete(echo.Context) error
50
51 // Commit causes the store to persist changes to the session to the
52 // storage mechanism and add a session identifier cookie to the client
53 // response. Only dirty or new sessions will result in changes to the
54 // client and storage.
55 Commit(echo.Context) error
56}
57
58type CookieStore[T Session] struct {
59 key *rsa.PrivateKey
60 enc jose.Encrypter
61 factory func(echo.Context) T
62}
63
64// NewCookieStore creates a new cookie-based session store. It requires
65// an RSA private key that is used to encrypt to contents fo the cookie
66// as well as a factory function that will create a new session if none
67// can be found in the user request.
68func NewCookieStore[T Session](key *rsa.PrivateKey, fact func(echo.Context) T) (Store[T], error) {
69 enc, err := jose.NewEncrypter(
70 jose.A128GCM,
71 jose.Recipient{
72 Algorithm: jose.RSA_OAEP_256,
73 Key: key.Public(),
74 },
75 nil,
76 )
77 if err != nil {
78 return nil, err
79 }
80
81 return &CookieStore[T]{
82 key: key,
83 enc: enc,
84 factory: fact,
85 }, nil
86}
87
88func (s *CookieStore[T]) New(c echo.Context) T {
89 session := s.factory(c)
90 c.Set(sessionContextKey, session)
91 c.Set(sessionDirtyContextKey, true)
92 return session
93}
94
95func (s *CookieStore[T]) Get(c echo.Context) T {
96 ses, err := s.GetStrict(c)
97 if err != nil {
98 return s.New(c)
99 }
100 return ses
101}
102
103func (s *CookieStore[T]) GetStrict(c echo.Context) (T, error) {
104 var empty T
105
106 // There is a decoded session already in the context, cast and use it
107 if cs := c.Get(sessionContextKey); cs != nil {
108 if res, ok := cs.(T); ok {
109 return res, nil
110 }
111 return empty, fmt.Errorf("CookieStore.Get: Error casting context session")
112 }
113
114 // No session in the context and context is marked as dirty. Must
115 // have been deleted, in which case create a new session and keep the
116 // context dirty so that the new session is written instead of the old
117 // one.
118 if dirty, ok := c.Get(sessionDirtyContextKey).(bool); ok && dirty {
119 return s.New(c), nil
120 }
121
122 // Otherwise the session should be stored in a cookie, if so then get
123 // it and decode it. Any failure here should result in a blank session
124 // being created.
125 cookie, err := c.Cookie(sessionCookieName)
126 if err != nil {
127 if errors.Is(err, http.ErrNoCookie) {
128 return s.New(c), nil
129 }
130 return empty, fmt.Errorf("CookieStore.Get: error getting session cookie: %w", err)
131 }
132
133 session := s.factory(c)
134
135 // If the cookie fails to decode then return a new session
136 if err := s.decodeCookie(cookie, session); err != nil {
137 return empty, fmt.Errorf("CookieStore.Get: error decoding session cookie: %w", err)
138 }
139
140 // If the session has expired, return a new session
141 if time.Now().After(session.Expires()) {
142 return empty, fmt.Errorf("CookieStore.Get: session expired at: %s", session.Expires())
143 }
144
145 c.Set(sessionContextKey, session)
146 c.Set(sessionDirtyContextKey, false)
147
148 return session, nil
149}
150
151func (s *CookieStore[T]) decodeCookie(c *http.Cookie, out T) error {
152 o, err := jose.ParseEncrypted(c.Value)
153 if err != nil {
154 return fmt.Errorf("CookieStore.decodeCookie: error parsing encrypted JWT: %w", err)
155 }
156
157 d, err := o.Decrypt(s.key)
158 if err != nil {
159 return fmt.Errorf("CookieStore.decodeCookie: error decrypting JWT: %w", err)
160 }
161
162 if err := json.Unmarshal(d, out); err != nil {
163 return fmt.Errorf("CookieStore.decodeCookie: error unmarshalling JWT: %w", err)
164 }
165
166 return nil
167}
168
169func (s *CookieStore[T]) Update(c echo.Context, session T) error {
170 c.Set(sessionContextKey, session)
171 c.Set(sessionDirtyContextKey, true)
172 return nil
173}
174
175func (s *CookieStore[T]) Delete(c echo.Context) error {
176 c.Set(sessionContextKey, nil)
177 c.Set(sessionDirtyContextKey, true)
178 return nil
179}
180
181func (s *CookieStore[T]) Commit(c echo.Context) error {
182 // Don't re-write the session to the client if it hasn't changed. Saves
183 // some time in request processing to avoid crypto and serialization
184 // overhead.
185 dirty, ok := c.Get(sessionDirtyContextKey).(bool)
186 if ok && !dirty {
187 return nil
188 }
189
190 // If the context is dirty and the session is nil then the session was
191 // deleted so delete the cookie on the client as well.
192 cs := c.Get(sessionContextKey)
193 if cs == nil && dirty {
194 c.SetCookie(&http.Cookie{
195 Name: sessionCookieName,
196 Secure: true,
197 HttpOnly: true,
198 Path: "/",
199 SameSite: http.SameSiteLaxMode,
200 MaxAge: -1,
201 })
202 return nil
203 } else if cs == nil {
204 return nil
205 }
206
207 session, ok := cs.(T)
208 if !ok {
209 return fmt.Errorf("CookieStore.Commit: unable to cast session")
210 }
211
212 cookie := &http.Cookie{
213 Name: sessionCookieName,
214 Secure: true,
215 HttpOnly: true,
216 Path: "/",
217 SameSite: http.SameSiteLaxMode,
218 Expires: session.Expires(),
219 }
220
221 if err := s.encodeCookie(c, cookie); err != nil {
222 return err
223 }
224
225 c.SetCookie(cookie)
226 return nil
227}
228
229func (s *CookieStore[T]) encodeCookie(c echo.Context, cookie *http.Cookie) error {
230 md, err := json.Marshal(c.Get(sessionContextKey))
231 if err != nil {
232 return fmt.Errorf("CookieStore.encodeCookie: failed to encode json: %w", err)
233 }
234
235 o, err := s.enc.Encrypt(md)
236 if err != nil {
237 return fmt.Errorf("CookieStore.encodeCookie: error encrypting JWT: %w", err)
238 }
239
240 ser, err := o.CompactSerialize()
241 if err != nil {
242 return fmt.Errorf("CookieStore.encodeCookie: error serialzing JWT: %w", err)
243 }
244
245 cookie.Value = ser
246
247 return nil
248}
diff --git a/echo/session/store_test.go b/echo/session/store_test.go
new file mode 100644
index 0000000..561caca
--- /dev/null
+++ b/echo/session/store_test.go
@@ -0,0 +1,67 @@
1package session
2
3import (
4 "crypto/rand"
5 "crypto/rsa"
6 "net/http"
7 "net/http/httptest"
8 "testing"
9 "time"
10
11 "github.com/labstack/echo/v4"
12 "github.com/stretchr/testify/assert"
13)
14
15type thing struct {
16 Name string
17 Value int
18}
19
20type testSession struct {
21 Name string
22 Thing *thing
23 Exp time.Time
24}
25
26func (s *testSession) Expires() time.Time {
27 return s.Exp
28}
29
30// TODO: This is ugly, make it better and test more edge cases
31func TestStore(t *testing.T) {
32 e := echo.New()
33
34 pk, err := rsa.GenerateKey(rand.Reader, 2048)
35 assert.NoError(t, err)
36
37 st, err := NewCookieStore[*testSession](pk, func(echo.Context) *testSession {
38 return &testSession{}
39 })
40 assert.NoError(t, err)
41
42 req := httptest.NewRequest(http.MethodGet, "/", nil)
43 res := httptest.NewRecorder()
44
45 c := e.NewContext(req, res)
46
47 s := st.New(c)
48 s.Name = "foo"
49 s.Thing = &thing{"bar", 10}
50 s.Exp = time.Now().Add(time.Hour)
51
52 assert.NoError(t, st.Commit(c))
53
54 c2 := e.NewContext(&http.Request{
55 Header: http.Header{
56 "Cookie": []string{res.Header().Get("Set-Cookie")},
57 },
58 }, httptest.NewRecorder())
59
60 nt, err := st.GetStrict(c2)
61 assert.NoError(t, err)
62
63 assert.Equal(t, nt.Name, "foo")
64 assert.Equal(t, nt.Thing.Name, "bar")
65 assert.Equal(t, nt.Thing.Value, 10)
66 assert.Equal(t, nt.Exp.Unix(), s.Exp.Unix())
67}