247 lines
8.0 KiB
Python
247 lines
8.0 KiB
Python
import binascii
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
from .._crypto import AEAD, CryptoError, HeaderProtection
|
|
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
|
|
from .packet import (
|
|
QuicProtocolVersion,
|
|
decode_packet_number,
|
|
is_long_header,
|
|
)
|
|
|
|
CIPHER_SUITES = {
|
|
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
|
|
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
|
|
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
|
|
}
|
|
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
|
|
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
|
|
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
|
|
SAMPLE_SIZE = 16
|
|
|
|
|
|
Callback = Callable[[str], None]
|
|
|
|
|
|
def NoCallback(trigger: str) -> None:
|
|
pass
|
|
|
|
|
|
class KeyUnavailableError(CryptoError):
|
|
pass
|
|
|
|
|
|
def derive_key_iv_hp(
|
|
*, cipher_suite: CipherSuite, secret: bytes, version: int
|
|
) -> Tuple[bytes, bytes, bytes]:
|
|
algorithm = cipher_suite_hash(cipher_suite)
|
|
if cipher_suite in [
|
|
CipherSuite.AES_256_GCM_SHA384,
|
|
CipherSuite.CHACHA20_POLY1305_SHA256,
|
|
]:
|
|
key_size = 32
|
|
else:
|
|
key_size = 16
|
|
if version == QuicProtocolVersion.VERSION_2:
|
|
return (
|
|
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
|
|
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
|
|
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
|
|
)
|
|
else:
|
|
return (
|
|
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
|
|
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
|
|
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
|
|
)
|
|
|
|
|
|
class CryptoContext:
|
|
def __init__(
|
|
self,
|
|
key_phase: int = 0,
|
|
setup_cb: Callback = NoCallback,
|
|
teardown_cb: Callback = NoCallback,
|
|
) -> None:
|
|
self.aead: Optional[AEAD] = None
|
|
self.cipher_suite: Optional[CipherSuite] = None
|
|
self.hp: Optional[HeaderProtection] = None
|
|
self.key_phase = key_phase
|
|
self.secret: Optional[bytes] = None
|
|
self.version: Optional[int] = None
|
|
self._setup_cb = setup_cb
|
|
self._teardown_cb = teardown_cb
|
|
|
|
def decrypt_packet(
|
|
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
|
) -> Tuple[bytes, bytes, int, bool]:
|
|
if self.aead is None:
|
|
raise KeyUnavailableError("Decryption key is not available")
|
|
|
|
# header protection
|
|
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
|
|
first_byte = plain_header[0]
|
|
|
|
# packet number
|
|
pn_length = (first_byte & 0x03) + 1
|
|
packet_number = decode_packet_number(
|
|
packet_number, pn_length * 8, expected_packet_number
|
|
)
|
|
|
|
# detect key phase change
|
|
crypto = self
|
|
if not is_long_header(first_byte):
|
|
key_phase = (first_byte & 4) >> 2
|
|
if key_phase != self.key_phase:
|
|
crypto = next_key_phase(self)
|
|
|
|
# payload protection
|
|
payload = crypto.aead.decrypt(
|
|
packet[len(plain_header) :], plain_header, packet_number
|
|
)
|
|
|
|
return plain_header, payload, packet_number, crypto != self
|
|
|
|
def encrypt_packet(
|
|
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
|
) -> bytes:
|
|
assert self.is_valid(), "Encryption key is not available"
|
|
|
|
# payload protection
|
|
protected_payload = self.aead.encrypt(
|
|
plain_payload, plain_header, packet_number
|
|
)
|
|
|
|
# header protection
|
|
return self.hp.apply(plain_header, protected_payload)
|
|
|
|
def is_valid(self) -> bool:
|
|
return self.aead is not None
|
|
|
|
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
|
|
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
|
|
|
|
key, iv, hp = derive_key_iv_hp(
|
|
cipher_suite=cipher_suite,
|
|
secret=secret,
|
|
version=version,
|
|
)
|
|
self.aead = AEAD(aead_cipher_name, key, iv)
|
|
self.cipher_suite = cipher_suite
|
|
self.hp = HeaderProtection(hp_cipher_name, hp)
|
|
self.secret = secret
|
|
self.version = version
|
|
|
|
# trigger callback
|
|
self._setup_cb("tls")
|
|
|
|
def teardown(self) -> None:
|
|
self.aead = None
|
|
self.cipher_suite = None
|
|
self.hp = None
|
|
self.secret = None
|
|
|
|
# trigger callback
|
|
self._teardown_cb("tls")
|
|
|
|
|
|
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
|
|
self.aead = crypto.aead
|
|
self.key_phase = crypto.key_phase
|
|
self.secret = crypto.secret
|
|
|
|
# trigger callback
|
|
self._setup_cb(trigger)
|
|
|
|
|
|
def next_key_phase(self: CryptoContext) -> CryptoContext:
|
|
algorithm = cipher_suite_hash(self.cipher_suite)
|
|
|
|
crypto = CryptoContext(key_phase=int(not self.key_phase))
|
|
crypto.setup(
|
|
cipher_suite=self.cipher_suite,
|
|
secret=hkdf_expand_label(
|
|
algorithm, self.secret, b"quic ku", b"", algorithm.digest_size
|
|
),
|
|
version=self.version,
|
|
)
|
|
return crypto
|
|
|
|
|
|
class CryptoPair:
|
|
def __init__(
|
|
self,
|
|
recv_setup_cb: Callback = NoCallback,
|
|
recv_teardown_cb: Callback = NoCallback,
|
|
send_setup_cb: Callback = NoCallback,
|
|
send_teardown_cb: Callback = NoCallback,
|
|
) -> None:
|
|
self.aead_tag_size = 16
|
|
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
|
|
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
|
|
self._update_key_requested = False
|
|
|
|
def decrypt_packet(
|
|
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
|
) -> Tuple[bytes, bytes, int]:
|
|
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
|
|
packet, encrypted_offset, expected_packet_number
|
|
)
|
|
if update_key:
|
|
self._update_key("remote_update")
|
|
return plain_header, payload, packet_number
|
|
|
|
def encrypt_packet(
|
|
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
|
) -> bytes:
|
|
if self._update_key_requested:
|
|
self._update_key("local_update")
|
|
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
|
|
|
|
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
|
|
if is_client:
|
|
recv_label, send_label = b"server in", b"client in"
|
|
else:
|
|
recv_label, send_label = b"client in", b"server in"
|
|
|
|
if version == QuicProtocolVersion.VERSION_2:
|
|
initial_salt = INITIAL_SALT_VERSION_2
|
|
else:
|
|
initial_salt = INITIAL_SALT_VERSION_1
|
|
|
|
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
|
|
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
|
|
self.recv.setup(
|
|
cipher_suite=INITIAL_CIPHER_SUITE,
|
|
secret=hkdf_expand_label(
|
|
algorithm, initial_secret, recv_label, b"", algorithm.digest_size
|
|
),
|
|
version=version,
|
|
)
|
|
self.send.setup(
|
|
cipher_suite=INITIAL_CIPHER_SUITE,
|
|
secret=hkdf_expand_label(
|
|
algorithm, initial_secret, send_label, b"", algorithm.digest_size
|
|
),
|
|
version=version,
|
|
)
|
|
|
|
def teardown(self) -> None:
|
|
self.recv.teardown()
|
|
self.send.teardown()
|
|
|
|
def update_key(self) -> None:
|
|
self._update_key_requested = True
|
|
|
|
@property
|
|
def key_phase(self) -> int:
|
|
if self._update_key_requested:
|
|
return int(not self.recv.key_phase)
|
|
else:
|
|
return self.recv.key_phase
|
|
|
|
def _update_key(self, trigger: str) -> None:
|
|
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
|
|
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
|
|
self._update_key_requested = False
|