package client import ( "bytes" "context" "crypto/ed25519" "crypto/rand" "fmt" "io" "log" "net" "net/http" "os" "strconv" "time" "code.crute.us/mcrute/ssh-proxy/app" "code.crute.us/mcrute/ssh-proxy/proxy" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "code.crute.us/mcrute/golib/cli" "github.com/gorilla/websocket" "github.com/mdp/qrterminal" "github.com/spf13/cobra" ) // This should be compiled into the binary var clientId string func NewClientCommand(appVersion string) *cobra.Command { clientCmd := &cobra.Command{ Use: "client proxy-host ssh-to-host ssh-port username", Short: "Run websocket client", Args: cobra.ExactArgs(3), Version: appVersion, Run: func(c *cobra.Command, args []string) { cfg := app.Config{} cli.MustGetConfig(c, &cfg) clientMain(cfg, args[0], args[1], args[2]) }, } cli.AddFlags(clientCmd, &app.Config{}, app.DefaultConfig, "client") return clientCmd } func Register(root *cobra.Command, appVersion string) { root.AddCommand(NewClientCommand(appVersion)) } func generateCertificateRequest(username, host string) (ed25519.PrivateKey, []byte, error) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, err } pubKey, err := ssh.NewPublicKey(pub) if err != nil { return nil, nil, err } cert := &ssh.Certificate{ Key: pubKey, CertType: ssh.UserCert, ValidPrincipals: []string{username}, Permissions: ssh.Permissions{ Extensions: map[string]string{ // Used for CA policy checks, removed by the CA server // Server supports a comma separated list without spaces "allowed-hosts": host, }, }, } signer, err := ssh.NewSignerFromKey(priv) if err != nil { return nil, nil, err } // Signatures are required to un/marshal to ASCII. The server will // discard this anyhow and replace it with its own signature. if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } return priv, ssh.MarshalAuthorizedKey(cert), nil } func getCertificateFromCA(ctx context.Context, oauthToken string, certRequest []byte, host string) (*ssh.Certificate, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://%s/ca/issue", host), bytes.NewReader(certRequest)) if err != nil { return nil, err } req.Header.Add("Content-Type", "application/x-ssh-certificate") req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", oauthToken)) resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } res, err := io.ReadAll(resp.Body) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("CA returned error: %s", res) } pubkey, _, _, _, err := ssh.ParseAuthorizedKey(res) if err != nil { return nil, err } cert, ok := pubkey.(*ssh.Certificate) if !ok { return nil, fmt.Errorf("Parsed certificate is of incorrect type") } return cert, nil } func connectToAgent() (agent.ExtendedAgent, error) { socket := os.Getenv("SSH_AUTH_SOCK") conn, err := net.Dial("unix", socket) if err != nil { return nil, err } return agent.NewClient(conn), nil } func addCertificateToAgent(conn agent.ExtendedAgent, private any, cert *ssh.Certificate) error { return conn.Add(agent.AddedKey{ PrivateKey: private, Certificate: cert, LifetimeSecs: 10, }) } func dialProxyHost(ctx context.Context, oauthToken, proxyHost, host, port string) (io.ReadWriteCloser, time.Duration, error) { addr := fmt.Sprintf("wss://%s/proxy-to/%s/%s", proxyHost, host, port) hdr := http.Header{} hdr.Add("Authorization", fmt.Sprintf("Bearer %s", oauthToken)) conn, resp, err := websocket.DefaultDialer.DialContext(ctx, addr, hdr) if err != nil { return nil, 0, err } // Extract the connection TTL in seconds from the header if present, // otherwise there's no TTL. If we fail to terminate by the TTL the server // will do it for us. var ttl time.Duration if th := resp.Header.Get("Access-Control-Max-Age"); th != "" { tp, err := strconv.Atoi(th) if err != nil { return nil, 0, err } ttl = time.Duration(tp) * time.Second } return &proxy.WebsocketReadWriter{W: conn}, ttl, nil } func fetchOauthToken(ctx context.Context, clientId, proxyHost string) (string, error) { client := &Oauth2PKCEDeviceClient{ Host: proxyHost, ClientId: clientId, Scope: "ssh:proxy ca:issue", } authResponse, err := client.Authorize(ctx) if err != nil { return "", err } fmt.Fprintf(os.Stderr, "To authenticate, please visit: \n\n\t%s \n\nEnter code: %s\n\n", authResponse.VerificationUri, authResponse.UserCode) if authResponse.VerificationUriComplete != "" { qrterminal.GenerateWithConfig(authResponse.VerificationUriComplete, qrterminal.Config{ Level: qrterminal.M, Writer: os.Stderr, BlackChar: "\033[7m \033[0m", // White WhiteChar: "\033[0m \033[0m", // Black QuietZone: 1, }) fmt.Fprintf(os.Stderr, "\n") } tokenResponse, err := client.AwaitToken(ctx, authResponse.DeviceCode) if err != nil { return "", err } return tokenResponse.AccessToken, nil } func clientMain(cfg app.Config, host, port, username string) { log.SetOutput(os.Stderr) ctx, cancel := context.WithCancel(context.Background()) defer cancel() agentConn, err := connectToAgent() if err != nil { log.Fatalf("Error connecting to agent, is it started?") } oauthToken, err := fetchOauthToken(ctx, clientId, cfg.ClientHost) if err != nil { log.Fatalf("Error fetching oauth token: %s", err) } privateKey, certRequest, err := generateCertificateRequest(username, host) if err != nil { log.Fatalf("Error generating certificate request: %s", err) } certificate, err := getCertificateFromCA(ctx, oauthToken, certRequest, cfg.ClientHost) if err != nil { log.Fatalf("Error fetching certificate: %s", err) } if err := addCertificateToAgent(agentConn, privateKey, certificate); err != nil { log.Fatalf("Error adding certificate to agent: %s", err) } ws, ttl, err := dialProxyHost(ctx, oauthToken, cfg.ClientHost, host, port) if err != nil { log.Fatalf("Error dialing proxy host: %s", err) } defer ws.Close() // Clear the terminal screen fmt.Fprintf(os.Stderr, "\033c") errc := make(chan error) // The server will also force the connection closed but if we do it here we // can give a slightly more friendly error message to the client. if ttl != 0 { log.Printf("Time limited connection, will expire at %s", time.Now().Add(ttl)) time.AfterFunc(ttl, func() { ws.Close() errc <- fmt.Errorf("Connection time limit has expired") }) } go proxy.CopyWithErrors(os.Stdout, ws, errc) go proxy.CopyWithErrors(ws, os.Stdin, errc) err = <-errc if err != nil { log.Printf("Closing client connection: %s", <-errc) os.Exit(1) } else { log.Printf("Closing client connection") os.Exit(0) } }