Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 82 additions & 10 deletions internal/sshproxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ var (
errUnknownPublicKey = errors.New("unknown public key")
)

// upstreamAuthFailureError represents an SSH client auth failure (that is, SSH_MSG_USERAUTH_FAILURE) when connecting
// to any host in the Upstream.hosts chain (either a jump host or a target host)
type upstreamAuthFailureError struct {
sshErr error
isTargetHost bool
}

func (e *upstreamAuthFailureError) Error() string {
return fmt.Sprintf("auth failure: %s", e.sshErr.Error())
}

func (e *upstreamAuthFailureError) Unwrap() error {
return e.sshErr
}

type Server struct {
address string

Expand Down Expand Up @@ -246,7 +261,7 @@ func (s *Server) Close(ctx context.Context) error {
}

func (s *Server) publicKeyCallback(conn ssh.ConnMetadata, publicKey ssh.PublicKey) (*ssh.Permissions, error) {
upstreamID := conn.User()
upstreamID, _ := parseAuthUser(conn.User())
logger := log.GetLogger(s.serveCtx).WithField("id", upstreamID)

upstream, found := s.upstreamCache.Get(upstreamID)
Expand Down Expand Up @@ -312,9 +327,22 @@ func handleConnection(ctx context.Context, conn net.Conn, config *ssh.ServerConf
logger.Debug("client logged in")

upstream := clientConn.Permissions.ExtraData[upstreamExtraDataKey].(Upstream)
upstreamConn, upstreamNewChans, upstreamReqs, err := connectToUpstream(ctx, upstream)
_, user := parseAuthUser(clientConn.User())
upstreamConn, upstreamNewChans, upstreamReqs, err := connectToUpstream(ctx, upstream, user)
if err != nil {
logger.WithError(err).Error("failed to connect to upstream")
logger = logger.WithError(err)
if user != "" {
logger = logger.WithField("user", user)
}

const msg = "failed to connect to upstream"
// Don't log as an error if it is a client auth error on the last host in the chain and the user is overridden
// to avoid log noise in case a non-existent user is requested
if authErr, ok := errors.AsType[*upstreamAuthFailureError](err); ok && authErr.isTargetHost && user != "" {
logger.Debug(msg)
} else {
logger.Error(msg)
}

return
}
Expand Down Expand Up @@ -420,22 +448,40 @@ func handleConnectionError(ctx context.Context, err error) {
logger.WithError(err).Error("failed to handshake client")
}

// parseAuthUser extracts upstreamID and optional upstreamUser (overrides the default upstream user)
// from the "user name" field of the SSH_MSG_USERAUTH_REQUEST request (the `user` in the `ssh user@hostname` command)
// The optional user is appended to the upstreamID after the `_` delimiter:
// 3b07781fc52d4427b3f4e83f16abb104@ssh.dstack.example.com - log in as the default job user
// 3b07781fc52d4427b3f4e83f16abb104_root@ssh.dstack.example.com - log in as `root`
func parseAuthUser(user string) (upstreamID string, upstreamUser string) {
upstreamID, upstreamUser, _ = strings.Cut(user, "_")
return upstreamID, upstreamUser
}

func connectToUpstream(
ctx context.Context,
upstream Upstream,
ctx context.Context, upstream Upstream, user string,
) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) {
logger := log.GetLogger(ctx)

var conn ssh.Conn
var chans <-chan ssh.NewChannel
var reqs <-chan *ssh.Request

for i, host := range upstream.hosts {
// A target host is the last host in the Upstream.hosts chanin. All other hosts are jump hosts.
targetHostIdx := len(upstream.hosts) - 1
for hostIdx, host := range upstream.hosts {
isTargetHost := hostIdx == targetHostIdx
hostUser := host.user
if isTargetHost && user != "" {
hostUser = user
}
config := &ssh.ClientConfig{
Config: ssh.Config{
KeyExchanges: allowedKeyExchanges,
Ciphers: allowedCiphers,
MACs: allowedMACs,
},
User: host.user,
User: hostUser,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(host.privateKey),
},
Expand All @@ -446,7 +492,7 @@ func connectToUpstream(
var netConn net.Conn
var err error

if i == 0 {
if hostIdx == 0 {
d := net.Dialer{
Timeout: upstreamDialTimeout,
}
Expand All @@ -457,14 +503,31 @@ func connectToUpstream(
netConn, err = client.Dial("tcp", host.address)
}

var hostType string
if isTargetHost {
hostType = "target"
} else {
hostType = "jump"
}

if err != nil {
return nil, nil, nil, fmt.Errorf("dial upstream %d %s: %w", i, host.address, err)
return nil, nil, nil, fmt.Errorf("dial %s host #%d %s: %w", hostType, hostIdx, host.address, err)
}

conn, chans, reqs, err = ssh.NewClientConn(netConn, host.address, config)
if err != nil {
return nil, nil, nil, fmt.Errorf("create SSH connection %d %s: %w", i, host.address, err)
if isClientAuthFailureError(err) {
err = &upstreamAuthFailureError{
sshErr: err,
isTargetHost: isTargetHost,
}
}

return nil, nil, nil, fmt.Errorf(
"create SSH connection to %s host #%d %s: %w", hostType, hostIdx, host.address, err)
}

logger.Tracef("connected to %s host #%d %s", hostType, hostIdx, host.address)
}

return conn, chans, reqs, nil
Expand Down Expand Up @@ -611,3 +674,12 @@ func getSSHError(err error) error {

return nil
}

func isClientAuthFailureError(err error) bool {
sshErr := getSSHError(err)
if sshErr == nil {
return false
}
// https://github.com/golang/crypto/blob/982eaa62dfb7273603b97fc1835561450096f3bd/ssh/client_auth.go#L118
return strings.Contains(sshErr.Error(), "unable to authenticate")
}