From 267ee2d8a78fa7425af765eb583cab3248995a31 Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Sat, 19 Aug 2023 17:25:09 -0700 Subject: Add server-driven time limits --- app/controllers/proxy.go | 3 +++ cmd/client/client.go | 36 +++++++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/app/controllers/proxy.go b/app/controllers/proxy.go index 9e3ec13..0777c93 100644 --- a/app/controllers/proxy.go +++ b/app/controllers/proxy.go @@ -72,6 +72,9 @@ func (h *ProxyHandler) Handle(c echo.Context) error { return c.NoContent(http.StatusUnauthorized) } + // TODO: Set Access-Control-Max-Age header if policy requires time-limited + // sessions then terminate the session once that timer expires + wsconn, err := h.Upgrader.Upgrade(c.Response(), c.Request(), nil) if err != nil { return err diff --git a/cmd/client/client.go b/cmd/client/client.go index 54d7190..b72003d 100644 --- a/cmd/client/client.go +++ b/cmd/client/client.go @@ -11,6 +11,8 @@ import ( "net" "net/http" "os" + "strconv" + "time" "code.crute.us/mcrute/ssh-proxy/app" "code.crute.us/mcrute/ssh-proxy/proxy" @@ -139,18 +141,30 @@ func addCertificateToAgent(conn agent.ExtendedAgent, private any, cert *ssh.Cert }) } -func dialProxyHost(ctx context.Context, oauthToken, proxyHost, host, port string) (io.ReadWriteCloser, error) { +func dialProxyHost(ctx context.Context, oauthToken, proxyHost, host, port string) (io.ReadWriteCloser, time.Duration, error) { addr := fmt.Sprintf("wss://%s/proxy-to/%s/%s", proxyHost, host, port) hdr := http.Header{} hdr.Add("Authorization", fmt.Sprintf("Bearer %s", oauthToken)) - conn, _, err := websocket.DefaultDialer.DialContext(ctx, addr, hdr) + conn, resp, err := websocket.DefaultDialer.DialContext(ctx, addr, hdr) if err != nil { - return nil, err + return nil, 0, err + } + + // Extract the connection TTL in seconds from the header if present, + // otherwise there's no TTL. If we fail to terminate by the TTL the server + // will do it for us. + var ttl time.Duration + if th := resp.Header.Get("Access-Control-Max-Age"); th != "" { + tp, err := strconv.Atoi(th) + if err != nil { + return nil, 0, err + } + ttl = time.Duration(tp) * time.Second } - return &proxy.WebsocketReadWriter{W: conn}, nil + return &proxy.WebsocketReadWriter{W: conn}, ttl, nil } func fetchOauthToken(ctx context.Context, clientId, proxyHost string) (string, error) { @@ -218,7 +232,7 @@ func clientMain(cfg app.Config, host, port, username string) { log.Fatalf("Error adding certificate to agent: %s", err) } - ws, err := dialProxyHost(ctx, oauthToken, cfg.ClientHost, host, port) + ws, ttl, err := dialProxyHost(ctx, oauthToken, cfg.ClientHost, host, port) if err != nil { log.Fatalf("Error dialing proxy host: %s", err) } @@ -229,13 +243,25 @@ func clientMain(cfg app.Config, host, port, username string) { errc := make(chan error) + // The server will also force the connection closed but if we do it here we + // can give a slightly more friendly error message to the client. + if ttl != 0 { + log.Printf("Time limited connection, will expire at %s", time.Now().Add(ttl)) + time.AfterFunc(ttl, func() { + ws.Close() + errc <- fmt.Errorf("Connection time limit has expired") + }) + } + go proxy.CopyWithErrors(os.Stdout, ws, errc) go proxy.CopyWithErrors(ws, os.Stdin, errc) err = <-errc if err != nil { log.Printf("Closing client connection: %s", <-errc) + os.Exit(1) } else { log.Printf("Closing client connection") + os.Exit(0) } } -- cgit v1.2.3