This commit is contained in:
mofixx
2025-08-08 10:41:30 +02:00
parent 4444be3799
commit a5df3861fd
1674 changed files with 234266 additions and 0 deletions

View File

@ -0,0 +1,163 @@
from dataclasses import dataclass, field
from os import PathLike
from re import split
from typing import Any, List, Optional, TextIO, Union
from ..tls import (
CipherSuite,
SessionTicket,
load_pem_private_key,
load_pem_x509_certificates,
)
from .logger import QuicLogger
from .packet import QuicProtocolVersion
SMALLEST_MAX_DATAGRAM_SIZE = 1200
@dataclass
class QuicConfiguration:
"""
A QUIC configuration.
"""
alpn_protocols: Optional[List[str]] = None
"""
A list of supported ALPN protocols.
"""
congestion_control_algorithm: str = "reno"
"""
The name of the congestion control algorithm to use.
Currently supported algorithms: `"reno", `"cubic"`.
"""
connection_id_length: int = 8
"""
The length in bytes of local connection IDs.
"""
idle_timeout: float = 60.0
"""
The idle timeout in seconds.
The connection is terminated if nothing is received for the given duration.
"""
is_client: bool = True
"""
Whether this is the client side of the QUIC connection.
"""
max_data: int = 1048576
"""
Connection-wide flow control limit.
"""
max_datagram_size: int = SMALLEST_MAX_DATAGRAM_SIZE
"""
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
"""
max_stream_data: int = 1048576
"""
Per-stream flow control limit.
"""
quic_logger: Optional[QuicLogger] = None
"""
The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to.
"""
secrets_log_file: TextIO = None
"""
A file-like object in which to log traffic secrets.
This is useful to analyze traffic captures with Wireshark.
"""
server_name: Optional[str] = None
"""
The server name to use when verifying the server's TLS certificate, which
can either be a DNS name or an IP address.
If it is a DNS name, it is also sent during the TLS handshake in the
Server Name Indication (SNI) extension.
.. note:: This is only used by clients.
"""
session_ticket: Optional[SessionTicket] = None
"""
The TLS session ticket which should be used for session resumption.
"""
token: bytes = b""
"""
The address validation token that can be used to validate future connections.
.. note:: This is only used by clients.
"""
# For internal purposes, not guaranteed to be stable.
cadata: Optional[bytes] = None
cafile: Optional[str] = None
capath: Optional[str] = None
certificate: Any = None
certificate_chain: List[Any] = field(default_factory=list)
cipher_suites: Optional[List[CipherSuite]] = None
initial_rtt: float = 0.1
max_datagram_frame_size: Optional[int] = None
original_version: Optional[int] = None
private_key: Any = None
quantum_readiness_test: bool = False
supported_versions: List[int] = field(
default_factory=lambda: [
QuicProtocolVersion.VERSION_1,
QuicProtocolVersion.VERSION_2,
]
)
verify_mode: Optional[int] = None
def load_cert_chain(
self,
certfile: PathLike,
keyfile: Optional[PathLike] = None,
password: Optional[Union[bytes, str]] = None,
) -> None:
"""
Load a private key and the corresponding certificate.
"""
with open(certfile, "rb") as fp:
boundary = b"-----BEGIN PRIVATE KEY-----\n"
chunks = split(b"\n" + boundary, fp.read())
certificates = load_pem_x509_certificates(chunks[0])
if len(chunks) == 2:
private_key = boundary + chunks[1]
self.private_key = load_pem_private_key(private_key)
self.certificate = certificates[0]
self.certificate_chain = certificates[1:]
if keyfile is not None:
with open(keyfile, "rb") as fp:
self.private_key = load_pem_private_key(
fp.read(),
password=password.encode("utf8")
if isinstance(password, str)
else password,
)
def load_verify_locations(
self,
cafile: Optional[str] = None,
capath: Optional[str] = None,
cadata: Optional[bytes] = None,
) -> None:
"""
Load a set of "certification authority" (CA) certificates used to
validate other peers' certificates.
"""
self.cafile = cafile
self.capath = capath
self.cadata = cadata

View File

@ -0,0 +1,128 @@
import abc
from typing import Any, Dict, Iterable, Optional, Protocol
from ..packet_builder import QuicSentPacket
K_GRANULARITY = 0.001 # seconds
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2
class QuicCongestionControl(abc.ABC):
"""
Base class for congestion control implementations.
"""
bytes_in_flight: int = 0
congestion_window: int = 0
ssthresh: Optional[int] = None
def __init__(self, *, max_datagram_size: int) -> None:
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size
@abc.abstractmethod
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ...
@abc.abstractmethod
def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ...
@abc.abstractmethod
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ...
@abc.abstractmethod
def on_packets_lost(
self, *, now: float, packets: Iterable[QuicSentPacket]
) -> None: ...
@abc.abstractmethod
def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ...
def get_log_data(self) -> Dict[str, Any]:
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
if self.ssthresh is not None:
data["ssthresh"] = self.ssthresh
return data
class QuicCongestionControlFactory(Protocol):
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ...
class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""
def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5
self._filtered_min: Optional[float] = None
self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]
def add_rtt(self, *, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1
if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True
if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample
def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt=rtt)
self._sample_time = now
if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max
delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False
_factories: Dict[str, QuicCongestionControlFactory] = {}
def create_congestion_control(
name: str, *, max_datagram_size: int
) -> QuicCongestionControl:
"""
Create an instance of the `name` congestion control algorithm.
"""
try:
factory = _factories[name]
except KeyError:
raise Exception(f"Unknown congestion control algorithm: {name}")
return factory(max_datagram_size=max_datagram_size)
def register_congestion_control(
name: str, factory: QuicCongestionControlFactory
) -> None:
"""
Register a congestion control algorithm named `name`.
"""
_factories[name] = factory

View File

@ -0,0 +1,212 @@
from typing import Any, Dict, Iterable
from ..packet_builder import QuicSentPacket
from .base import (
K_INITIAL_WINDOW,
K_MINIMUM_WINDOW,
QuicCongestionControl,
QuicRttMonitor,
register_congestion_control,
)
# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions)
K_CUBIC_C = 0.4
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity
def better_cube_root(x: float) -> float:
if x < 0:
# avoid precision errors that make the cube root returns an imaginary number
return -((-x) ** (1.0 / 3.0))
else:
return (x) ** (1.0 / 3.0)
class CubicCongestionControl(QuicCongestionControl):
"""
Cubic congestion control implementation for aioquic
"""
def __init__(self, max_datagram_size: int) -> None:
super().__init__(max_datagram_size=max_datagram_size)
# increase by one segment
self.additive_increase_factor: int = max_datagram_size
self._max_datagram_size: int = max_datagram_size
self._congestion_recovery_start_time = 0.0
self._rtt_monitor = QuicRttMonitor()
self.rtt = 0.02 # starting RTT is considered to be 20ms
self.reset()
self.last_ack = 0.0
def W_cubic(self, t) -> int:
W_max_segments = self._W_max / self._max_datagram_size
target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments)
return int(target_segments * self._max_datagram_size)
def is_reno_friendly(self, t) -> bool:
return self.W_cubic(t) < self._W_est
def is_concave(self) -> bool:
return self.congestion_window < self._W_max
def reset(self) -> None:
self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size
self.ssthresh = None
self._first_slow_start = True
self._starting_congestion_avoidance = False
self.K: float = 0.0
self._W_est = 0
self._cwnd_epoch = 0
self._t_epoch = 0.0
self._W_max = self.congestion_window
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
self.last_ack = packet.sent_time
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
if self._first_slow_start and not self._starting_congestion_avoidance:
# exiting slow start without having a loss
self._first_slow_start = False
self._W_max = self.congestion_window
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
# calculate K
W_max_segments = self._W_max / self._max_datagram_size
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
self.K = better_cube_root(
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
)
# initialize the variables used at start of congestion avoidance
if self._starting_congestion_avoidance:
self._starting_congestion_avoidance = False
self._first_slow_start = False
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
# calculate K
W_max_segments = self._W_max / self._max_datagram_size
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
self.K = better_cube_root(
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
)
self._W_est = int(
self._W_est
+ self.additive_increase_factor
* (packet.sent_bytes / self.congestion_window)
)
t = now - self._t_epoch
target: int = 0
W_cubic = self.W_cubic(t + self.rtt)
if W_cubic < self.congestion_window:
target = self.congestion_window
elif W_cubic > 1.5 * self.congestion_window:
target = int(self.congestion_window * 1.5)
else:
target = W_cubic
if self.is_reno_friendly(t):
# reno friendly region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region)
self.congestion_window = self._W_est
elif self.is_concave():
# concave region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region)
self.congestion_window = int(
self.congestion_window
+ (
(target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
)
else:
# convex region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region)
self.congestion_window = int(
self.congestion_window
+ (
(target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
)
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
if self.last_ack == 0.0:
return
elapsed_idle = packet.sent_time - self.last_ack
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
self.reset()
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
# Normal congestion handle, can't be used in same time as fast convergence
# self._W_max = self.congestion_window
# fast convergence
if self._W_max is not None and self.congestion_window < self._W_max:
self._W_max = int(
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
)
else:
self._W_max = self.congestion_window
# normal congestion MD
flight_size = self.bytes_in_flight
new_ssthresh = max(
int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR),
K_MINIMUM_WINDOW * self._max_datagram_size,
)
self.ssthresh = new_ssthresh
self.congestion_window = max(
self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size
)
# restart a new congestion avoidance phase
self._starting_congestion_avoidance = True
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
self.rtt = rtt
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
rtt=rtt, now=now
):
self.ssthresh = self.congestion_window
def get_log_data(self) -> Dict[str, Any]:
data = super().get_log_data()
data["cubic-wmax"] = int(self._W_max)
return data
register_congestion_control("cubic", CubicCongestionControl)

View File

@ -0,0 +1,77 @@
from typing import Iterable
from ..packet_builder import QuicSentPacket
from .base import (
K_MINIMUM_WINDOW,
QuicCongestionControl,
QuicRttMonitor,
register_congestion_control,
)
K_LOSS_REDUCTION_FACTOR = 0.5
class RenoCongestionControl(QuicCongestionControl):
"""
New Reno congestion control.
"""
def __init__(self, *, max_datagram_size: int) -> None:
super().__init__(max_datagram_size=max_datagram_size)
self._max_datagram_size = max_datagram_size
self._congestion_recovery_start_time = 0.0
self._congestion_stash = 0
self._rtt_monitor = QuicRttMonitor()
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
# don't increase window in congestion recovery
if packet.sent_time <= self._congestion_recovery_start_time:
return
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
self._congestion_stash += packet.sent_bytes
count = self._congestion_stash // self.congestion_window
if count:
self._congestion_stash -= count * self.congestion_window
self.congestion_window += count * self._max_datagram_size
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
self.congestion_window = max(
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR),
K_MINIMUM_WINDOW * self._max_datagram_size,
)
self.ssthresh = self.congestion_window
# TODO : collapse congestion window if persistent congestion
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
now=now, rtt=rtt
):
self.ssthresh = self.congestion_window
register_congestion_control("reno", RenoCongestionControl)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,246 @@
import binascii
from typing import Callable, Optional, Tuple
from .._crypto import AEAD, CryptoError, HeaderProtection
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
from .packet import (
QuicProtocolVersion,
decode_packet_number,
is_long_header,
)
CIPHER_SUITES = {
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
}
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
SAMPLE_SIZE = 16
Callback = Callable[[str], None]
def NoCallback(trigger: str) -> None:
pass
class KeyUnavailableError(CryptoError):
pass
def derive_key_iv_hp(
*, cipher_suite: CipherSuite, secret: bytes, version: int
) -> Tuple[bytes, bytes, bytes]:
algorithm = cipher_suite_hash(cipher_suite)
if cipher_suite in [
CipherSuite.AES_256_GCM_SHA384,
CipherSuite.CHACHA20_POLY1305_SHA256,
]:
key_size = 32
else:
key_size = 16
if version == QuicProtocolVersion.VERSION_2:
return (
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
)
else:
return (
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
)
class CryptoContext:
def __init__(
self,
key_phase: int = 0,
setup_cb: Callback = NoCallback,
teardown_cb: Callback = NoCallback,
) -> None:
self.aead: Optional[AEAD] = None
self.cipher_suite: Optional[CipherSuite] = None
self.hp: Optional[HeaderProtection] = None
self.key_phase = key_phase
self.secret: Optional[bytes] = None
self.version: Optional[int] = None
self._setup_cb = setup_cb
self._teardown_cb = teardown_cb
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> Tuple[bytes, bytes, int, bool]:
if self.aead is None:
raise KeyUnavailableError("Decryption key is not available")
# header protection
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
first_byte = plain_header[0]
# packet number
pn_length = (first_byte & 0x03) + 1
packet_number = decode_packet_number(
packet_number, pn_length * 8, expected_packet_number
)
# detect key phase change
crypto = self
if not is_long_header(first_byte):
key_phase = (first_byte & 4) >> 2
if key_phase != self.key_phase:
crypto = next_key_phase(self)
# payload protection
payload = crypto.aead.decrypt(
packet[len(plain_header) :], plain_header, packet_number
)
return plain_header, payload, packet_number, crypto != self
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
assert self.is_valid(), "Encryption key is not available"
# payload protection
protected_payload = self.aead.encrypt(
plain_payload, plain_header, packet_number
)
# header protection
return self.hp.apply(plain_header, protected_payload)
def is_valid(self) -> bool:
return self.aead is not None
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
key, iv, hp = derive_key_iv_hp(
cipher_suite=cipher_suite,
secret=secret,
version=version,
)
self.aead = AEAD(aead_cipher_name, key, iv)
self.cipher_suite = cipher_suite
self.hp = HeaderProtection(hp_cipher_name, hp)
self.secret = secret
self.version = version
# trigger callback
self._setup_cb("tls")
def teardown(self) -> None:
self.aead = None
self.cipher_suite = None
self.hp = None
self.secret = None
# trigger callback
self._teardown_cb("tls")
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
self.aead = crypto.aead
self.key_phase = crypto.key_phase
self.secret = crypto.secret
# trigger callback
self._setup_cb(trigger)
def next_key_phase(self: CryptoContext) -> CryptoContext:
algorithm = cipher_suite_hash(self.cipher_suite)
crypto = CryptoContext(key_phase=int(not self.key_phase))
crypto.setup(
cipher_suite=self.cipher_suite,
secret=hkdf_expand_label(
algorithm, self.secret, b"quic ku", b"", algorithm.digest_size
),
version=self.version,
)
return crypto
class CryptoPair:
def __init__(
self,
recv_setup_cb: Callback = NoCallback,
recv_teardown_cb: Callback = NoCallback,
send_setup_cb: Callback = NoCallback,
send_teardown_cb: Callback = NoCallback,
) -> None:
self.aead_tag_size = 16
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
self._update_key_requested = False
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> Tuple[bytes, bytes, int]:
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
packet, encrypted_offset, expected_packet_number
)
if update_key:
self._update_key("remote_update")
return plain_header, payload, packet_number
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
if self._update_key_requested:
self._update_key("local_update")
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
if is_client:
recv_label, send_label = b"server in", b"client in"
else:
recv_label, send_label = b"client in", b"server in"
if version == QuicProtocolVersion.VERSION_2:
initial_salt = INITIAL_SALT_VERSION_2
else:
initial_salt = INITIAL_SALT_VERSION_1
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
self.recv.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, recv_label, b"", algorithm.digest_size
),
version=version,
)
self.send.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, send_label, b"", algorithm.digest_size
),
version=version,
)
def teardown(self) -> None:
self.recv.teardown()
self.send.teardown()
def update_key(self) -> None:
self._update_key_requested = True
@property
def key_phase(self) -> int:
if self._update_key_requested:
return int(not self.recv.key_phase)
else:
return self.recv.key_phase
def _update_key(self, trigger: str) -> None:
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
self._update_key_requested = False

View File

@ -0,0 +1,126 @@
from dataclasses import dataclass
from typing import Optional
class QuicEvent:
"""
Base class for QUIC events.
"""
pass
@dataclass
class ConnectionIdIssued(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionIdRetired(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionTerminated(QuicEvent):
"""
The ConnectionTerminated event is fired when the QUIC connection is terminated.
"""
error_code: int
"The error code which was specified when closing the connection."
frame_type: Optional[int]
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
@dataclass
class DatagramFrameReceived(QuicEvent):
"""
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
"""
data: bytes
"The data which was received."
@dataclass
class HandshakeCompleted(QuicEvent):
"""
The HandshakeCompleted event is fired when the TLS handshake completes.
"""
alpn_protocol: Optional[str]
"The protocol which was negotiated using ALPN, or `None`."
early_data_accepted: bool
"Whether early (0-RTT) data was accepted by the remote peer."
session_resumed: bool
"Whether a TLS session was resumed."
@dataclass
class PingAcknowledged(QuicEvent):
"""
The PingAcknowledged event is fired when a PING frame is acknowledged.
"""
uid: int
"The unique ID of the PING."
@dataclass
class ProtocolNegotiated(QuicEvent):
"""
The ProtocolNegotiated event is fired when ALPN negotiation completes.
"""
alpn_protocol: Optional[str]
"The protocol which was negotiated using ALPN, or `None`."
@dataclass
class StopSendingReceived(QuicEvent):
"""
The StopSendingReceived event is fired when the remote peer requests
stopping data transmission on a stream.
"""
error_code: int
"The error code that was sent from the peer."
stream_id: int
"The ID of the stream that the peer requested stopping data transmission."
@dataclass
class StreamDataReceived(QuicEvent):
"""
The StreamDataReceived event is fired whenever data is received on a
stream.
"""
data: bytes
"The data which was received."
end_stream: bool
"Whether the STREAM frame had the FIN bit set."
stream_id: int
"The ID of the stream the data was received for."
@dataclass
class StreamReset(QuicEvent):
"""
The StreamReset event is fired when the remote peer resets a stream.
"""
error_code: int
"The error code that triggered the reset."
stream_id: int
"The ID of the stream that was reset."

View File

@ -0,0 +1,329 @@
import binascii
import json
import os
import time
from collections import deque
from typing import Any, Deque, Dict, List, Optional
from ..h3.events import Headers
from .packet import (
QuicFrameType,
QuicPacketType,
QuicStreamFrame,
QuicTransportParameters,
)
from .rangeset import RangeSet
PACKET_TYPE_NAMES = {
QuicPacketType.INITIAL: "initial",
QuicPacketType.HANDSHAKE: "handshake",
QuicPacketType.ZERO_RTT: "0RTT",
QuicPacketType.ONE_RTT: "1RTT",
QuicPacketType.RETRY: "retry",
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
}
QLOG_VERSION = "0.3"
def hexdump(data: bytes) -> str:
return binascii.hexlify(data).decode("ascii")
class QuicLoggerTrace:
"""
A QUIC event trace.
Events are logged in the format defined by qlog.
See:
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
"""
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
self._odcid = odcid
self._events: Deque[Dict[str, Any]] = deque()
self._vantage_point = {
"name": "aioquic",
"type": "client" if is_client else "server",
}
# QUIC
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict:
return {
"ack_delay": self.encode_time(delay),
"acked_ranges": [[x.start, x.stop - 1] for x in ranges],
"frame_type": "ack",
}
def encode_connection_close_frame(
self, error_code: int, frame_type: Optional[int], reason_phrase: str
) -> Dict:
attrs = {
"error_code": error_code,
"error_space": "application" if frame_type is None else "transport",
"frame_type": "connection_close",
"raw_error_code": error_code,
"reason": reason_phrase,
}
if frame_type is not None:
attrs["trigger_frame_type"] = frame_type
return attrs
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict:
if frame_type == QuicFrameType.MAX_DATA:
return {"frame_type": "max_data", "maximum": maximum}
else:
return {
"frame_type": "max_streams",
"maximum": maximum,
"stream_type": "unidirectional"
if frame_type == QuicFrameType.MAX_STREAMS_UNI
else "bidirectional",
}
def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict:
return {
"frame_type": "crypto",
"length": len(frame.data),
"offset": frame.offset,
}
def encode_data_blocked_frame(self, limit: int) -> Dict:
return {"frame_type": "data_blocked", "limit": limit}
def encode_datagram_frame(self, length: int) -> Dict:
return {"frame_type": "datagram", "length": length}
def encode_handshake_done_frame(self) -> Dict:
return {"frame_type": "handshake_done"}
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict:
return {
"frame_type": "max_stream_data",
"maximum": maximum,
"stream_id": stream_id,
}
def encode_new_connection_id_frame(
self,
connection_id: bytes,
retire_prior_to: int,
sequence_number: int,
stateless_reset_token: bytes,
) -> Dict:
return {
"connection_id": hexdump(connection_id),
"frame_type": "new_connection_id",
"length": len(connection_id),
"reset_token": hexdump(stateless_reset_token),
"retire_prior_to": retire_prior_to,
"sequence_number": sequence_number,
}
def encode_new_token_frame(self, token: bytes) -> Dict:
return {
"frame_type": "new_token",
"length": len(token),
"token": hexdump(token),
}
def encode_padding_frame(self) -> Dict:
return {"frame_type": "padding"}
def encode_path_challenge_frame(self, data: bytes) -> Dict:
return {"data": hexdump(data), "frame_type": "path_challenge"}
def encode_path_response_frame(self, data: bytes) -> Dict:
return {"data": hexdump(data), "frame_type": "path_response"}
def encode_ping_frame(self) -> Dict:
return {"frame_type": "ping"}
def encode_reset_stream_frame(
self, error_code: int, final_size: int, stream_id: int
) -> Dict:
return {
"error_code": error_code,
"final_size": final_size,
"frame_type": "reset_stream",
"stream_id": stream_id,
}
def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict:
return {
"frame_type": "retire_connection_id",
"sequence_number": sequence_number,
}
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict:
return {
"frame_type": "stream_data_blocked",
"limit": limit,
"stream_id": stream_id,
}
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict:
return {
"frame_type": "stop_sending",
"error_code": error_code,
"stream_id": stream_id,
}
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict:
return {
"fin": frame.fin,
"frame_type": "stream",
"length": len(frame.data),
"offset": frame.offset,
"stream_id": stream_id,
}
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict:
return {
"frame_type": "streams_blocked",
"limit": limit,
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
}
def encode_time(self, seconds: float) -> float:
"""
Convert a time to milliseconds.
"""
return seconds * 1000
def encode_transport_parameters(
self, owner: str, parameters: QuicTransportParameters
) -> Dict[str, Any]:
data: Dict[str, Any] = {"owner": owner}
for param_name, param_value in parameters.__dict__.items():
if isinstance(param_value, bool):
data[param_name] = param_value
elif isinstance(param_value, bytes):
data[param_name] = hexdump(param_value)
elif isinstance(param_value, int):
data[param_name] = param_value
return data
def packet_type(self, packet_type: QuicPacketType) -> str:
return PACKET_TYPE_NAMES[packet_type]
# HTTP/3
def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict:
return {
"frame": {"frame_type": "data"},
"length": length,
"stream_id": stream_id,
}
def encode_http3_headers_frame(
self, length: int, headers: Headers, stream_id: int
) -> Dict:
return {
"frame": {
"frame_type": "headers",
"headers": self._encode_http3_headers(headers),
},
"length": length,
"stream_id": stream_id,
}
def encode_http3_push_promise_frame(
self, length: int, headers: Headers, push_id: int, stream_id: int
) -> Dict:
return {
"frame": {
"frame_type": "push_promise",
"headers": self._encode_http3_headers(headers),
"push_id": push_id,
},
"length": length,
"stream_id": stream_id,
}
def _encode_http3_headers(self, headers: Headers) -> List[Dict]:
return [
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
]
# CORE
def log_event(self, *, category: str, event: str, data: Dict) -> None:
self._events.append(
{
"data": data,
"name": category + ":" + event,
"time": self.encode_time(time.time()),
}
)
def to_dict(self) -> Dict[str, Any]:
"""
Return the trace as a dictionary which can be written as JSON.
"""
return {
"common_fields": {
"ODCID": hexdump(self._odcid),
},
"events": list(self._events),
"vantage_point": self._vantage_point,
}
class QuicLogger:
"""
A QUIC event logger which stores traces in memory.
"""
def __init__(self) -> None:
self._traces: List[QuicLoggerTrace] = []
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
self._traces.append(trace)
return trace
def end_trace(self, trace: QuicLoggerTrace) -> None:
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
def to_dict(self) -> Dict[str, Any]:
"""
Return the traces as a dictionary which can be written as JSON.
"""
return {
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace.to_dict() for trace in self._traces],
}
class QuicFileLogger(QuicLogger):
"""
A QUIC event logger which writes one trace per file.
"""
def __init__(self, path: str) -> None:
if not os.path.isdir(path):
raise ValueError("QUIC log output directory '%s' does not exist" % path)
self.path = path
super().__init__()
def end_trace(self, trace: QuicLoggerTrace) -> None:
trace_dict = trace.to_dict()
trace_path = os.path.join(
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
)
with open(trace_path, "w") as logger_fp:
json.dump(
{
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace_dict],
},
logger_fp,
)
self._traces.remove(trace)

View File

@ -0,0 +1,640 @@
import binascii
import ipaddress
import os
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import List, Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ..buffer import Buffer
from .rangeset import RangeSet
PACKET_LONG_HEADER = 0x80
PACKET_FIXED_BIT = 0x40
PACKET_SPIN_BIT = 0x20
CONNECTION_ID_MAX_SIZE = 20
PACKET_NUMBER_MAX_SIZE = 4
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
RETRY_INTEGRITY_TAG_SIZE = 16
STATELESS_RESET_TOKEN_SIZE = 16
class QuicErrorCode(IntEnum):
NO_ERROR = 0x0
INTERNAL_ERROR = 0x1
CONNECTION_REFUSED = 0x2
FLOW_CONTROL_ERROR = 0x3
STREAM_LIMIT_ERROR = 0x4
STREAM_STATE_ERROR = 0x5
FINAL_SIZE_ERROR = 0x6
FRAME_ENCODING_ERROR = 0x7
TRANSPORT_PARAMETER_ERROR = 0x8
CONNECTION_ID_LIMIT_ERROR = 0x9
PROTOCOL_VIOLATION = 0xA
INVALID_TOKEN = 0xB
APPLICATION_ERROR = 0xC
CRYPTO_BUFFER_EXCEEDED = 0xD
KEY_UPDATE_ERROR = 0xE
AEAD_LIMIT_REACHED = 0xF
VERSION_NEGOTIATION_ERROR = 0x11
CRYPTO_ERROR = 0x100
class QuicPacketType(Enum):
INITIAL = 0
ZERO_RTT = 1
HANDSHAKE = 2
RETRY = 3
VERSION_NEGOTIATION = 4
ONE_RTT = 5
# For backwards compatibility only, use `QuicPacketType` in new code.
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
# QUIC version 1
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
QuicPacketType.INITIAL: 0,
QuicPacketType.ZERO_RTT: 1,
QuicPacketType.HANDSHAKE: 2,
QuicPacketType.RETRY: 3,
}
PACKET_LONG_TYPE_DECODE_VERSION_1 = dict(
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
)
# QUIC version 2
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
QuicPacketType.INITIAL: 1,
QuicPacketType.ZERO_RTT: 2,
QuicPacketType.HANDSHAKE: 3,
QuicPacketType.RETRY: 0,
}
PACKET_LONG_TYPE_DECODE_VERSION_2 = dict(
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
)
class QuicProtocolVersion(IntEnum):
NEGOTIATION = 0
VERSION_1 = 0x00000001
VERSION_2 = 0x6B3343CF
@dataclass
class QuicHeader:
version: Optional[int]
"The protocol version. Only present in long header packets."
packet_type: QuicPacketType
"The type of the packet."
packet_length: int
"The total length of the packet, in bytes."
destination_cid: bytes
"The destination connection ID."
source_cid: bytes
"The destination connection ID."
token: bytes
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
integrity_tag: bytes
"The retry integrity tag. Only present in `RETRY` packets."
supported_versions: List[int]
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
"""
Recover a packet number from a truncated packet number.
See: Appendix A - Sample Packet Number Decoding Algorithm
"""
window = 1 << num_bits
half_window = window // 2
candidate = (expected & ~(window - 1)) | truncated
if candidate <= expected - half_window and candidate < (1 << 62) - window:
return candidate + window
elif candidate > expected + half_window and candidate >= window:
return candidate - window
else:
return candidate
def get_retry_integrity_tag(
packet_without_tag: bytes, original_destination_cid: bytes, version: int
) -> bytes:
"""
Calculate the integrity tag for a RETRY packet.
"""
# build Retry pseudo packet
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
buf.push_uint8(len(original_destination_cid))
buf.push_bytes(original_destination_cid)
buf.push_bytes(packet_without_tag)
assert buf.eof()
if version == QuicProtocolVersion.VERSION_2:
aead_key = RETRY_AEAD_KEY_VERSION_2
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
else:
aead_key = RETRY_AEAD_KEY_VERSION_1
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
# run AES-128-GCM
aead = AESGCM(aead_key)
integrity_tag = aead.encrypt(aead_nonce, b"", buf.data)
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
return integrity_tag
def get_spin_bit(first_byte: int) -> bool:
return bool(first_byte & PACKET_SPIN_BIT)
def is_long_header(first_byte: int) -> bool:
return bool(first_byte & PACKET_LONG_HEADER)
def pretty_protocol_version(version: int) -> str:
"""
Return a user-friendly representation of a protocol version.
"""
try:
version_name = QuicProtocolVersion(version).name
except ValueError:
version_name = "UNKNOWN"
return f"0x{version:08x} ({version_name})"
def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader:
packet_start = buf.tell()
version = None
integrity_tag = b""
supported_versions = []
token = b""
first_byte = buf.pull_uint8()
if is_long_header(first_byte):
# Long Header Packets.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
version = buf.pull_uint32()
destination_cid_length = buf.pull_uint8()
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError(
"Destination CID is too long (%d bytes)" % destination_cid_length
)
destination_cid = buf.pull_bytes(destination_cid_length)
source_cid_length = buf.pull_uint8()
if source_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError("Source CID is too long (%d bytes)" % source_cid_length)
source_cid = buf.pull_bytes(source_cid_length)
if version == QuicProtocolVersion.NEGOTIATION:
# Version Negotiation Packet.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
packet_type = QuicPacketType.VERSION_NEGOTIATION
while not buf.eof():
supported_versions.append(buf.pull_uint32())
packet_end = buf.tell()
else:
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
if version == QuicProtocolVersion.VERSION_2:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
(first_byte & 0x30) >> 4
]
else:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
(first_byte & 0x30) >> 4
]
if packet_type == QuicPacketType.INITIAL:
token_length = buf.pull_uint_var()
token = buf.pull_bytes(token_length)
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.ZERO_RTT:
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.HANDSHAKE:
rest_length = buf.pull_uint_var()
else:
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
token = buf.pull_bytes(token_length)
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
rest_length = 0
# Check remainder length.
packet_end = buf.tell() + rest_length
if packet_end > buf.capacity:
raise ValueError("Packet payload is truncated")
else:
# Short Header Packets.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
version = None
packet_type = QuicPacketType.ONE_RTT
destination_cid = buf.pull_bytes(host_cid_length)
source_cid = b""
packet_end = buf.capacity
return QuicHeader(
version=version,
packet_type=packet_type,
packet_length=packet_end - packet_start,
destination_cid=destination_cid,
source_cid=source_cid,
token=token,
integrity_tag=integrity_tag,
supported_versions=supported_versions,
)
def encode_long_header_first_byte(
version: int, packet_type: QuicPacketType, bits: int
) -> int:
"""
Encode the first byte of a long header packet.
"""
if version == QuicProtocolVersion.VERSION_2:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
else:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
return (
PACKET_LONG_HEADER
| PACKET_FIXED_BIT
| long_type_encode[packet_type] << 4
| bits
)
def encode_quic_retry(
version: int,
source_cid: bytes,
destination_cid: bytes,
original_destination_cid: bytes,
retry_token: bytes,
unused: int = 0,
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ len(retry_token)
+ RETRY_INTEGRITY_TAG_SIZE
)
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
buf.push_uint32(version)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
buf.push_bytes(retry_token)
buf.push_bytes(
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
)
assert buf.eof()
return buf.data
def encode_quic_version_negotiation(
source_cid: bytes, destination_cid: bytes, supported_versions: List[int]
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ 4 * len(supported_versions)
)
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
for version in supported_versions:
buf.push_uint32(version)
return buf.data
# TLS EXTENSION
@dataclass
class QuicPreferredAddress:
ipv4_address: Optional[Tuple[str, int]]
ipv6_address: Optional[Tuple[str, int]]
connection_id: bytes
stateless_reset_token: bytes
@dataclass
class QuicVersionInformation:
chosen_version: int
available_versions: List[int]
@dataclass
class QuicTransportParameters:
original_destination_connection_id: Optional[bytes] = None
max_idle_timeout: Optional[int] = None
stateless_reset_token: Optional[bytes] = None
max_udp_payload_size: Optional[int] = None
initial_max_data: Optional[int] = None
initial_max_stream_data_bidi_local: Optional[int] = None
initial_max_stream_data_bidi_remote: Optional[int] = None
initial_max_stream_data_uni: Optional[int] = None
initial_max_streams_bidi: Optional[int] = None
initial_max_streams_uni: Optional[int] = None
ack_delay_exponent: Optional[int] = None
max_ack_delay: Optional[int] = None
disable_active_migration: Optional[bool] = False
preferred_address: Optional[QuicPreferredAddress] = None
active_connection_id_limit: Optional[int] = None
initial_source_connection_id: Optional[bytes] = None
retry_source_connection_id: Optional[bytes] = None
version_information: Optional[QuicVersionInformation] = None
max_datagram_frame_size: Optional[int] = None
quantum_readiness: Optional[bytes] = None
PARAMS = {
0x00: ("original_destination_connection_id", bytes),
0x01: ("max_idle_timeout", int),
0x02: ("stateless_reset_token", bytes),
0x03: ("max_udp_payload_size", int),
0x04: ("initial_max_data", int),
0x05: ("initial_max_stream_data_bidi_local", int),
0x06: ("initial_max_stream_data_bidi_remote", int),
0x07: ("initial_max_stream_data_uni", int),
0x08: ("initial_max_streams_bidi", int),
0x09: ("initial_max_streams_uni", int),
0x0A: ("ack_delay_exponent", int),
0x0B: ("max_ack_delay", int),
0x0C: ("disable_active_migration", bool),
0x0D: ("preferred_address", QuicPreferredAddress),
0x0E: ("active_connection_id_limit", int),
0x0F: ("initial_source_connection_id", bytes),
0x10: ("retry_source_connection_id", bytes),
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
0x11: ("version_information", QuicVersionInformation),
# extensions
0x0020: ("max_datagram_frame_size", int),
0x0C37: ("quantum_readiness", bytes),
}
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
ipv4_address = None
ipv4_host = buf.pull_bytes(4)
ipv4_port = buf.pull_uint16()
if ipv4_host != bytes(4):
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
ipv6_address = None
ipv6_host = buf.pull_bytes(16)
ipv6_port = buf.pull_uint16()
if ipv6_host != bytes(16):
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
connection_id_length = buf.pull_uint8()
connection_id = buf.pull_bytes(connection_id_length)
stateless_reset_token = buf.pull_bytes(16)
return QuicPreferredAddress(
ipv4_address=ipv4_address,
ipv6_address=ipv6_address,
connection_id=connection_id,
stateless_reset_token=stateless_reset_token,
)
def push_quic_preferred_address(
buf: Buffer, preferred_address: QuicPreferredAddress
) -> None:
if preferred_address.ipv4_address is not None:
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
buf.push_uint16(preferred_address.ipv4_address[1])
else:
buf.push_bytes(bytes(6))
if preferred_address.ipv6_address is not None:
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
buf.push_uint16(preferred_address.ipv6_address[1])
else:
buf.push_bytes(bytes(18))
buf.push_uint8(len(preferred_address.connection_id))
buf.push_bytes(preferred_address.connection_id)
buf.push_bytes(preferred_address.stateless_reset_token)
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
chosen_version = buf.pull_uint32()
available_versions = []
for i in range(length // 4 - 1):
available_versions.append(buf.pull_uint32())
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
# equal to zero, it MUST treat it as a parsing failure.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if chosen_version == 0 or 0 in available_versions:
raise ValueError("Version Information must not contain version 0")
return QuicVersionInformation(
chosen_version=chosen_version,
available_versions=available_versions,
)
def push_quic_version_information(
buf: Buffer, version_information: QuicVersionInformation
) -> None:
buf.push_uint32(version_information.chosen_version)
for version in version_information.available_versions:
buf.push_uint32(version)
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
params = QuicTransportParameters()
while not buf.eof():
param_id = buf.pull_uint_var()
param_len = buf.pull_uint_var()
param_start = buf.tell()
if param_id in PARAMS:
# Parse known parameter.
param_name, param_type = PARAMS[param_id]
if param_type is int:
setattr(params, param_name, buf.pull_uint_var())
elif param_type is bytes:
setattr(params, param_name, buf.pull_bytes(param_len))
elif param_type is QuicPreferredAddress:
setattr(params, param_name, pull_quic_preferred_address(buf))
elif param_type is QuicVersionInformation:
setattr(
params,
param_name,
pull_quic_version_information(buf, param_len),
)
else:
setattr(params, param_name, True)
else:
# Skip unknown parameter.
buf.pull_bytes(param_len)
if buf.tell() != param_start + param_len:
raise ValueError("Transport parameter length does not match")
return params
def push_quic_transport_parameters(
buf: Buffer, params: QuicTransportParameters
) -> None:
for param_id, (param_name, param_type) in PARAMS.items():
param_value = getattr(params, param_name)
if param_value is not None and param_value is not False:
param_buf = Buffer(capacity=65536)
if param_type is int:
param_buf.push_uint_var(param_value)
elif param_type is bytes:
param_buf.push_bytes(param_value)
elif param_type is QuicPreferredAddress:
push_quic_preferred_address(param_buf, param_value)
elif param_type is QuicVersionInformation:
push_quic_version_information(param_buf, param_value)
buf.push_uint_var(param_id)
buf.push_uint_var(param_buf.tell())
buf.push_bytes(param_buf.data)
# FRAMES
class QuicFrameType(IntEnum):
PADDING = 0x00
PING = 0x01
ACK = 0x02
ACK_ECN = 0x03
RESET_STREAM = 0x04
STOP_SENDING = 0x05
CRYPTO = 0x06
NEW_TOKEN = 0x07
STREAM_BASE = 0x08
MAX_DATA = 0x10
MAX_STREAM_DATA = 0x11
MAX_STREAMS_BIDI = 0x12
MAX_STREAMS_UNI = 0x13
DATA_BLOCKED = 0x14
STREAM_DATA_BLOCKED = 0x15
STREAMS_BLOCKED_BIDI = 0x16
STREAMS_BLOCKED_UNI = 0x17
NEW_CONNECTION_ID = 0x18
RETIRE_CONNECTION_ID = 0x19
PATH_CHALLENGE = 0x1A
PATH_RESPONSE = 0x1B
TRANSPORT_CLOSE = 0x1C
APPLICATION_CLOSE = 0x1D
HANDSHAKE_DONE = 0x1E
DATAGRAM = 0x30
DATAGRAM_WITH_LENGTH = 0x31
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.PADDING,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
PROBING_FRAME_TYPES = frozenset(
[
QuicFrameType.PATH_CHALLENGE,
QuicFrameType.PATH_RESPONSE,
QuicFrameType.PADDING,
QuicFrameType.NEW_CONNECTION_ID,
]
)
@dataclass
class QuicResetStreamFrame:
error_code: int
final_size: int
stream_id: int
@dataclass
class QuicStopSendingFrame:
error_code: int
stream_id: int
@dataclass
class QuicStreamFrame:
data: bytes = b""
fin: bool = False
offset: int = 0
def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]:
rangeset = RangeSet()
end = buf.pull_uint_var() # largest acknowledged
delay = buf.pull_uint_var()
ack_range_count = buf.pull_uint_var()
ack_count = buf.pull_uint_var() # first ack range
rangeset.add(end - ack_count, end + 1)
end -= ack_count
for _ in range(ack_range_count):
end -= buf.pull_uint_var() + 2
ack_count = buf.pull_uint_var()
rangeset.add(end - ack_count, end + 1)
end -= ack_count
return rangeset, delay
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
ranges = len(rangeset)
index = ranges - 1
r = rangeset[index]
buf.push_uint_var(r.stop - 1)
buf.push_uint_var(delay)
buf.push_uint_var(index)
buf.push_uint_var(r.stop - 1 - r.start)
start = r.start
while index > 0:
index -= 1
r = rangeset[index]
buf.push_uint_var(start - r.stop - 1)
buf.push_uint_var(r.stop - r.start - 1)
start = r.start
return ranges

View File

@ -0,0 +1,384 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from ..buffer import Buffer, size_uint_var
from ..tls import Epoch
from .crypto import CryptoPair
from .logger import QuicLoggerTrace
from .packet import (
NON_ACK_ELICITING_FRAME_TYPES,
NON_IN_FLIGHT_FRAME_TYPES,
PACKET_FIXED_BIT,
PACKET_NUMBER_MAX_SIZE,
QuicFrameType,
QuicPacketType,
encode_long_header_first_byte,
)
PACKET_LENGTH_SEND_SIZE = 2
PACKET_NUMBER_SEND_SIZE = 2
QuicDeliveryHandler = Callable[..., None]
class QuicDeliveryState(Enum):
ACKED = 0
LOST = 1
@dataclass
class QuicSentPacket:
epoch: Epoch
in_flight: bool
is_ack_eliciting: bool
is_crypto_packet: bool
packet_number: int
packet_type: QuicPacketType
sent_time: Optional[float] = None
sent_bytes: int = 0
delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field(
default_factory=list
)
quic_logger_frames: List[Dict] = field(default_factory=list)
class QuicPacketBuilderStop(Exception):
pass
class QuicPacketBuilder:
"""
Helper for building QUIC packets.
"""
def __init__(
self,
*,
host_cid: bytes,
peer_cid: bytes,
version: int,
is_client: bool,
max_datagram_size: int,
packet_number: int = 0,
peer_token: bytes = b"",
quic_logger: Optional[QuicLoggerTrace] = None,
spin_bit: bool = False,
):
self.max_flight_bytes: Optional[int] = None
self.max_total_bytes: Optional[int] = None
self.quic_logger_frames: Optional[List[Dict]] = None
self._host_cid = host_cid
self._is_client = is_client
self._peer_cid = peer_cid
self._peer_token = peer_token
self._quic_logger = quic_logger
self._spin_bit = spin_bit
self._version = version
# assembled datagrams and packets
self._datagrams: List[bytes] = []
self._datagram_flight_bytes = 0
self._datagram_init = True
self._datagram_needs_padding = False
self._packets: List[QuicSentPacket] = []
self._flight_bytes = 0
self._total_bytes = 0
# current packet
self._header_size = 0
self._packet: Optional[QuicSentPacket] = None
self._packet_crypto: Optional[CryptoPair] = None
self._packet_number = packet_number
self._packet_start = 0
self._packet_type: Optional[QuicPacketType] = None
self._buffer = Buffer(max_datagram_size)
self._buffer_capacity = max_datagram_size
self._flight_capacity = max_datagram_size
@property
def packet_is_empty(self) -> bool:
"""
Returns `True` if the current packet is empty.
"""
assert self._packet is not None
packet_size = self._buffer.tell() - self._packet_start
return packet_size <= self._header_size
@property
def packet_number(self) -> int:
"""
Returns the packet number for the next packet.
"""
return self._packet_number
@property
def remaining_buffer_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._buffer_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
@property
def remaining_flight_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._flight_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]:
"""
Returns the assembled datagrams.
"""
if self._packet is not None:
self._end_packet()
self._flush_current_datagram()
datagrams = self._datagrams
packets = self._packets
self._datagrams = []
self._packets = []
return datagrams, packets
def start_frame(
self,
frame_type: int,
capacity: int = 1,
handler: Optional[QuicDeliveryHandler] = None,
handler_args: Sequence[Any] = [],
) -> Buffer:
"""
Starts a new frame.
"""
if self.remaining_buffer_space < capacity or (
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
and self.remaining_flight_space < capacity
):
raise QuicPacketBuilderStop
self._buffer.push_uint_var(frame_type)
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
self._packet.is_ack_eliciting = True
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
self._packet.in_flight = True
if frame_type == QuicFrameType.CRYPTO:
self._packet.is_crypto_packet = True
if handler is not None:
self._packet.delivery_handlers.append((handler, handler_args))
return self._buffer
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
"""
Starts a new packet.
"""
assert packet_type in (
QuicPacketType.INITIAL,
QuicPacketType.HANDSHAKE,
QuicPacketType.ZERO_RTT,
QuicPacketType.ONE_RTT,
), "Invalid packet type"
buf = self._buffer
# finish previous datagram
if self._packet is not None:
self._end_packet()
# if there is too little space remaining, start a new datagram
# FIXME: the limit is arbitrary!
packet_start = buf.tell()
if self._buffer_capacity - packet_start < 128:
self._flush_current_datagram()
packet_start = 0
# initialize datagram if needed
if self._datagram_init:
if self.max_total_bytes is not None:
remaining_total_bytes = self.max_total_bytes - self._total_bytes
if remaining_total_bytes < self._buffer_capacity:
self._buffer_capacity = remaining_total_bytes
self._flight_capacity = self._buffer_capacity
if self.max_flight_bytes is not None:
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
if remaining_flight_bytes < self._flight_capacity:
self._flight_capacity = remaining_flight_bytes
self._datagram_flight_bytes = 0
self._datagram_init = False
self._datagram_needs_padding = False
# calculate header size
if packet_type != QuicPacketType.ONE_RTT:
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
if packet_type == QuicPacketType.INITIAL:
token_length = len(self._peer_token)
header_size += size_uint_var(token_length) + token_length
else:
header_size = 3 + len(self._peer_cid)
# check we have enough space
if packet_start + header_size >= self._buffer_capacity:
raise QuicPacketBuilderStop
# determine ack epoch
if packet_type == QuicPacketType.INITIAL:
epoch = Epoch.INITIAL
elif packet_type == QuicPacketType.HANDSHAKE:
epoch = Epoch.HANDSHAKE
else:
epoch = Epoch.ONE_RTT
self._header_size = header_size
self._packet = QuicSentPacket(
epoch=epoch,
in_flight=False,
is_ack_eliciting=False,
is_crypto_packet=False,
packet_number=self._packet_number,
packet_type=packet_type,
)
self._packet_crypto = crypto
self._packet_start = packet_start
self._packet_type = packet_type
self.quic_logger_frames = self._packet.quic_logger_frames
buf.seek(self._packet_start + self._header_size)
def _end_packet(self) -> None:
"""
Ends the current packet.
"""
buf = self._buffer
packet_size = buf.tell() - self._packet_start
if packet_size > self._header_size:
# padding to ensure sufficient sample size
padding_size = (
PACKET_NUMBER_MAX_SIZE
- PACKET_NUMBER_SEND_SIZE
+ self._header_size
- packet_size
)
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if (
self._is_client or self._packet.is_ack_eliciting
) and self._packet_type == QuicPacketType.INITIAL:
self._datagram_needs_padding = True
# For datagrams containing 1-RTT data, we *must* apply the padding
# inside the packet, we cannot tack bytes onto the end of the
# datagram.
if (
self._datagram_needs_padding
and self._packet_type == QuicPacketType.ONE_RTT
):
if self.remaining_flight_space > padding_size:
padding_size = self.remaining_flight_space
self._datagram_needs_padding = False
# write padding
if padding_size > 0:
buf.push_bytes(bytes(padding_size))
packet_size += padding_size
self._packet.in_flight = True
# log frame
if self._quic_logger is not None:
self._packet.quic_logger_frames.append(
self._quic_logger.encode_padding_frame()
)
# write header
if self._packet_type != QuicPacketType.ONE_RTT:
length = (
packet_size
- self._header_size
+ PACKET_NUMBER_SEND_SIZE
+ self._packet_crypto.aead_tag_size
)
buf.seek(self._packet_start)
buf.push_uint8(
encode_long_header_first_byte(
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
)
)
buf.push_uint32(self._version)
buf.push_uint8(len(self._peer_cid))
buf.push_bytes(self._peer_cid)
buf.push_uint8(len(self._host_cid))
buf.push_bytes(self._host_cid)
if self._packet_type == QuicPacketType.INITIAL:
buf.push_uint_var(len(self._peer_token))
buf.push_bytes(self._peer_token)
buf.push_uint16(length | 0x4000)
buf.push_uint16(self._packet_number & 0xFFFF)
else:
buf.seek(self._packet_start)
buf.push_uint8(
PACKET_FIXED_BIT
| (self._spin_bit << 5)
| (self._packet_crypto.key_phase << 2)
| (PACKET_NUMBER_SEND_SIZE - 1)
)
buf.push_bytes(self._peer_cid)
buf.push_uint16(self._packet_number & 0xFFFF)
# encrypt in place
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
buf.seek(self._packet_start)
buf.push_bytes(
self._packet_crypto.encrypt_packet(
plain[0 : self._header_size],
plain[self._header_size : packet_size],
self._packet_number,
)
)
self._packet.sent_bytes = buf.tell() - self._packet_start
self._packets.append(self._packet)
if self._packet.in_flight:
self._datagram_flight_bytes += self._packet.sent_bytes
# Short header packets cannot be coalesced, we need a new datagram.
if self._packet_type == QuicPacketType.ONE_RTT:
self._flush_current_datagram()
self._packet_number += 1
else:
# "cancel" the packet
buf.seek(self._packet_start)
self._packet = None
self.quic_logger_frames = None
def _flush_current_datagram(self) -> None:
datagram_bytes = self._buffer.tell()
if datagram_bytes:
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if self._datagram_needs_padding:
extra_bytes = self._flight_capacity - self._buffer.tell()
if extra_bytes > 0:
self._buffer.push_bytes(bytes(extra_bytes))
self._datagram_flight_bytes += extra_bytes
datagram_bytes += extra_bytes
self._datagrams.append(self._buffer.data)
self._flight_bytes += self._datagram_flight_bytes
self._total_bytes += datagram_bytes
self._datagram_init = True
self._buffer.seek(0)

View File

@ -0,0 +1,98 @@
from collections.abc import Sequence
from typing import Any, Iterable, List, Optional
class RangeSet(Sequence):
def __init__(self, ranges: Iterable[range] = []):
self.__ranges: List[range] = []
for r in ranges:
assert r.step == 1
self.add(r.start, r.stop)
def add(self, start: int, stop: Optional[int] = None) -> None:
if stop is None:
stop = start + 1
assert stop > start
for i, r in enumerate(self.__ranges):
# the added range is entirely before current item, insert here
if stop < r.start:
self.__ranges.insert(i, range(start, stop))
return
# the added range is entirely after current item, keep looking
if start > r.stop:
continue
# the added range touches the current item, merge it
start = min(start, r.start)
stop = max(stop, r.stop)
while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop:
stop = max(self.__ranges[i + 1].stop, stop)
self.__ranges.pop(i + 1)
self.__ranges[i] = range(start, stop)
return
# the added range is entirely after all existing items, append it
self.__ranges.append(range(start, stop))
def bounds(self) -> range:
return range(self.__ranges[0].start, self.__ranges[-1].stop)
def shift(self) -> range:
return self.__ranges.pop(0)
def subtract(self, start: int, stop: int) -> None:
assert stop > start
i = 0
while i < len(self.__ranges):
r = self.__ranges[i]
# the removed range is entirely before current item, stop here
if stop <= r.start:
return
# the removed range is entirely after current item, keep looking
if start >= r.stop:
i += 1
continue
# the removed range completely covers the current item, remove it
if start <= r.start and stop >= r.stop:
self.__ranges.pop(i)
continue
# the removed range touches the current item
if start > r.start:
self.__ranges[i] = range(r.start, start)
if stop < r.stop:
self.__ranges.insert(i + 1, range(stop, r.stop))
else:
self.__ranges[i] = range(stop, r.stop)
i += 1
def __bool__(self) -> bool:
raise NotImplementedError
def __contains__(self, val: Any) -> bool:
for r in self.__ranges:
if val in r:
return True
return False
def __eq__(self, other: object) -> bool:
if not isinstance(other, RangeSet):
return NotImplemented
return self.__ranges == other.__ranges
def __getitem__(self, key: Any) -> range:
return self.__ranges[key]
def __len__(self) -> int:
return len(self.__ranges)
def __repr__(self) -> str:
return "RangeSet({})".format(repr(self.__ranges))

View File

@ -0,0 +1,389 @@
import logging
import math
from typing import Any, Callable, Dict, Iterable, List, Optional
from .congestion import cubic, reno # noqa
from .congestion.base import K_GRANULARITY, create_congestion_control
from .logger import QuicLoggerTrace
from .packet_builder import QuicDeliveryState, QuicSentPacket
from .rangeset import RangeSet
# loss detection
K_PACKET_THRESHOLD = 3
K_TIME_THRESHOLD = 9 / 8
K_MICRO_SECOND = 0.000001
K_SECOND = 1.0
class QuicPacketSpace:
def __init__(self) -> None:
self.ack_at: Optional[float] = None
self.ack_queue = RangeSet()
self.discarded = False
self.expected_packet_number = 0
self.largest_received_packet = -1
self.largest_received_time: Optional[float] = None
# sent packets and loss
self.ack_eliciting_in_flight = 0
self.largest_acked_packet = 0
self.loss_time: Optional[float] = None
self.sent_packets: Dict[int, QuicSentPacket] = {}
class QuicPacketPacer:
def __init__(self, *, max_datagram_size: int) -> None:
self._max_datagram_size = max_datagram_size
self.bucket_max: float = 0.0
self.bucket_time: float = 0.0
self.evaluation_time: float = 0.0
self.packet_time: Optional[float] = None
def next_send_time(self, now: float) -> float:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time <= 0:
return now + self.packet_time
return None
def update_after_send(self, now: float) -> None:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time < self.packet_time:
self.bucket_time = 0.0
else:
self.bucket_time -= self.packet_time
def update_bucket(self, now: float) -> None:
if now > self.evaluation_time:
self.bucket_time = min(
self.bucket_time + (now - self.evaluation_time), self.bucket_max
)
self.evaluation_time = now
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None:
pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND)
self.packet_time = max(
K_MICRO_SECOND, min(self._max_datagram_size / pacing_rate, K_SECOND)
)
self.bucket_max = (
max(
2 * self._max_datagram_size,
min(congestion_window // 4, 16 * self._max_datagram_size),
)
/ pacing_rate
)
if self.bucket_time > self.bucket_max:
self.bucket_time = self.bucket_max
class QuicPacketRecovery:
"""
Packet loss and congestion controller.
"""
def __init__(
self,
*,
congestion_control_algorithm: str,
initial_rtt: float,
max_datagram_size: int,
peer_completed_address_validation: bool,
send_probe: Callable[[], None],
logger: Optional[logging.LoggerAdapter] = None,
quic_logger: Optional[QuicLoggerTrace] = None,
) -> None:
self.max_ack_delay = 0.025
self.peer_completed_address_validation = peer_completed_address_validation
self.spaces: List[QuicPacketSpace] = []
# callbacks
self._logger = logger
self._quic_logger = quic_logger
self._send_probe = send_probe
# loss detection
self._pto_count = 0
self._rtt_initial = initial_rtt
self._rtt_initialized = False
self._rtt_latest = 0.0
self._rtt_min = math.inf
self._rtt_smoothed = 0.0
self._rtt_variance = 0.0
self._time_of_last_sent_ack_eliciting_packet = 0.0
# congestion control
self._cc = create_congestion_control(
congestion_control_algorithm, max_datagram_size=max_datagram_size
)
self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size)
@property
def bytes_in_flight(self) -> int:
return self._cc.bytes_in_flight
@property
def congestion_window(self) -> int:
return self._cc.congestion_window
def discard_space(self, space: QuicPacketSpace) -> None:
assert space in self.spaces
self._cc.on_packets_expired(
packets=filter(lambda x: x.in_flight, space.sent_packets.values())
)
space.sent_packets.clear()
space.ack_at = None
space.ack_eliciting_in_flight = 0
space.loss_time = None
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated()
def get_loss_detection_time(self) -> float:
# loss timer
loss_space = self._get_loss_space()
if loss_space is not None:
return loss_space.loss_time
# packet timer
if (
not self.peer_completed_address_validation
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
):
timeout = self.get_probe_timeout() * (2**self._pto_count)
return self._time_of_last_sent_ack_eliciting_packet + timeout
return None
def get_probe_timeout(self) -> float:
if not self._rtt_initialized:
return 2 * self._rtt_initial
return (
self._rtt_smoothed
+ max(4 * self._rtt_variance, K_GRANULARITY)
+ self.max_ack_delay
)
def on_ack_received(
self,
*,
ack_rangeset: RangeSet,
ack_delay: float,
now: float,
space: QuicPacketSpace,
) -> None:
"""
Update metrics as the result of an ACK being received.
"""
is_ack_eliciting = False
largest_acked = ack_rangeset.bounds().stop - 1
largest_newly_acked = None
largest_sent_time = None
if largest_acked > space.largest_acked_packet:
space.largest_acked_packet = largest_acked
for packet_number in sorted(space.sent_packets.keys()):
if packet_number > largest_acked:
break
if packet_number in ack_rangeset:
# remove packet and update counters
packet = space.sent_packets.pop(packet_number)
if packet.is_ack_eliciting:
is_ack_eliciting = True
space.ack_eliciting_in_flight -= 1
if packet.in_flight:
self._cc.on_packet_acked(packet=packet, now=now)
largest_newly_acked = packet_number
largest_sent_time = packet.sent_time
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.ACKED, *args)
# nothing to do if there are no newly acked packets
if largest_newly_acked is None:
return
if largest_acked == largest_newly_acked and is_ack_eliciting:
latest_rtt = now - largest_sent_time
log_rtt = True
# limit ACK delay to max_ack_delay
ack_delay = min(ack_delay, self.max_ack_delay)
# update RTT estimate, which cannot be < 1 ms
self._rtt_latest = max(latest_rtt, 0.001)
if self._rtt_latest < self._rtt_min:
self._rtt_min = self._rtt_latest
if self._rtt_latest > self._rtt_min + ack_delay:
self._rtt_latest -= ack_delay
if not self._rtt_initialized:
self._rtt_initialized = True
self._rtt_variance = latest_rtt / 2
self._rtt_smoothed = latest_rtt
else:
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
self._rtt_min - self._rtt_latest
)
self._rtt_smoothed = (
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
)
# inform congestion controller
self._cc.on_rtt_measurement(now=now, rtt=latest_rtt)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
else:
log_rtt = False
self._detect_loss(now=now, space=space)
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated(log_rtt=log_rtt)
def on_loss_detection_timeout(self, *, now: float) -> None:
loss_space = self._get_loss_space()
if loss_space is not None:
self._detect_loss(now=now, space=loss_space)
else:
self._pto_count += 1
self.reschedule_data(now=now)
def on_packet_sent(self, *, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
space.sent_packets[packet.packet_number] = packet
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight += 1
if packet.in_flight:
if packet.is_ack_eliciting:
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
# add packet to bytes in flight
self._cc.on_packet_sent(packet=packet)
if self._quic_logger is not None:
self._log_metrics_updated()
def reschedule_data(self, *, now: float) -> None:
"""
Schedule some data for retransmission.
"""
# if there is any outstanding CRYPTO, retransmit it
crypto_scheduled = False
for space in self.spaces:
packets = tuple(
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
)
if packets:
self._on_packets_lost(now=now, packets=packets, space=space)
crypto_scheduled = True
if crypto_scheduled and self._logger is not None:
self._logger.debug("Scheduled CRYPTO data for retransmission")
# ensure an ACK-elliciting packet is sent
self._send_probe()
def _detect_loss(self, *, now: float, space: QuicPacketSpace) -> None:
"""
Check whether any packets should be declared lost.
"""
loss_delay = K_TIME_THRESHOLD * (
max(self._rtt_latest, self._rtt_smoothed)
if self._rtt_initialized
else self._rtt_initial
)
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
time_threshold = now - loss_delay
lost_packets = []
space.loss_time = None
for packet_number, packet in space.sent_packets.items():
if packet_number > space.largest_acked_packet:
break
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
lost_packets.append(packet)
else:
packet_loss_time = packet.sent_time + loss_delay
if space.loss_time is None or space.loss_time > packet_loss_time:
space.loss_time = packet_loss_time
self._on_packets_lost(now=now, packets=lost_packets, space=space)
def _get_loss_space(self) -> Optional[QuicPacketSpace]:
loss_space = None
for space in self.spaces:
if space.loss_time is not None and (
loss_space is None or space.loss_time < loss_space.loss_time
):
loss_space = space
return loss_space
def _log_metrics_updated(self, log_rtt=False) -> None:
data: Dict[str, Any] = self._cc.get_log_data()
if log_rtt:
data.update(
{
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
}
)
self._quic_logger.log_event(
category="recovery", event="metrics_updated", data=data
)
def _on_packets_lost(
self, *, now: float, packets: Iterable[QuicSentPacket], space: QuicPacketSpace
) -> None:
lost_packets_cc = []
for packet in packets:
del space.sent_packets[packet.packet_number]
if packet.in_flight:
lost_packets_cc.append(packet)
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight -= 1
if self._quic_logger is not None:
self._quic_logger.log_event(
category="recovery",
event="packet_lost",
data={
"type": self._quic_logger.packet_type(packet.packet_type),
"packet_number": packet.packet_number,
},
)
self._log_metrics_updated()
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.LOST, *args)
# inform congestion controller
if lost_packets_cc:
self._cc.on_packets_lost(now=now, packets=lost_packets_cc)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
if self._quic_logger is not None:
self._log_metrics_updated()

View File

@ -0,0 +1,53 @@
import ipaddress
from typing import Tuple
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from ..buffer import Buffer
from ..tls import pull_opaque, push_opaque
from .connection import NetworkAddress
def encode_address(addr: NetworkAddress) -> bytes:
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
class QuicRetryTokenHandler:
def __init__(self) -> None:
self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
def create_token(
self,
addr: NetworkAddress,
original_destination_connection_id: bytes,
retry_source_connection_id: bytes,
) -> bytes:
buf = Buffer(capacity=512)
push_opaque(buf, 1, encode_address(addr))
push_opaque(buf, 1, original_destination_connection_id)
push_opaque(buf, 1, retry_source_connection_id)
return self._key.public_key().encrypt(
buf.data,
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
),
)
def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]:
buf = Buffer(
data=self._key.decrypt(
token,
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
)
encoded_addr = pull_opaque(buf, 1)
original_destination_connection_id = pull_opaque(buf, 1)
retry_source_connection_id = pull_opaque(buf, 1)
if encoded_addr != encode_address(addr):
raise ValueError("Remote address does not match.")
return original_destination_connection_id, retry_source_connection_id

View File

@ -0,0 +1,364 @@
from typing import Optional
from . import events
from .packet import (
QuicErrorCode,
QuicResetStreamFrame,
QuicStopSendingFrame,
QuicStreamFrame,
)
from .packet_builder import QuicDeliveryState
from .rangeset import RangeSet
class FinalSizeError(Exception):
pass
class StreamFinishedError(Exception):
pass
class QuicStreamReceiver:
"""
The receive part of a QUIC stream.
It finishes:
- immediately for a send-only stream
- upon reception of a STREAM_RESET frame
- upon reception of a data frame with the FIN bit set
"""
def __init__(self, stream_id: Optional[int], readable: bool) -> None:
self.highest_offset = 0 # the highest offset ever seen
self.is_finished = False
self.stop_pending = False
self._buffer = bytearray()
self._buffer_start = 0 # the offset for the start of the buffer
self._final_size: Optional[int] = None
self._ranges = RangeSet()
self._stream_id = stream_id
self._stop_error_code: Optional[int] = None
def get_stop_frame(self) -> QuicStopSendingFrame:
self.stop_pending = False
return QuicStopSendingFrame(
error_code=self._stop_error_code,
stream_id=self._stream_id,
)
def starting_offset(self) -> int:
return self._buffer_start
def handle_frame(
self, frame: QuicStreamFrame
) -> Optional[events.StreamDataReceived]:
"""
Handle a frame of received data.
"""
pos = frame.offset - self._buffer_start
count = len(frame.data)
frame_end = frame.offset + count
# we should receive no more data beyond FIN!
if self._final_size is not None:
if frame_end > self._final_size:
raise FinalSizeError("Data received beyond final size")
elif frame.fin and frame_end != self._final_size:
raise FinalSizeError("Cannot change final size")
if frame.fin:
self._final_size = frame_end
if frame_end > self.highest_offset:
self.highest_offset = frame_end
# fast path: new in-order chunk
if pos == 0 and count and not self._buffer:
self._buffer_start += count
if frame.fin:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
return events.StreamDataReceived(
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
)
# discard duplicate data
if pos < 0:
frame.data = frame.data[-pos:]
frame.offset -= pos
pos = 0
count = len(frame.data)
# marked received range
if frame_end > frame.offset:
self._ranges.add(frame.offset, frame_end)
# add new data
gap = pos - len(self._buffer)
if gap > 0:
self._buffer += bytearray(gap)
self._buffer[pos : pos + count] = frame.data
# return data from the front of the buffer
data = self._pull_data()
end_stream = self._buffer_start == self._final_size
if end_stream:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
if data or end_stream:
return events.StreamDataReceived(
data=data, end_stream=end_stream, stream_id=self._stream_id
)
else:
return None
def handle_reset(
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
) -> Optional[events.StreamReset]:
"""
Handle an abrupt termination of the receiving part of the QUIC stream.
"""
if self._final_size is not None and final_size != self._final_size:
raise FinalSizeError("Cannot change final size")
# we are done receiving
self._final_size = final_size
self.is_finished = True
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a STOP_SENDING is ACK'd.
"""
if delivery != QuicDeliveryState.ACKED:
self.stop_pending = True
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
"""
Request the peer stop sending data on the QUIC stream.
"""
self._stop_error_code = error_code
self.stop_pending = True
def _pull_data(self) -> bytes:
"""
Remove data from the front of the buffer.
"""
try:
has_data_to_read = self._ranges[0].start == self._buffer_start
except IndexError:
has_data_to_read = False
if not has_data_to_read:
return b""
r = self._ranges.shift()
pos = r.stop - r.start
data = bytes(self._buffer[:pos])
del self._buffer[:pos]
self._buffer_start = r.stop
return data
class QuicStreamSender:
"""
The send part of a QUIC stream.
It finishes:
- immediately for a receive-only stream
- upon acknowledgement of a STREAM_RESET frame
- upon acknowledgement of a data frame with the FIN bit set
"""
def __init__(self, stream_id: Optional[int], writable: bool) -> None:
self.buffer_is_empty = True
self.highest_offset = 0
self.is_finished = not writable
self.reset_pending = False
self._acked = RangeSet()
self._acked_fin = False
self._buffer = bytearray()
self._buffer_fin: Optional[int] = None
self._buffer_start = 0 # the offset for the start of the buffer
self._buffer_stop = 0 # the offset for the stop of the buffer
self._pending = RangeSet()
self._pending_eof = False
self._reset_error_code: Optional[int] = None
self._stream_id = stream_id
@property
def next_offset(self) -> int:
"""
The offset for the next frame to send.
This is used to determine the space needed for the frame's `offset` field.
"""
try:
return self._pending[0].start
except IndexError:
return self._buffer_stop
def get_frame(
self, max_size: int, max_offset: Optional[int] = None
) -> Optional[QuicStreamFrame]:
"""
Get a frame of data to send.
"""
assert self._reset_error_code is None, "cannot call get_frame() after reset()"
# get the first pending data range
try:
r = self._pending[0]
except IndexError:
if self._pending_eof:
# FIN only
self._pending_eof = False
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
self.buffer_is_empty = True
return None
# apply flow control
start = r.start
stop = min(r.stop, start + max_size)
if max_offset is not None and stop > max_offset:
stop = max_offset
if stop <= start:
return None
# create frame
frame = QuicStreamFrame(
data=bytes(
self._buffer[start - self._buffer_start : stop - self._buffer_start]
),
offset=start,
)
self._pending.subtract(start, stop)
# track the highest offset ever sent
if stop > self.highest_offset:
self.highest_offset = stop
# if the buffer is empty and EOF was written, set the FIN bit
if self._buffer_fin == stop:
frame.fin = True
self._pending_eof = False
return frame
def get_reset_frame(self) -> QuicResetStreamFrame:
self.reset_pending = False
return QuicResetStreamFrame(
error_code=self._reset_error_code,
final_size=self.highest_offset,
stream_id=self._stream_id,
)
def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
) -> None:
"""
Callback when sent data is ACK'd.
"""
# If the frame had the FIN bit set, its end MUST match otherwise
# we have a programming error.
assert (
not fin or stop == self._buffer_fin
), "on_data_delivered() was called with inconsistent fin / stop"
# If a reset has been requested, stop processing data delivery.
# The transition to the finished state only depends on the reset
# being acknowledged.
if self._reset_error_code is not None:
return
if delivery == QuicDeliveryState.ACKED:
if stop > start:
# Some data has been ACK'd, discard it.
self._acked.add(start, stop)
first_range = self._acked[0]
if first_range.start == self._buffer_start:
size = first_range.stop - first_range.start
self._acked.shift()
self._buffer_start += size
del self._buffer[:size]
if fin:
# The FIN has been ACK'd.
self._acked_fin = True
if self._buffer_start == self._buffer_fin and self._acked_fin:
# All data and the FIN have been ACK'd, we're done sending.
self.is_finished = True
else:
if stop > start:
# Some data has been lost, reschedule it.
self.buffer_is_empty = False
self._pending.add(start, stop)
if fin:
# The FIN has been lost, reschedule it.
self.buffer_is_empty = False
self._pending_eof = True
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a reset is ACK'd.
"""
if delivery == QuicDeliveryState.ACKED:
# The reset has been ACK'd, we're done sending.
self.is_finished = True
else:
# The reset has been lost, reschedule it.
self.reset_pending = True
def reset(self, error_code: int) -> None:
"""
Abruptly terminate the sending part of the QUIC stream.
"""
assert self._reset_error_code is None, "cannot call reset() more than once"
self._reset_error_code = error_code
self.reset_pending = True
# Prevent any more data from being sent or re-sent.
self.buffer_is_empty = True
def write(self, data: bytes, end_stream: bool = False) -> None:
"""
Write some data bytes to the QUIC stream.
"""
assert self._buffer_fin is None, "cannot call write() after FIN"
assert self._reset_error_code is None, "cannot call write() after reset()"
size = len(data)
if size:
self.buffer_is_empty = False
self._pending.add(self._buffer_stop, self._buffer_stop + size)
self._buffer += data
self._buffer_stop += size
if end_stream:
self.buffer_is_empty = False
self._buffer_fin = self._buffer_stop
self._pending_eof = True
class QuicStream:
def __init__(
self,
stream_id: Optional[int] = None,
max_stream_data_local: int = 0,
max_stream_data_remote: int = 0,
readable: bool = True,
writable: bool = True,
) -> None:
self.is_blocked = False
self.max_stream_data_local = max_stream_data_local
self.max_stream_data_local_sent = max_stream_data_local
self.max_stream_data_remote = max_stream_data_remote
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
self.stream_id = stream_id
@property
def is_finished(self) -> bool:
return self.receiver.is_finished and self.sender.is_finished