Skip to content

Commit 76c14dd

Browse files
committed
When possible, defer calling passphrase callback
When loading encrypted keys, AsyncSSH allows a passphrase to be specified either directly as a str/bytes valaue or as a callable or coroutine which takes an argument of the name of the file being loaded, allowing multiple keys to be loaded with different passphrases in a single operation. However, until now, the callbacks were always called at the time the keys were loaded. With this change, the callable/coroutine for a key is only called when the key is actually used to perform a signing operation. This can be useful if the callback prompts the end user for the passphrase and you want to minimize the number of such requests. This feature only applies when the public key associated with an encrypted private key is available. This is alwyays the case when using OpenSSH fotmat, but for older PEM/DER formats, a separate public key might need to be provided.
1 parent 4a83745 commit 76c14dd

File tree

4 files changed

+157
-107
lines changed

4 files changed

+157
-107
lines changed

asyncssh/public_key.py

Lines changed: 147 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others.
1+
# Copyright (c) 2013-2025 by Ron Frederick <[email protected]> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -97,6 +97,8 @@
9797
_KeyPairArg = Union['SSHKeyPair', _KeyArg, Tuple[_KeyArg, _CertArg]]
9898
KeyPairListArg = Union[_KeyPairArg, Sequence[_KeyPairArg]]
9999

100+
_PassphraseCallable = Callable[[str], BytesOrStr]
101+
_PassphraseArg = Optional[Union[_PassphraseCallable, BytesOrStr]]
100102

101103
# Default file names in .ssh directory to read private keys from
102104
_DEFAULT_KEY_FILES = (
@@ -192,6 +194,51 @@ def _wrap_base64(data: bytes, wrap: int = 64) -> bytes:
192194
for i in range(0, len(data), wrap)) + b'\n'
193195

194196

197+
def _resolve_passphrase(
198+
passphrase: _PassphraseArg, filename: str,
199+
loop: Optional[asyncio.AbstractEventLoop]) -> Optional[BytesOrStr]:
200+
"""Resolve a passphrase used to encrypt/decrypt SSH private keys"""
201+
202+
resolved_passphrase: Optional[BytesOrStr]
203+
204+
if callable(passphrase):
205+
resolved_passphrase = passphrase(filename)
206+
else:
207+
resolved_passphrase = passphrase
208+
209+
if loop and inspect.isawaitable(resolved_passphrase):
210+
resolved_passphrase = asyncio.run_coroutine_threadsafe(
211+
resolved_passphrase, loop).result()
212+
213+
return resolved_passphrase
214+
215+
216+
class _EncryptedKey:
217+
"""Encrypted SSH private key, decrypted just prior to use"""
218+
219+
def __init__(self, key_data: bytes, filename: str,
220+
passphrase: _PassphraseArg,
221+
loop: Optional[asyncio.AbstractEventLoop],
222+
unsafe_skip_rsa_key_validation: bool):
223+
self._key_data = key_data
224+
self._filename = filename
225+
self._passphrase = passphrase
226+
self._loop = loop
227+
self._unsafe_skip_rsa_key_validation = unsafe_skip_rsa_key_validation
228+
229+
def decrypt(self) -> 'SSHKey':
230+
"""Decrypt this encrypted key data and return an SSH private key"""
231+
232+
resolved_passphrase = _resolve_passphrase(self._passphrase,
233+
self._filename, self._loop)
234+
235+
key = import_private_key(self._key_data, resolved_passphrase,
236+
self._unsafe_skip_rsa_key_validation)
237+
key.set_filename(self._filename)
238+
239+
return key
240+
241+
195242
class KeyGenerationError(ValueError):
196243
"""Key generation error
197244
@@ -2238,8 +2285,9 @@ class SSHLocalKeyPair(SSHKeyPair):
22382285

22392286
_key_type = 'local'
22402287

2241-
def __init__(self, key: SSHKey, pubkey: Optional[SSHKey] = None,
2242-
cert: Optional[SSHCertificate] = None):
2288+
def __init__(self, key: SSHKey, pubkey: Optional[SSHKey],
2289+
cert: Optional[SSHCertificate],
2290+
enc_key: Optional[_EncryptedKey]):
22432291
if pubkey and pubkey.public_data != key.public_data:
22442292
raise ValueError('Public key mismatch')
22452293

@@ -2254,10 +2302,11 @@ def __init__(self, key: SSHKey, pubkey: Optional[SSHKey] = None,
22542302

22552303
super().__init__(key.algorithm, key.algorithm, key.sig_algorithms,
22562304
key.sig_algorithms, key.public_data, comment,
2257-
cert, key.get_filename(), key.use_executor,
2258-
key.use_webauthn)
2305+
cert, key.get_filename(), key.use_executor or
2306+
bool(enc_key), key.use_webauthn)
22592307

22602308
self._key = key
2309+
self._enc_key = enc_key
22612310

22622311
def get_agent_private_key(self) -> bytes:
22632312
"""Return binary encoding of keypair for upload to SSH agent"""
@@ -2273,6 +2322,12 @@ def get_agent_private_key(self) -> bytes:
22732322
def sign(self, data: bytes) -> bytes:
22742323
"""Sign a block of data with this private key"""
22752324

2325+
if self._enc_key:
2326+
self._key = self._enc_key.decrypt()
2327+
self._enc_key = None
2328+
2329+
self.use_executor = self._key.use_executor
2330+
22762331
return self._key.sign(data, self.sig_algorithm)
22772332

22782333

@@ -2368,7 +2423,7 @@ def _match_block(data: bytes, start: int, header: bytes,
23682423
"""Match a block of data wrapped in a header/footer"""
23692424

23702425
match = re.compile(b'^' + header[:5] + b'END' + header[10:] +
2371-
rb'[ \t\r\f\v]*$', re.M).search(data, start)
2426+
rb'[ \t\n\r\f\v]*$', re.M).search(data, start)
23722427

23732428
if not match:
23742429
raise KeyImportError(f'Missing {fmt} footer')
@@ -3203,21 +3258,6 @@ def import_private_key(
32033258
raise KeyImportError('Invalid private key')
32043259

32053260

3206-
def import_private_key_and_certs(
3207-
data: bytes, passphrase: Optional[BytesOrStr] = None,
3208-
unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \
3209-
Tuple[SSHKey, Optional[SSHX509CertificateChain]]:
3210-
"""Import a private key and optional certificate chain"""
3211-
3212-
key, end = _decode_private(data, passphrase,
3213-
unsafe_skip_rsa_key_validation)
3214-
3215-
if key:
3216-
return key, import_certificate_chain(data[end:])
3217-
else:
3218-
raise KeyImportError('Invalid private key')
3219-
3220-
32213261
def import_public_key(data: BytesOrStr) -> SSHKey:
32223262
"""Import a public key
32233263
@@ -3339,20 +3379,6 @@ def read_private_key(
33393379
return key
33403380

33413381

3342-
def read_private_key_and_certs(
3343-
filename: FilePath, passphrase: Optional[BytesOrStr] = None,
3344-
unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \
3345-
Tuple[SSHKey, Optional[SSHX509CertificateChain]]:
3346-
"""Read a private key and optional certificate chain from a file"""
3347-
3348-
key, cert = import_private_key_and_certs(read_file(filename), passphrase,
3349-
unsafe_skip_rsa_key_validation)
3350-
3351-
key.set_filename(filename)
3352-
3353-
return key, cert
3354-
3355-
33563382
def read_public_key(filename: FilePath) -> SSHKey:
33573383
"""Read a public key from a file
33583384
@@ -3512,31 +3538,37 @@ def load_keypairs(
35123538
"""
35133539

35143540
keys_to_load: Sequence[_KeyPairArg]
3541+
key_data: Optional[bytes]
3542+
key: Union['SSHKey', 'SSHKeyPair']
35153543
result: List[SSHKeyPair] = []
35163544

35173545
certlist = load_certificates(certlist)
35183546
certdict = {cert.key.public_data: cert for cert in certlist}
35193547

35203548
if isinstance(keylist, (PurePath, str)):
3521-
try:
3522-
if callable(passphrase):
3523-
resolved_passphrase = passphrase(str(keylist))
3524-
else:
3525-
resolved_passphrase = passphrase
3549+
data = read_file(keylist)
3550+
key_data_list: List[bytes] = []
35263551

3527-
if loop and inspect.isawaitable(resolved_passphrase):
3528-
resolved_passphrase = asyncio.run_coroutine_threadsafe(
3529-
resolved_passphrase, loop).result()
3552+
while data:
3553+
fmt, _, end = _match_next(data, b'PRIVATE KEY')
3554+
if fmt:
3555+
key_data_list.append(data[:end])
35303556

3531-
priv_keys = read_private_key_list(keylist, resolved_passphrase,
3532-
unsafe_skip_rsa_key_validation)
3557+
data = data[end:]
35333558

3534-
if len(priv_keys) <= 1:
3535-
keys_to_load = [keylist]
3536-
passphrase = resolved_passphrase
3537-
else:
3538-
keys_to_load = priv_keys
3539-
except KeyImportError:
3559+
if len(key_data_list) > 1:
3560+
resolved_passphrase = _resolve_passphrase(passphrase,
3561+
str(keylist), loop)
3562+
3563+
keys_to_load = []
3564+
3565+
for key_data in key_data_list:
3566+
key = import_private_key(key_data, resolved_passphrase,
3567+
unsafe_skip_rsa_key_validation)
3568+
key.set_filename(keylist)
3569+
3570+
keys_to_load.append(key)
3571+
else:
35403572
keys_to_load = [keylist]
35413573
elif isinstance(keylist, (tuple, bytes, SSHKey, SSHKeyPair)):
35423574
keys_to_load = [cast(_KeyPairArg, keylist)]
@@ -3545,61 +3577,37 @@ def load_keypairs(
35453577

35463578
for key_to_load in keys_to_load:
35473579
allow_certs = False
3548-
key_prefix = None
3549-
saved_exc = None
3580+
key_data = None
3581+
key_prefix = ''
35503582
pubkey_or_certs = None
3551-
pubkey_to_load: Optional[_KeyArg] = None
35523583
certs_to_load: Optional[_CertArg] = None
3553-
key: Union['SSHKey', 'SSHKeyPair']
3584+
pubkey_to_load: Optional[_KeyArg] = None
3585+
saved_exc = None
3586+
enc_key: Optional[_EncryptedKey] = None
35543587

35553588
if isinstance(key_to_load, (PurePath, str, bytes)):
35563589
allow_certs = True
35573590
elif isinstance(key_to_load, tuple):
35583591
key_to_load, pubkey_or_certs = key_to_load
35593592

3560-
try:
3561-
if isinstance(key_to_load, (PurePath, str)):
3562-
key_prefix = str(key_to_load)
3593+
if isinstance(key_to_load, (PurePath, str)):
3594+
key_prefix = str(key_to_load)
3595+
key_data = read_file(key_to_load)
3596+
elif isinstance(key_to_load, bytes):
3597+
key_data = key_to_load
35633598

3564-
if callable(passphrase):
3565-
resolved_passphrase = passphrase(key_prefix)
3566-
else:
3567-
resolved_passphrase = passphrase
3599+
certs: Optional[Sequence[SSHCertificate]]
35683600

3569-
if loop and inspect.isawaitable(resolved_passphrase):
3570-
resolved_passphrase = asyncio.run_coroutine_threadsafe(
3571-
resolved_passphrase, loop).result()
3601+
if allow_certs:
3602+
assert key_data is not None
35723603

3573-
if allow_certs:
3574-
key, certs_to_load = read_private_key_and_certs(
3575-
key_to_load, resolved_passphrase,
3576-
unsafe_skip_rsa_key_validation)
3604+
_, _, end = _match_next(key_data, b'PRIVATE KEY')
35773605

3578-
if not certs_to_load:
3579-
certs_to_load = key_prefix + '-cert.pub'
3580-
else:
3581-
key = read_private_key(key_to_load, resolved_passphrase,
3582-
unsafe_skip_rsa_key_validation)
3583-
3584-
pubkey_to_load = key_prefix + '.pub'
3585-
elif isinstance(key_to_load, bytes):
3586-
if allow_certs:
3587-
key, certs_to_load = import_private_key_and_certs(
3588-
key_to_load, passphrase,
3589-
unsafe_skip_rsa_key_validation)
3590-
else:
3591-
key = import_private_key(key_to_load, passphrase,
3592-
unsafe_skip_rsa_key_validation)
3593-
else:
3594-
key = key_to_load
3595-
except KeyImportError as exc:
3596-
if skip_public or \
3597-
(ignore_encrypted and str(exc).startswith('Passphrase')):
3598-
continue
3599-
3600-
raise
3606+
certs_to_load = import_certificate_chain(key_data[end:])
3607+
key_data = key_data[:end]
36013608

3602-
certs: Optional[Sequence[SSHCertificate]]
3609+
if not certs_to_load:
3610+
certs_to_load = key_prefix + '-cert.pub'
36033611

36043612
if pubkey_or_certs:
36053613
try:
@@ -3613,7 +3621,7 @@ def load_keypairs(
36133621
elif certs_to_load:
36143622
try:
36153623
certs = load_certificates(certs_to_load)
3616-
except (OSError, KeyImportError):
3624+
except (OSError, KeyImportError) as exc:
36173625
certs = None
36183626
else:
36193627
certs = None
@@ -3628,16 +3636,58 @@ def load_keypairs(
36283636
pubkey = import_public_key(pubkey_to_load)
36293637
else:
36303638
pubkey = pubkey_to_load
3639+
3640+
saved_exc = None
36313641
except (OSError, KeyImportError):
36323642
pubkey = None
3633-
else:
3643+
elif key_prefix:
3644+
try:
3645+
pubkey = read_public_key(key_prefix + '.pub')
36343646
saved_exc = None
3647+
except (OSError, KeyImportError):
3648+
try:
3649+
pubkey = read_public_key(key_prefix)
3650+
saved_exc = None
3651+
except (OSError, KeyImportError):
3652+
pubkey = None
36353653
else:
36363654
pubkey = None
36373655

36383656
if saved_exc:
36393657
raise saved_exc # pylint: disable=raising-bad-type
36403658

3659+
if key_data is not None:
3660+
try:
3661+
unencrypted_key = import_private_key(
3662+
key_data, None, unsafe_skip_rsa_key_validation)
3663+
unencrypted_key.set_filename(key_prefix)
3664+
except KeyImportError:
3665+
unencrypted_key = None
3666+
3667+
if unencrypted_key:
3668+
key = unencrypted_key
3669+
elif callable(passphrase) and key_prefix and (certs or pubkey):
3670+
enc_key = _EncryptedKey(key_data, key_prefix, passphrase, loop,
3671+
unsafe_skip_rsa_key_validation)
3672+
3673+
key = certs[0].key if certs else pubkey
3674+
else:
3675+
try:
3676+
resolved_passphrase = _resolve_passphrase(passphrase,
3677+
key_prefix, loop)
3678+
3679+
key = import_private_key(key_data, passphrase,
3680+
unsafe_skip_rsa_key_validation)
3681+
key.set_filename(key_prefix)
3682+
except KeyImportError as exc:
3683+
if skip_public or (ignore_encrypted and
3684+
str(exc).startswith('Passphrase')):
3685+
continue
3686+
3687+
raise
3688+
else:
3689+
key = cast(Union[SSHKey, SSHKeyPair], key_to_load)
3690+
36413691
if not certs:
36423692
if isinstance(key, SSHKeyPair):
36433693
pubdata = key.key_public_data
@@ -3660,9 +3710,9 @@ def load_keypairs(
36603710
result.append(key)
36613711
else:
36623712
if cert:
3663-
result.append(SSHLocalKeyPair(key, pubkey, cert))
3713+
result.append(SSHLocalKeyPair(key, pubkey, cert, enc_key))
36643714

3665-
result.append(SSHLocalKeyPair(key, pubkey))
3715+
result.append(SSHLocalKeyPair(key, pubkey, None, enc_key))
36663716

36673717
return result
36683718

tests/test_agent.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2016-2024 by Ron Frederick <[email protected]> and others.
1+
# Copyright (c) 2016-2025 by Ron Frederick <[email protected]> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -321,9 +321,8 @@ async def test_add_sk_keys(self):
321321
async with agent:
322322
self.assertIsNone(await agent.add_keys([keypair]))
323323

324-
async with agent:
325-
with self.assertRaises(asyncssh.KeyExportError):
326-
await agent.add_keys([key.convert_to_public()])
324+
with self.assertRaises(asyncssh.KeyExportError):
325+
await agent.add_keys([key.convert_to_public()])
327326

328327
await mock_agent.stop()
329328

0 commit comments

Comments
 (0)