summaryrefslogtreecommitdiff
path: root/cmd/client/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/client/client.go')
-rw-r--r--cmd/client/client.go226
1 files changed, 226 insertions, 0 deletions
diff --git a/cmd/client/client.go b/cmd/client/client.go
new file mode 100644
index 0000000..62f1f48
--- /dev/null
+++ b/cmd/client/client.go
@@ -0,0 +1,226 @@
1package client
2
3import (
4 "bytes"
5 "context"
6 "crypto/ed25519"
7 "crypto/rand"
8 "fmt"
9 "io"
10 "log"
11 "net"
12 "net/http"
13 "os"
14
15 "code.crute.us/mcrute/ssh-proxy/app"
16 "code.crute.us/mcrute/ssh-proxy/proxy"
17 "golang.org/x/crypto/ssh"
18 "golang.org/x/crypto/ssh/agent"
19
20 "code.crute.us/mcrute/golib/cli"
21 "github.com/gorilla/websocket"
22 "github.com/mdp/qrterminal"
23 "github.com/spf13/cobra"
24)
25
26// This should be compiled into the binary
27var clientId string
28
29func Register(root *cobra.Command) {
30 clientCmd := &cobra.Command{
31 Use: "client proxy-host ssh-to-host ssh-port username",
32 Short: "Run websocket client",
33 Args: cobra.ExactArgs(4),
34 Run: func(c *cobra.Command, args []string) {
35 cfg := app.Config{}
36 cli.MustGetConfig(c, &cfg)
37 clientMain(cfg, args[0], args[1], args[2], args[3])
38 },
39 }
40 cli.AddFlags(clientCmd, &app.Config{}, app.DefaultConfig, "client")
41 root.AddCommand(clientCmd)
42}
43
44func generateCertificateRequest(username, host string) (ed25519.PrivateKey, []byte, error) {
45 pub, priv, err := ed25519.GenerateKey(rand.Reader)
46 if err != nil {
47 return nil, nil, err
48 }
49
50 pubKey, err := ssh.NewPublicKey(pub)
51 if err != nil {
52 return nil, nil, err
53 }
54
55 cert := &ssh.Certificate{
56 Key: pubKey,
57 CertType: ssh.UserCert,
58 ValidPrincipals: []string{username},
59 Permissions: ssh.Permissions{
60 Extensions: map[string]string{
61 // Used for CA policy checks, removed by the CA server
62 // Server supports a comma separated list without spaces
63 "allowed-hosts": host,
64 },
65 },
66 }
67
68 signer, err := ssh.NewSignerFromKey(priv)
69 if err != nil {
70 return nil, nil, err
71 }
72
73 // Signatures are required to un/marshal to ASCII. The server will
74 // discard this anyhow and replace it with its own signature.
75 if err := cert.SignCert(rand.Reader, signer); err != nil {
76 return nil, nil, err
77 }
78
79 return priv, ssh.MarshalAuthorizedKey(cert), nil
80}
81
82func getCertificateFromCA(ctx context.Context, oauthToken string, certRequest []byte, host string) (*ssh.Certificate, error) {
83 req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://%s/ca/issue", host), bytes.NewReader(certRequest))
84 if err != nil {
85 return nil, err
86 }
87
88 req.Header.Add("Content-Type", "application/x-ssh-certificate")
89 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", oauthToken))
90
91 resp, err := http.DefaultClient.Do(req)
92 if err != nil {
93 return nil, err
94 }
95
96 res, err := io.ReadAll(resp.Body)
97 if err != nil {
98 return nil, err
99 }
100 defer resp.Body.Close()
101
102 if resp.StatusCode != http.StatusOK {
103 return nil, fmt.Errorf("CA returned error: %s", res)
104 }
105
106 pubkey, _, _, _, err := ssh.ParseAuthorizedKey(res)
107 if err != nil {
108 return nil, err
109 }
110
111 cert, ok := pubkey.(*ssh.Certificate)
112 if !ok {
113 return nil, fmt.Errorf("Parsed certificate is of incorrect type")
114 }
115
116 return cert, nil
117}
118
119func addCertificateToAgent(private any, cert *ssh.Certificate) error {
120 socket := os.Getenv("SSH_AUTH_SOCK")
121 conn, err := net.Dial("unix", socket)
122 if err != nil {
123 return err
124 }
125
126 agentConn := agent.NewClient(conn)
127
128 return agentConn.Add(agent.AddedKey{
129 PrivateKey: private,
130 Certificate: cert,
131 LifetimeSecs: 10,
132 })
133}
134
135func dialProxyHost(ctx context.Context, oauthToken, proxyHost, host, port string) (io.ReadWriteCloser, error) {
136 addr := fmt.Sprintf("wss://%s/proxy-to/%s/%s", proxyHost, host, port)
137
138 hdr := http.Header{}
139 hdr.Add("Authorization", fmt.Sprintf("Bearer %s", oauthToken))
140
141 conn, _, err := websocket.DefaultDialer.DialContext(ctx, addr, hdr)
142 if err != nil {
143 return nil, err
144 }
145
146 return &proxy.WebsocketReadWriter{W: conn}, nil
147}
148
149func fetchOauthToken(ctx context.Context, clientId, proxyHost string) (string, error) {
150 client := &Oauth2PKCEDeviceClient{
151 Host: proxyHost,
152 ClientId: clientId,
153 Scope: "ssh:proxy ca:issue",
154 }
155
156 authResponse, err := client.Authorize(ctx)
157 if err != nil {
158 return "", err
159 }
160
161 fmt.Fprintf(os.Stderr,
162 "To authenticate, please visit: \n\n\t%s \n\nEnter code: %s\n\n",
163 authResponse.VerificationUri, authResponse.UserCode)
164
165 if authResponse.VerificationUriComplete != "" {
166 qrterminal.GenerateWithConfig(authResponse.VerificationUriComplete, qrterminal.Config{
167 Level: qrterminal.M,
168 Writer: os.Stderr,
169 BlackChar: "\033[7m \033[0m", // White
170 WhiteChar: "\033[0m \033[0m", // Black
171 QuietZone: 1,
172 })
173 fmt.Fprintf(os.Stderr, "\n")
174 }
175
176 tokenResponse, err := client.AwaitToken(ctx, authResponse.DeviceCode)
177 if err != nil {
178 return "", err
179 }
180
181 return tokenResponse.AccessToken, nil
182}
183
184func clientMain(cfg app.Config, proxyHost, host, port, username string) {
185 log.SetOutput(os.Stderr)
186
187 ctx, cancel := context.WithCancel(context.Background())
188 defer cancel()
189
190 oauthToken, err := fetchOauthToken(ctx, clientId, proxyHost)
191 if err != nil {
192 log.Fatalf("Error fetching oauth token: %s", err)
193 }
194
195 privateKey, certRequest, err := generateCertificateRequest(username, host)
196 if err != nil {
197 log.Fatalf("Error generating certificate request: %s", err)
198 }
199
200 certificate, err := getCertificateFromCA(ctx, oauthToken, certRequest, proxyHost)
201 if err != nil {
202 log.Fatalf("Error fetching certificate: %s", err)
203 }
204
205 if err := addCertificateToAgent(privateKey, certificate); err != nil {
206 log.Fatalf("Error adding certificate to agent: %s", err)
207 }
208
209 ws, err := dialProxyHost(ctx, oauthToken, proxyHost, host, port)
210 if err != nil {
211 log.Fatalf("Error dialing proxy host: %s", err)
212 }
213 defer ws.Close()
214
215 errc := make(chan error)
216
217 go proxy.CopyWithErrors(os.Stdout, ws, errc)
218 go proxy.CopyWithErrors(ws, os.Stdin, errc)
219
220 err = <-errc
221 if err != nil {
222 log.Printf("Closing client connection: %s", <-errc)
223 } else {
224 log.Printf("Closing client connection")
225 }
226}