summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Crute <mike@crute.us>2022-01-02 11:49:07 -0800
committerMike Crute <mike@crute.us>2022-01-02 11:49:07 -0800
commit51949f8dc563c7c1ce03d8862abbee4cc1e20943 (patch)
treede8c07df1a0fd4dd47e947ea6d74275ee2c3daa3
parent219288c98477e392242e5dcca300d20062b3c670 (diff)
downloadwebsocket_proxy-51949f8dc563c7c1ce03d8862abbee4cc1e20943.tar.bz2
websocket_proxy-51949f8dc563c7c1ce03d8862abbee4cc1e20943.tar.xz
websocket_proxy-51949f8dc563c7c1ce03d8862abbee4cc1e20943.zip
WIP: add contexts
-rw-r--r--client.go4
-rw-r--r--main.go6
-rw-r--r--server.go2
-rw-r--r--sockets.go27
4 files changed, 30 insertions, 9 deletions
diff --git a/client.go b/client.go
index 00455ab..d0716b1 100644
--- a/client.go
+++ b/client.go
@@ -1,6 +1,7 @@
1package main 1package main
2 2
3import ( 3import (
4 "context"
4 "log" 5 "log"
5 "net" 6 "net"
6 7
@@ -10,6 +11,7 @@ import (
10type ClientHandler struct { 11type ClientHandler struct {
11 SocketListenOn string 12 SocketListenOn string
12 WebsocketServer string 13 WebsocketServer string
14 Context context.Context
13} 15}
14 16
15func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) { 17func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) {
@@ -24,7 +26,7 @@ func (h *ClientHandler) ServiceConnection(proxyconn net.Conn) {
24 26
25 log.Println("Connected to server") 27 log.Println("Connected to server")
26 28
27 serviceBoth(wsconn, proxyconn) 29 serviceBoth(wsconn, proxyconn, h.Context)
28} 30}
29 31
30func (h *ClientHandler) Run() { 32func (h *ClientHandler) Run() {
diff --git a/main.go b/main.go
index ec4170b..00fe063 100644
--- a/main.go
+++ b/main.go
@@ -1,6 +1,7 @@
1package main 1package main
2 2
3import ( 3import (
4 "context"
4 "errors" 5 "errors"
5 "fmt" 6 "fmt"
6 "log" 7 "log"
@@ -34,9 +35,14 @@ var clientCmd = &cobra.Command{
34 Run: func(cmd *cobra.Command, args []string) { 35 Run: func(cmd *cobra.Command, args []string) {
35 listenOn := cmd.Flag("listen").Value.String() 36 listenOn := cmd.Flag("listen").Value.String()
36 37
38 // TODO: Handle signals
39 ctx, cancel := context.WithCancel(context.Background())
40 defer cancel()
41
37 h := &ClientHandler{ 42 h := &ClientHandler{
38 SocketListenOn: listenOn, 43 SocketListenOn: listenOn,
39 WebsocketServer: args[0], 44 WebsocketServer: args[0],
45 Context: ctx,
40 } 46 }
41 47
42 log.Printf("Serving on %s", listenOn) 48 log.Printf("Serving on %s", listenOn)
diff --git a/server.go b/server.go
index 47eb3e5..1faf33d 100644
--- a/server.go
+++ b/server.go
@@ -42,5 +42,5 @@ func (h *ServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
42 42
43 log.Println("Connected to SSH server") 43 log.Println("Connected to SSH server")
44 44
45 serviceBoth(wsconn, proxyconn) 45 serviceBoth(wsconn, proxyconn, r.Context())
46} 46}
diff --git a/sockets.go b/sockets.go
index 0ebbe43..2ded258 100644
--- a/sockets.go
+++ b/sockets.go
@@ -1,6 +1,7 @@
1package main 1package main
2 2
3import ( 3import (
4 "context"
4 "io" 5 "io"
5 "log" 6 "log"
6 "net" 7 "net"
@@ -8,7 +9,7 @@ import (
8 "github.com/gorilla/websocket" 9 "github.com/gorilla/websocket"
9) 10)
10 11
11func wsReader(wsconn *websocket.Conn, out chan []byte) { 12func wsReader(wsconn *websocket.Conn, out chan []byte, ctx context.Context) {
12 for { 13 for {
13 messageType, p, err := wsconn.ReadMessage() 14 messageType, p, err := wsconn.ReadMessage()
14 if err != nil { 15 if err != nil {
@@ -19,11 +20,16 @@ func wsReader(wsconn *websocket.Conn, out chan []byte) {
19 log.Println("error: wsReader: only binary messages are supported") 20 log.Println("error: wsReader: only binary messages are supported")
20 continue 21 continue
21 } 22 }
22 out <- p 23 select {
24 case out <- p:
25 continue
26 case <-ctx.Done():
27 return
28 }
23 } 29 }
24} 30}
25 31
26func socketReader(proxyconn net.Conn, out chan []byte) { 32func socketReader(proxyconn net.Conn, out chan []byte, ctx context.Context) {
27 for { 33 for {
28 readBuffer := make([]byte, 2048) 34 readBuffer := make([]byte, 2048)
29 35
@@ -37,16 +43,21 @@ func socketReader(proxyconn net.Conn, out chan []byte) {
37 return 43 return
38 } 44 }
39 45
40 out <- readBuffer[:i] 46 select {
47 case out <- readBuffer[:i]:
48 continue
49 case <-ctx.Done():
50 return
51 }
41 } 52 }
42} 53}
43 54
44func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn) { 55func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn, ctx context.Context) {
45 sc := make(chan []byte) 56 sc := make(chan []byte)
46 wsc := make(chan []byte) 57 wsc := make(chan []byte)
47 58
48 go socketReader(proxyconn, sc) 59 go socketReader(proxyconn, sc, ctx)
49 go wsReader(wsconn, wsc) 60 go wsReader(wsconn, wsc, ctx)
50 61
51 for { 62 for {
52 select { 63 select {
@@ -61,6 +72,8 @@ func serviceBoth(wsconn *websocket.Conn, proxyconn net.Conn) {
61 log.Printf("error: serviceBoth: %s", err) 72 log.Printf("error: serviceBoth: %s", err)
62 return 73 return
63 } 74 }
75 case <-ctx.Done():
76 return
64 } 77 }
65 } 78 }
66} 79}