summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Crute <mike@crute.us>2023-08-19 17:25:09 -0700
committerMike Crute <mike@crute.us>2023-08-19 17:25:09 -0700
commit267ee2d8a78fa7425af765eb583cab3248995a31 (patch)
tree38f0dd8b2b85a3570988201297563726598efbf3
parentafbec26c42528ad810fbc08c6a233c7ea75239d4 (diff)
downloadwebsocket_proxy-267ee2d8a78fa7425af765eb583cab3248995a31.tar.bz2
websocket_proxy-267ee2d8a78fa7425af765eb583cab3248995a31.tar.xz
websocket_proxy-267ee2d8a78fa7425af765eb583cab3248995a31.zip
Add server-driven time limits
-rw-r--r--app/controllers/proxy.go3
-rw-r--r--cmd/client/client.go36
2 files changed, 34 insertions, 5 deletions
diff --git a/app/controllers/proxy.go b/app/controllers/proxy.go
index 9e3ec13..0777c93 100644
--- a/app/controllers/proxy.go
+++ b/app/controllers/proxy.go
@@ -72,6 +72,9 @@ func (h *ProxyHandler) Handle(c echo.Context) error {
72 return c.NoContent(http.StatusUnauthorized) 72 return c.NoContent(http.StatusUnauthorized)
73 } 73 }
74 74
75 // TODO: Set Access-Control-Max-Age header if policy requires time-limited
76 // sessions then terminate the session once that timer expires
77
75 wsconn, err := h.Upgrader.Upgrade(c.Response(), c.Request(), nil) 78 wsconn, err := h.Upgrader.Upgrade(c.Response(), c.Request(), nil)
76 if err != nil { 79 if err != nil {
77 return err 80 return err
diff --git a/cmd/client/client.go b/cmd/client/client.go
index 54d7190..b72003d 100644
--- a/cmd/client/client.go
+++ b/cmd/client/client.go
@@ -11,6 +11,8 @@ import (
11 "net" 11 "net"
12 "net/http" 12 "net/http"
13 "os" 13 "os"
14 "strconv"
15 "time"
14 16
15 "code.crute.us/mcrute/ssh-proxy/app" 17 "code.crute.us/mcrute/ssh-proxy/app"
16 "code.crute.us/mcrute/ssh-proxy/proxy" 18 "code.crute.us/mcrute/ssh-proxy/proxy"
@@ -139,18 +141,30 @@ func addCertificateToAgent(conn agent.ExtendedAgent, private any, cert *ssh.Cert
139 }) 141 })
140} 142}
141 143
142func dialProxyHost(ctx context.Context, oauthToken, proxyHost, host, port string) (io.ReadWriteCloser, error) { 144func dialProxyHost(ctx context.Context, oauthToken, proxyHost, host, port string) (io.ReadWriteCloser, time.Duration, error) {
143 addr := fmt.Sprintf("wss://%s/proxy-to/%s/%s", proxyHost, host, port) 145 addr := fmt.Sprintf("wss://%s/proxy-to/%s/%s", proxyHost, host, port)
144 146
145 hdr := http.Header{} 147 hdr := http.Header{}
146 hdr.Add("Authorization", fmt.Sprintf("Bearer %s", oauthToken)) 148 hdr.Add("Authorization", fmt.Sprintf("Bearer %s", oauthToken))
147 149
148 conn, _, err := websocket.DefaultDialer.DialContext(ctx, addr, hdr) 150 conn, resp, err := websocket.DefaultDialer.DialContext(ctx, addr, hdr)
149 if err != nil { 151 if err != nil {
150 return nil, err 152 return nil, 0, err
153 }
154
155 // Extract the connection TTL in seconds from the header if present,
156 // otherwise there's no TTL. If we fail to terminate by the TTL the server
157 // will do it for us.
158 var ttl time.Duration
159 if th := resp.Header.Get("Access-Control-Max-Age"); th != "" {
160 tp, err := strconv.Atoi(th)
161 if err != nil {
162 return nil, 0, err
163 }
164 ttl = time.Duration(tp) * time.Second
151 } 165 }
152 166
153 return &proxy.WebsocketReadWriter{W: conn}, nil 167 return &proxy.WebsocketReadWriter{W: conn}, ttl, nil
154} 168}
155 169
156func fetchOauthToken(ctx context.Context, clientId, proxyHost string) (string, error) { 170func fetchOauthToken(ctx context.Context, clientId, proxyHost string) (string, error) {
@@ -218,7 +232,7 @@ func clientMain(cfg app.Config, host, port, username string) {
218 log.Fatalf("Error adding certificate to agent: %s", err) 232 log.Fatalf("Error adding certificate to agent: %s", err)
219 } 233 }
220 234
221 ws, err := dialProxyHost(ctx, oauthToken, cfg.ClientHost, host, port) 235 ws, ttl, err := dialProxyHost(ctx, oauthToken, cfg.ClientHost, host, port)
222 if err != nil { 236 if err != nil {
223 log.Fatalf("Error dialing proxy host: %s", err) 237 log.Fatalf("Error dialing proxy host: %s", err)
224 } 238 }
@@ -229,13 +243,25 @@ func clientMain(cfg app.Config, host, port, username string) {
229 243
230 errc := make(chan error) 244 errc := make(chan error)
231 245
246 // The server will also force the connection closed but if we do it here we
247 // can give a slightly more friendly error message to the client.
248 if ttl != 0 {
249 log.Printf("Time limited connection, will expire at %s", time.Now().Add(ttl))
250 time.AfterFunc(ttl, func() {
251 ws.Close()
252 errc <- fmt.Errorf("Connection time limit has expired")
253 })
254 }
255
232 go proxy.CopyWithErrors(os.Stdout, ws, errc) 256 go proxy.CopyWithErrors(os.Stdout, ws, errc)
233 go proxy.CopyWithErrors(ws, os.Stdin, errc) 257 go proxy.CopyWithErrors(ws, os.Stdin, errc)
234 258
235 err = <-errc 259 err = <-errc
236 if err != nil { 260 if err != nil {
237 log.Printf("Closing client connection: %s", <-errc) 261 log.Printf("Closing client connection: %s", <-errc)
262 os.Exit(1)
238 } else { 263 } else {
239 log.Printf("Closing client connection") 264 log.Printf("Closing client connection")
265 os.Exit(0)
240 } 266 }
241} 267}