From a3adc1d396b6ac32828746c699feefa27ab072a7 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 25 Mar 2026 13:36:39 +0000 Subject: [PATCH] Allow to override upstream user The syntax is [_], e.g., * log in as the default job user: 3b07781fc52d4427b3f4e83f16abb104@ssh.dstack.example.com * log in as `root`: 3b07781fc52d4427b3f4e83f16abb104_root@ssh.dstack.example.com --- internal/sshproxy/server.go | 92 +++++++++++++++++++++++++++++++++---- 1 file changed, 82 insertions(+), 10 deletions(-) diff --git a/internal/sshproxy/server.go b/internal/sshproxy/server.go index 99cd8dd..4ccc177 100644 --- a/internal/sshproxy/server.go +++ b/internal/sshproxy/server.go @@ -126,6 +126,21 @@ var ( errUnknownPublicKey = errors.New("unknown public key") ) +// upstreamAuthFailureError represents an SSH client auth failure (that is, SSH_MSG_USERAUTH_FAILURE) when connecting +// to any host in the Upstream.hosts chain (either a jump host or a target host) +type upstreamAuthFailureError struct { + sshErr error + isTargetHost bool +} + +func (e *upstreamAuthFailureError) Error() string { + return fmt.Sprintf("auth failure: %s", e.sshErr.Error()) +} + +func (e *upstreamAuthFailureError) Unwrap() error { + return e.sshErr +} + type Server struct { address string @@ -246,7 +261,7 @@ func (s *Server) Close(ctx context.Context) error { } func (s *Server) publicKeyCallback(conn ssh.ConnMetadata, publicKey ssh.PublicKey) (*ssh.Permissions, error) { - upstreamID := conn.User() + upstreamID, _ := parseAuthUser(conn.User()) logger := log.GetLogger(s.serveCtx).WithField("id", upstreamID) upstream, found := s.upstreamCache.Get(upstreamID) @@ -312,9 +327,22 @@ func handleConnection(ctx context.Context, conn net.Conn, config *ssh.ServerConf logger.Debug("client logged in") upstream := clientConn.Permissions.ExtraData[upstreamExtraDataKey].(Upstream) - upstreamConn, upstreamNewChans, upstreamReqs, err := connectToUpstream(ctx, upstream) + _, user := parseAuthUser(clientConn.User()) + upstreamConn, upstreamNewChans, upstreamReqs, err := connectToUpstream(ctx, upstream, user) if err != nil { - logger.WithError(err).Error("failed to connect to upstream") + logger = logger.WithError(err) + if user != "" { + logger = logger.WithField("user", user) + } + + const msg = "failed to connect to upstream" + // Don't log as an error if it is a client auth error on the last host in the chain and the user is overridden + // to avoid log noise in case a non-existent user is requested + if authErr, ok := errors.AsType[*upstreamAuthFailureError](err); ok && authErr.isTargetHost && user != "" { + logger.Debug(msg) + } else { + logger.Error(msg) + } return } @@ -420,22 +448,40 @@ func handleConnectionError(ctx context.Context, err error) { logger.WithError(err).Error("failed to handshake client") } +// parseAuthUser extracts upstreamID and optional upstreamUser (overrides the default upstream user) +// from the "user name" field of the SSH_MSG_USERAUTH_REQUEST request (the `user` in the `ssh user@hostname` command) +// The optional user is appended to the upstreamID after the `_` delimiter: +// 3b07781fc52d4427b3f4e83f16abb104@ssh.dstack.example.com - log in as the default job user +// 3b07781fc52d4427b3f4e83f16abb104_root@ssh.dstack.example.com - log in as `root` +func parseAuthUser(user string) (upstreamID string, upstreamUser string) { + upstreamID, upstreamUser, _ = strings.Cut(user, "_") + return upstreamID, upstreamUser +} + func connectToUpstream( - ctx context.Context, - upstream Upstream, + ctx context.Context, upstream Upstream, user string, ) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { + logger := log.GetLogger(ctx) + var conn ssh.Conn var chans <-chan ssh.NewChannel var reqs <-chan *ssh.Request - for i, host := range upstream.hosts { + // A target host is the last host in the Upstream.hosts chanin. All other hosts are jump hosts. + targetHostIdx := len(upstream.hosts) - 1 + for hostIdx, host := range upstream.hosts { + isTargetHost := hostIdx == targetHostIdx + hostUser := host.user + if isTargetHost && user != "" { + hostUser = user + } config := &ssh.ClientConfig{ Config: ssh.Config{ KeyExchanges: allowedKeyExchanges, Ciphers: allowedCiphers, MACs: allowedMACs, }, - User: host.user, + User: hostUser, Auth: []ssh.AuthMethod{ ssh.PublicKeys(host.privateKey), }, @@ -446,7 +492,7 @@ func connectToUpstream( var netConn net.Conn var err error - if i == 0 { + if hostIdx == 0 { d := net.Dialer{ Timeout: upstreamDialTimeout, } @@ -457,14 +503,31 @@ func connectToUpstream( netConn, err = client.Dial("tcp", host.address) } + var hostType string + if isTargetHost { + hostType = "target" + } else { + hostType = "jump" + } + if err != nil { - return nil, nil, nil, fmt.Errorf("dial upstream %d %s: %w", i, host.address, err) + return nil, nil, nil, fmt.Errorf("dial %s host #%d %s: %w", hostType, hostIdx, host.address, err) } conn, chans, reqs, err = ssh.NewClientConn(netConn, host.address, config) if err != nil { - return nil, nil, nil, fmt.Errorf("create SSH connection %d %s: %w", i, host.address, err) + if isClientAuthFailureError(err) { + err = &upstreamAuthFailureError{ + sshErr: err, + isTargetHost: isTargetHost, + } + } + + return nil, nil, nil, fmt.Errorf( + "create SSH connection to %s host #%d %s: %w", hostType, hostIdx, host.address, err) } + + logger.Tracef("connected to %s host #%d %s", hostType, hostIdx, host.address) } return conn, chans, reqs, nil @@ -611,3 +674,12 @@ func getSSHError(err error) error { return nil } + +func isClientAuthFailureError(err error) bool { + sshErr := getSSHError(err) + if sshErr == nil { + return false + } + // https://github.com/golang/crypto/blob/982eaa62dfb7273603b97fc1835561450096f3bd/ssh/client_auth.go#L118 + return strings.Contains(sshErr.Error(), "unable to authenticate") +}