diff options
author | Mike Crute <mike@crute.us> | 2022-01-02 11:49:07 -0800 |
---|---|---|
committer | Mike Crute <mike@crute.us> | 2022-01-02 11:49:07 -0800 |
commit | 51949f8dc563c7c1ce03d8862abbee4cc1e20943 (patch) | |
tree | de8c07df1a0fd4dd47e947ea6d74275ee2c3daa3 | |
parent | 219288c98477e392242e5dcca300d20062b3c670 (diff) | |
download | websocket_proxy-51949f8dc563c7c1ce03d8862abbee4cc1e20943.tar.bz2 websocket_proxy-51949f8dc563c7c1ce03d8862abbee4cc1e20943.tar.xz websocket_proxy-51949f8dc563c7c1ce03d8862abbee4cc1e20943.zip |
WIP: add contexts
-rw-r--r-- | client.go | 4 | ||||
-rw-r--r-- | main.go | 6 | ||||
-rw-r--r-- | server.go | 2 | ||||
-rw-r--r-- | sockets.go | 27 |
4 files changed, 30 insertions, 9 deletions
@@ -1,6 +1,7 @@ | |||
1 | package main | 1 | package main |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "context" | ||
4 | "log" | 5 | "log" |
5 | "net" | 6 | "net" |
6 | 7 | ||
@@ -10,6 +11,7 @@ import ( | |||
10 | type ClientHandler struct { | 11 | type ClientHandler struct { |
11 | SocketListenOn string | 12 | SocketListenOn string |
12 | WebsocketServer string | 13 | WebsocketServer string |
14 | Context context.Context | ||
13 | } | 15 | } |
14 | 16 | ||
15 | func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) { | 17 | func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) { |
@@ -24,7 +26,7 @@ func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) { | |||
24 | 26 | ||
25 | log.Println("Connected to server") | 27 | log.Println("Connected to server") |
26 | 28 | ||
27 | serviceBoth(wsconn, proxyconn) | 29 | serviceBoth(wsconn, proxyconn, h.Context) |
28 | } | 30 | } |
29 | 31 | ||
30 | func (h *ClientHandler) Run() { | 32 | func (h *ClientHandler) Run() { |
@@ -1,6 +1,7 @@ | |||
1 | package main | 1 | package main |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "context" | ||
4 | "errors" | 5 | "errors" |
5 | "fmt" | 6 | "fmt" |
6 | "log" | 7 | "log" |
@@ -34,9 +35,14 @@ var clientCmd = &cobra.Command{ | |||
34 | Run: func(cmd *cobra.Command, args []string) { | 35 | Run: func(cmd *cobra.Command, args []string) { |
35 | listenOn := cmd.Flag("listen").Value.String() | 36 | listenOn := cmd.Flag("listen").Value.String() |
36 | 37 | ||
38 | // TODO: Handle signals | ||
39 | ctx, cancel := context.WithCancel(context.Background()) | ||
40 | defer cancel() | ||
41 | |||
37 | h := &ClientHandler{ | 42 | h := &ClientHandler{ |
38 | SocketListenOn: listenOn, | 43 | SocketListenOn: listenOn, |
39 | WebsocketServer: args[0], | 44 | WebsocketServer: args[0], |
45 | Context: ctx, | ||
40 | } | 46 | } |
41 | 47 | ||
42 | log.Printf("Serving on %s", listenOn) | 48 | log.Printf("Serving on %s", listenOn) |
@@ -42,5 +42,5 @@ func (h *ServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
42 | 42 | ||
43 | log.Println("Connected to SSH server") | 43 | log.Println("Connected to SSH server") |
44 | 44 | ||
45 | serviceBoth(wsconn, proxyconn) | 45 | serviceBoth(wsconn, proxyconn, r.Context()) |
46 | } | 46 | } |
@@ -1,6 +1,7 @@ | |||
1 | package main | 1 | package main |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "context" | ||
4 | "io" | 5 | "io" |
5 | "log" | 6 | "log" |
6 | "net" | 7 | "net" |
@@ -8,7 +9,7 @@ import ( | |||
8 | "github.com/gorilla/websocket" | 9 | "github.com/gorilla/websocket" |
9 | ) | 10 | ) |
10 | 11 | ||
11 | func wsReader(wsconn *websocket.Conn, out chan []byte) { | 12 | func wsReader(wsconn *websocket.Conn, out chan []byte, ctx context.Context) { |
12 | for { | 13 | for { |
13 | messageType, p, err := wsconn.ReadMessage() | 14 | messageType, p, err := wsconn.ReadMessage() |
14 | if err != nil { | 15 | if err != nil { |
@@ -19,11 +20,16 @@ func wsReader(wsconn *websocket.Conn, out chan []byte) { | |||
19 | log.Println("error: wsReader: only binary messages are supported") | 20 | log.Println("error: wsReader: only binary messages are supported") |
20 | continue | 21 | continue |
21 | } | 22 | } |
22 | out <- p | 23 | select { |
24 | case out <- p: | ||
25 | continue | ||
26 | case <-ctx.Done(): | ||
27 | return | ||
28 | } | ||
23 | } | 29 | } |
24 | } | 30 | } |
25 | 31 | ||
26 | func socketReader(proxyconn net.Conn, out chan []byte) { | 32 | func socketReader(proxyconn net.Conn, out chan []byte, ctx context.Context) { |
27 | for { | 33 | for { |
28 | readBuffer := make([]byte, 2048) | 34 | readBuffer := make([]byte, 2048) |
29 | 35 | ||
@@ -37,16 +43,21 @@ func socketReader(proxyconn net.Conn, out chan []byte) { | |||
37 | return | 43 | return |
38 | } | 44 | } |
39 | 45 | ||
40 | out <- readBuffer[:i] | 46 | select { |
47 | case out <- readBuffer[:i]: | ||
48 | continue | ||
49 | case <-ctx.Done(): | ||
50 | return | ||
51 | } | ||
41 | } | 52 | } |
42 | } | 53 | } |
43 | 54 | ||
44 | func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn) { | 55 | func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn, ctx context.Context) { |
45 | sc := make(chan []byte) | 56 | sc := make(chan []byte) |
46 | wsc := make(chan []byte) | 57 | wsc := make(chan []byte) |
47 | 58 | ||
48 | go socketReader(proxyconn, sc) | 59 | go socketReader(proxyconn, sc, ctx) |
49 | go wsReader(wsconn, wsc) | 60 | go wsReader(wsconn, wsc, ctx) |
50 | 61 | ||
51 | for { | 62 | for { |
52 | select { | 63 | select { |
@@ -61,6 +72,8 @@ func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn) { | |||
61 | log.Printf("error: serviceBoth: %s", err) | 72 | log.Printf("error: serviceBoth: %s", err) |
62 | return | 73 | return |
63 | } | 74 | } |
75 | case <-ctx.Done(): | ||
76 | return | ||
64 | } | 77 | } |
65 | } | 78 | } |
66 | } | 79 | } |