From 51949f8dc563c7c1ce03d8862abbee4cc1e20943 Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Sun, 2 Jan 2022 11:49:07 -0800 Subject: WIP: add contexts --- client.go | 4 +++- main.go | 6 ++++++ server.go | 2 +- sockets.go | 27 ++++++++++++++++++++------- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 00455ab..d0716b1 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "net" @@ -10,6 +11,7 @@ import ( type ClientHandler struct { SocketListenOn string WebsocketServer string + Context context.Context } func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) { @@ -24,7 +26,7 @@ func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) { log.Println("Connected to server") - serviceBoth(wsconn, proxyconn) + serviceBoth(wsconn, proxyconn, h.Context) } func (h *ClientHandler) Run() { diff --git a/main.go b/main.go index ec4170b..00fe063 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "log" @@ -34,9 +35,14 @@ var clientCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { listenOn := cmd.Flag("listen").Value.String() + // TODO: Handle signals + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h := &ClientHandler{ SocketListenOn: listenOn, WebsocketServer: args[0], + Context: ctx, } log.Printf("Serving on %s", listenOn) diff --git a/server.go b/server.go index 47eb3e5..1faf33d 100644 --- a/server.go +++ b/server.go @@ -42,5 +42,5 @@ func (h *ServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Println("Connected to SSH server") - serviceBoth(wsconn, proxyconn) + serviceBoth(wsconn, proxyconn, r.Context()) } diff --git a/sockets.go b/sockets.go index 0ebbe43..2ded258 100644 --- a/sockets.go +++ b/sockets.go @@ -1,6 +1,7 @@ package main import ( + "context" "io" "log" "net" @@ -8,7 +9,7 @@ import ( "github.com/gorilla/websocket" ) -func wsReader(wsconn *websocket.Conn, out chan []byte) { +func wsReader(wsconn *websocket.Conn, out chan []byte, ctx context.Context) { for { messageType, p, err := wsconn.ReadMessage() if err != nil { @@ -19,11 +20,16 @@ func wsReader(wsconn *websocket.Conn, out chan []byte) { log.Println("error: wsReader: only binary messages are supported") continue } - out <- p + select { + case out <- p: + continue + case <-ctx.Done(): + return + } } } -func socketReader(proxyconn net.Conn, out chan []byte) { +func socketReader(proxyconn net.Conn, out chan []byte, ctx context.Context) { for { readBuffer := make([]byte, 2048) @@ -37,16 +43,21 @@ func socketReader(proxyconn net.Conn, out chan []byte) { return } - out <- readBuffer[:i] + select { + case out <- readBuffer[:i]: + continue + case <-ctx.Done(): + return + } } } -func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn) { +func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn, ctx context.Context) { sc := make(chan []byte) wsc := make(chan []byte) - go socketReader(proxyconn, sc) - go wsReader(wsconn, wsc) + go socketReader(proxyconn, sc, ctx) + go wsReader(wsconn, wsc, ctx) for { select { @@ -61,6 +72,8 @@ func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn) { log.Printf("error: serviceBoth: %s", err) return } + case <-ctx.Done(): + return } } } -- cgit v1.2.3