Code
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,163 @@
|
||||
from dataclasses import dataclass, field
|
||||
from os import PathLike
|
||||
from re import split
|
||||
from typing import Any, List, Optional, TextIO, Union
|
||||
|
||||
from ..tls import (
|
||||
CipherSuite,
|
||||
SessionTicket,
|
||||
load_pem_private_key,
|
||||
load_pem_x509_certificates,
|
||||
)
|
||||
from .logger import QuicLogger
|
||||
from .packet import QuicProtocolVersion
|
||||
|
||||
SMALLEST_MAX_DATAGRAM_SIZE = 1200
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicConfiguration:
|
||||
"""
|
||||
A QUIC configuration.
|
||||
"""
|
||||
|
||||
alpn_protocols: Optional[List[str]] = None
|
||||
"""
|
||||
A list of supported ALPN protocols.
|
||||
"""
|
||||
|
||||
congestion_control_algorithm: str = "reno"
|
||||
"""
|
||||
The name of the congestion control algorithm to use.
|
||||
|
||||
Currently supported algorithms: `"reno", `"cubic"`.
|
||||
"""
|
||||
|
||||
connection_id_length: int = 8
|
||||
"""
|
||||
The length in bytes of local connection IDs.
|
||||
"""
|
||||
|
||||
idle_timeout: float = 60.0
|
||||
"""
|
||||
The idle timeout in seconds.
|
||||
|
||||
The connection is terminated if nothing is received for the given duration.
|
||||
"""
|
||||
|
||||
is_client: bool = True
|
||||
"""
|
||||
Whether this is the client side of the QUIC connection.
|
||||
"""
|
||||
|
||||
max_data: int = 1048576
|
||||
"""
|
||||
Connection-wide flow control limit.
|
||||
"""
|
||||
|
||||
max_datagram_size: int = SMALLEST_MAX_DATAGRAM_SIZE
|
||||
"""
|
||||
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
|
||||
"""
|
||||
|
||||
max_stream_data: int = 1048576
|
||||
"""
|
||||
Per-stream flow control limit.
|
||||
"""
|
||||
|
||||
quic_logger: Optional[QuicLogger] = None
|
||||
"""
|
||||
The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to.
|
||||
"""
|
||||
|
||||
secrets_log_file: TextIO = None
|
||||
"""
|
||||
A file-like object in which to log traffic secrets.
|
||||
|
||||
This is useful to analyze traffic captures with Wireshark.
|
||||
"""
|
||||
|
||||
server_name: Optional[str] = None
|
||||
"""
|
||||
The server name to use when verifying the server's TLS certificate, which
|
||||
can either be a DNS name or an IP address.
|
||||
|
||||
If it is a DNS name, it is also sent during the TLS handshake in the
|
||||
Server Name Indication (SNI) extension.
|
||||
|
||||
.. note:: This is only used by clients.
|
||||
"""
|
||||
|
||||
session_ticket: Optional[SessionTicket] = None
|
||||
"""
|
||||
The TLS session ticket which should be used for session resumption.
|
||||
"""
|
||||
|
||||
token: bytes = b""
|
||||
"""
|
||||
The address validation token that can be used to validate future connections.
|
||||
|
||||
.. note:: This is only used by clients.
|
||||
"""
|
||||
|
||||
# For internal purposes, not guaranteed to be stable.
|
||||
cadata: Optional[bytes] = None
|
||||
cafile: Optional[str] = None
|
||||
capath: Optional[str] = None
|
||||
certificate: Any = None
|
||||
certificate_chain: List[Any] = field(default_factory=list)
|
||||
cipher_suites: Optional[List[CipherSuite]] = None
|
||||
initial_rtt: float = 0.1
|
||||
max_datagram_frame_size: Optional[int] = None
|
||||
original_version: Optional[int] = None
|
||||
private_key: Any = None
|
||||
quantum_readiness_test: bool = False
|
||||
supported_versions: List[int] = field(
|
||||
default_factory=lambda: [
|
||||
QuicProtocolVersion.VERSION_1,
|
||||
QuicProtocolVersion.VERSION_2,
|
||||
]
|
||||
)
|
||||
verify_mode: Optional[int] = None
|
||||
|
||||
def load_cert_chain(
|
||||
self,
|
||||
certfile: PathLike,
|
||||
keyfile: Optional[PathLike] = None,
|
||||
password: Optional[Union[bytes, str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a private key and the corresponding certificate.
|
||||
"""
|
||||
with open(certfile, "rb") as fp:
|
||||
boundary = b"-----BEGIN PRIVATE KEY-----\n"
|
||||
chunks = split(b"\n" + boundary, fp.read())
|
||||
certificates = load_pem_x509_certificates(chunks[0])
|
||||
if len(chunks) == 2:
|
||||
private_key = boundary + chunks[1]
|
||||
self.private_key = load_pem_private_key(private_key)
|
||||
self.certificate = certificates[0]
|
||||
self.certificate_chain = certificates[1:]
|
||||
|
||||
if keyfile is not None:
|
||||
with open(keyfile, "rb") as fp:
|
||||
self.private_key = load_pem_private_key(
|
||||
fp.read(),
|
||||
password=password.encode("utf8")
|
||||
if isinstance(password, str)
|
||||
else password,
|
||||
)
|
||||
|
||||
def load_verify_locations(
|
||||
self,
|
||||
cafile: Optional[str] = None,
|
||||
capath: Optional[str] = None,
|
||||
cadata: Optional[bytes] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a set of "certification authority" (CA) certificates used to
|
||||
validate other peers' certificates.
|
||||
"""
|
||||
self.cafile = cafile
|
||||
self.capath = capath
|
||||
self.cadata = cadata
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,128 @@
|
||||
import abc
|
||||
from typing import Any, Dict, Iterable, Optional, Protocol
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
|
||||
K_GRANULARITY = 0.001 # seconds
|
||||
K_INITIAL_WINDOW = 10
|
||||
K_MINIMUM_WINDOW = 2
|
||||
|
||||
|
||||
class QuicCongestionControl(abc.ABC):
|
||||
"""
|
||||
Base class for congestion control implementations.
|
||||
"""
|
||||
|
||||
bytes_in_flight: int = 0
|
||||
congestion_window: int = 0
|
||||
ssthresh: Optional[int] = None
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_lost(
|
||||
self, *, now: float, packets: Iterable[QuicSentPacket]
|
||||
) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ...
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
|
||||
if self.ssthresh is not None:
|
||||
data["ssthresh"] = self.ssthresh
|
||||
return data
|
||||
|
||||
|
||||
class QuicCongestionControlFactory(Protocol):
|
||||
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ...
|
||||
|
||||
|
||||
class QuicRttMonitor:
|
||||
"""
|
||||
Roundtrip time monitor for HyStart.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._increases = 0
|
||||
self._last_time = None
|
||||
self._ready = False
|
||||
self._size = 5
|
||||
|
||||
self._filtered_min: Optional[float] = None
|
||||
|
||||
self._sample_idx = 0
|
||||
self._sample_max: Optional[float] = None
|
||||
self._sample_min: Optional[float] = None
|
||||
self._sample_time = 0.0
|
||||
self._samples = [0.0 for i in range(self._size)]
|
||||
|
||||
def add_rtt(self, *, rtt: float) -> None:
|
||||
self._samples[self._sample_idx] = rtt
|
||||
self._sample_idx += 1
|
||||
|
||||
if self._sample_idx >= self._size:
|
||||
self._sample_idx = 0
|
||||
self._ready = True
|
||||
|
||||
if self._ready:
|
||||
self._sample_max = self._samples[0]
|
||||
self._sample_min = self._samples[0]
|
||||
for sample in self._samples[1:]:
|
||||
if sample < self._sample_min:
|
||||
self._sample_min = sample
|
||||
elif sample > self._sample_max:
|
||||
self._sample_max = sample
|
||||
|
||||
def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
|
||||
if now > self._sample_time + K_GRANULARITY:
|
||||
self.add_rtt(rtt=rtt)
|
||||
self._sample_time = now
|
||||
|
||||
if self._ready:
|
||||
if self._filtered_min is None or self._filtered_min > self._sample_max:
|
||||
self._filtered_min = self._sample_max
|
||||
|
||||
delta = self._sample_min - self._filtered_min
|
||||
if delta * 4 >= self._filtered_min:
|
||||
self._increases += 1
|
||||
if self._increases >= self._size:
|
||||
return True
|
||||
elif delta > 0:
|
||||
self._increases = 0
|
||||
return False
|
||||
|
||||
|
||||
_factories: Dict[str, QuicCongestionControlFactory] = {}
|
||||
|
||||
|
||||
def create_congestion_control(
|
||||
name: str, *, max_datagram_size: int
|
||||
) -> QuicCongestionControl:
|
||||
"""
|
||||
Create an instance of the `name` congestion control algorithm.
|
||||
"""
|
||||
try:
|
||||
factory = _factories[name]
|
||||
except KeyError:
|
||||
raise Exception(f"Unknown congestion control algorithm: {name}")
|
||||
return factory(max_datagram_size=max_datagram_size)
|
||||
|
||||
|
||||
def register_congestion_control(
|
||||
name: str, factory: QuicCongestionControlFactory
|
||||
) -> None:
|
||||
"""
|
||||
Register a congestion control algorithm named `name`.
|
||||
"""
|
||||
_factories[name] = factory
|
||||
@ -0,0 +1,212 @@
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_INITIAL_WINDOW,
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions)
|
||||
K_CUBIC_C = 0.4
|
||||
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
|
||||
K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity
|
||||
|
||||
|
||||
def better_cube_root(x: float) -> float:
|
||||
if x < 0:
|
||||
# avoid precision errors that make the cube root returns an imaginary number
|
||||
return -((-x) ** (1.0 / 3.0))
|
||||
else:
|
||||
return (x) ** (1.0 / 3.0)
|
||||
|
||||
|
||||
class CubicCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
Cubic congestion control implementation for aioquic
|
||||
"""
|
||||
|
||||
def __init__(self, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
# increase by one segment
|
||||
self.additive_increase_factor: int = max_datagram_size
|
||||
self._max_datagram_size: int = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
self.rtt = 0.02 # starting RTT is considered to be 20ms
|
||||
|
||||
self.reset()
|
||||
|
||||
self.last_ack = 0.0
|
||||
|
||||
def W_cubic(self, t) -> int:
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments)
|
||||
return int(target_segments * self._max_datagram_size)
|
||||
|
||||
def is_reno_friendly(self, t) -> bool:
|
||||
return self.W_cubic(t) < self._W_est
|
||||
|
||||
def is_concave(self) -> bool:
|
||||
return self.congestion_window < self._W_max
|
||||
|
||||
def reset(self) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size
|
||||
self.ssthresh = None
|
||||
|
||||
self._first_slow_start = True
|
||||
self._starting_congestion_avoidance = False
|
||||
self.K: float = 0.0
|
||||
self._W_est = 0
|
||||
self._cwnd_epoch = 0
|
||||
self._t_epoch = 0.0
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
self.last_ack = packet.sent_time
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
if self._first_slow_start and not self._starting_congestion_avoidance:
|
||||
# exiting slow start without having a loss
|
||||
self._first_slow_start = False
|
||||
self._W_max = self.congestion_window
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
# initialize the variables used at start of congestion avoidance
|
||||
if self._starting_congestion_avoidance:
|
||||
self._starting_congestion_avoidance = False
|
||||
self._first_slow_start = False
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
self._W_est = int(
|
||||
self._W_est
|
||||
+ self.additive_increase_factor
|
||||
* (packet.sent_bytes / self.congestion_window)
|
||||
)
|
||||
|
||||
t = now - self._t_epoch
|
||||
|
||||
target: int = 0
|
||||
W_cubic = self.W_cubic(t + self.rtt)
|
||||
if W_cubic < self.congestion_window:
|
||||
target = self.congestion_window
|
||||
elif W_cubic > 1.5 * self.congestion_window:
|
||||
target = int(self.congestion_window * 1.5)
|
||||
else:
|
||||
target = W_cubic
|
||||
|
||||
if self.is_reno_friendly(t):
|
||||
# reno friendly region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region)
|
||||
self.congestion_window = self._W_est
|
||||
elif self.is_concave():
|
||||
# concave region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# convex region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
if self.last_ack == 0.0:
|
||||
return
|
||||
elapsed_idle = packet.sent_time - self.last_ack
|
||||
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
|
||||
self.reset()
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
|
||||
# Normal congestion handle, can't be used in same time as fast convergence
|
||||
# self._W_max = self.congestion_window
|
||||
|
||||
# fast convergence
|
||||
if self._W_max is not None and self.congestion_window < self._W_max:
|
||||
self._W_max = int(
|
||||
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
|
||||
)
|
||||
else:
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
# normal congestion MD
|
||||
flight_size = self.bytes_in_flight
|
||||
new_ssthresh = max(
|
||||
int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = new_ssthresh
|
||||
self.congestion_window = max(
|
||||
self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size
|
||||
)
|
||||
|
||||
# restart a new congestion avoidance phase
|
||||
self._starting_congestion_avoidance = True
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
self.rtt = rtt
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
rtt=rtt, now=now
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = super().get_log_data()
|
||||
|
||||
data["cubic-wmax"] = int(self._W_max)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
register_congestion_control("cubic", CubicCongestionControl)
|
||||
@ -0,0 +1,77 @@
|
||||
from typing import Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
K_LOSS_REDUCTION_FACTOR = 0.5
|
||||
|
||||
|
||||
class RenoCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
New Reno congestion control.
|
||||
"""
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
self._congestion_stash = 0
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
# don't increase window in congestion recovery
|
||||
if packet.sent_time <= self._congestion_recovery_start_time:
|
||||
return
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
self._congestion_stash += packet.sent_bytes
|
||||
count = self._congestion_stash // self.congestion_window
|
||||
if count:
|
||||
self._congestion_stash -= count * self.congestion_window
|
||||
self.congestion_window += count * self._max_datagram_size
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
self.congestion_window = max(
|
||||
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
# TODO : collapse congestion window if persistent congestion
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
now=now, rtt=rtt
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
|
||||
register_congestion_control("reno", RenoCongestionControl)
|
||||
3623
Code/venv/lib/python3.13/site-packages/aioquic/quic/connection.py
Normal file
3623
Code/venv/lib/python3.13/site-packages/aioquic/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
246
Code/venv/lib/python3.13/site-packages/aioquic/quic/crypto.py
Normal file
246
Code/venv/lib/python3.13/site-packages/aioquic/quic/crypto.py
Normal file
@ -0,0 +1,246 @@
|
||||
import binascii
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from .._crypto import AEAD, CryptoError, HeaderProtection
|
||||
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
|
||||
from .packet import (
|
||||
QuicProtocolVersion,
|
||||
decode_packet_number,
|
||||
is_long_header,
|
||||
)
|
||||
|
||||
CIPHER_SUITES = {
|
||||
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
|
||||
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
|
||||
}
|
||||
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
|
||||
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
|
||||
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
|
||||
SAMPLE_SIZE = 16
|
||||
|
||||
|
||||
Callback = Callable[[str], None]
|
||||
|
||||
|
||||
def NoCallback(trigger: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class KeyUnavailableError(CryptoError):
|
||||
pass
|
||||
|
||||
|
||||
def derive_key_iv_hp(
|
||||
*, cipher_suite: CipherSuite, secret: bytes, version: int
|
||||
) -> Tuple[bytes, bytes, bytes]:
|
||||
algorithm = cipher_suite_hash(cipher_suite)
|
||||
if cipher_suite in [
|
||||
CipherSuite.AES_256_GCM_SHA384,
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256,
|
||||
]:
|
||||
key_size = 32
|
||||
else:
|
||||
key_size = 16
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
|
||||
)
|
||||
|
||||
|
||||
class CryptoContext:
|
||||
def __init__(
|
||||
self,
|
||||
key_phase: int = 0,
|
||||
setup_cb: Callback = NoCallback,
|
||||
teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead: Optional[AEAD] = None
|
||||
self.cipher_suite: Optional[CipherSuite] = None
|
||||
self.hp: Optional[HeaderProtection] = None
|
||||
self.key_phase = key_phase
|
||||
self.secret: Optional[bytes] = None
|
||||
self.version: Optional[int] = None
|
||||
self._setup_cb = setup_cb
|
||||
self._teardown_cb = teardown_cb
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> Tuple[bytes, bytes, int, bool]:
|
||||
if self.aead is None:
|
||||
raise KeyUnavailableError("Decryption key is not available")
|
||||
|
||||
# header protection
|
||||
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
|
||||
first_byte = plain_header[0]
|
||||
|
||||
# packet number
|
||||
pn_length = (first_byte & 0x03) + 1
|
||||
packet_number = decode_packet_number(
|
||||
packet_number, pn_length * 8, expected_packet_number
|
||||
)
|
||||
|
||||
# detect key phase change
|
||||
crypto = self
|
||||
if not is_long_header(first_byte):
|
||||
key_phase = (first_byte & 4) >> 2
|
||||
if key_phase != self.key_phase:
|
||||
crypto = next_key_phase(self)
|
||||
|
||||
# payload protection
|
||||
payload = crypto.aead.decrypt(
|
||||
packet[len(plain_header) :], plain_header, packet_number
|
||||
)
|
||||
|
||||
return plain_header, payload, packet_number, crypto != self
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
assert self.is_valid(), "Encryption key is not available"
|
||||
|
||||
# payload protection
|
||||
protected_payload = self.aead.encrypt(
|
||||
plain_payload, plain_header, packet_number
|
||||
)
|
||||
|
||||
# header protection
|
||||
return self.hp.apply(plain_header, protected_payload)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.aead is not None
|
||||
|
||||
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
|
||||
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
|
||||
|
||||
key, iv, hp = derive_key_iv_hp(
|
||||
cipher_suite=cipher_suite,
|
||||
secret=secret,
|
||||
version=version,
|
||||
)
|
||||
self.aead = AEAD(aead_cipher_name, key, iv)
|
||||
self.cipher_suite = cipher_suite
|
||||
self.hp = HeaderProtection(hp_cipher_name, hp)
|
||||
self.secret = secret
|
||||
self.version = version
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb("tls")
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.aead = None
|
||||
self.cipher_suite = None
|
||||
self.hp = None
|
||||
self.secret = None
|
||||
|
||||
# trigger callback
|
||||
self._teardown_cb("tls")
|
||||
|
||||
|
||||
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
|
||||
self.aead = crypto.aead
|
||||
self.key_phase = crypto.key_phase
|
||||
self.secret = crypto.secret
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb(trigger)
|
||||
|
||||
|
||||
def next_key_phase(self: CryptoContext) -> CryptoContext:
|
||||
algorithm = cipher_suite_hash(self.cipher_suite)
|
||||
|
||||
crypto = CryptoContext(key_phase=int(not self.key_phase))
|
||||
crypto.setup(
|
||||
cipher_suite=self.cipher_suite,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, self.secret, b"quic ku", b"", algorithm.digest_size
|
||||
),
|
||||
version=self.version,
|
||||
)
|
||||
return crypto
|
||||
|
||||
|
||||
class CryptoPair:
|
||||
def __init__(
|
||||
self,
|
||||
recv_setup_cb: Callback = NoCallback,
|
||||
recv_teardown_cb: Callback = NoCallback,
|
||||
send_setup_cb: Callback = NoCallback,
|
||||
send_teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead_tag_size = 16
|
||||
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
|
||||
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
|
||||
self._update_key_requested = False
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> Tuple[bytes, bytes, int]:
|
||||
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
|
||||
packet, encrypted_offset, expected_packet_number
|
||||
)
|
||||
if update_key:
|
||||
self._update_key("remote_update")
|
||||
return plain_header, payload, packet_number
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
if self._update_key_requested:
|
||||
self._update_key("local_update")
|
||||
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
|
||||
|
||||
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
|
||||
if is_client:
|
||||
recv_label, send_label = b"server in", b"client in"
|
||||
else:
|
||||
recv_label, send_label = b"client in", b"server in"
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
initial_salt = INITIAL_SALT_VERSION_2
|
||||
else:
|
||||
initial_salt = INITIAL_SALT_VERSION_1
|
||||
|
||||
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
|
||||
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
|
||||
self.recv.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, recv_label, b"", algorithm.digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
self.send.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, send_label, b"", algorithm.digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.recv.teardown()
|
||||
self.send.teardown()
|
||||
|
||||
def update_key(self) -> None:
|
||||
self._update_key_requested = True
|
||||
|
||||
@property
|
||||
def key_phase(self) -> int:
|
||||
if self._update_key_requested:
|
||||
return int(not self.recv.key_phase)
|
||||
else:
|
||||
return self.recv.key_phase
|
||||
|
||||
def _update_key(self, trigger: str) -> None:
|
||||
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
|
||||
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
|
||||
self._update_key_requested = False
|
||||
126
Code/venv/lib/python3.13/site-packages/aioquic/quic/events.py
Normal file
126
Code/venv/lib/python3.13/site-packages/aioquic/quic/events.py
Normal file
@ -0,0 +1,126 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuicEvent:
|
||||
"""
|
||||
Base class for QUIC events.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdIssued(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdRetired(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionTerminated(QuicEvent):
|
||||
"""
|
||||
The ConnectionTerminated event is fired when the QUIC connection is terminated.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code which was specified when closing the connection."
|
||||
|
||||
frame_type: Optional[int]
|
||||
"The frame type which caused the connection to be closed, or `None`."
|
||||
|
||||
reason_phrase: str
|
||||
"The human-readable reason for which the connection was closed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatagramFrameReceived(QuicEvent):
|
||||
"""
|
||||
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandshakeCompleted(QuicEvent):
|
||||
"""
|
||||
The HandshakeCompleted event is fired when the TLS handshake completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: Optional[str]
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
early_data_accepted: bool
|
||||
"Whether early (0-RTT) data was accepted by the remote peer."
|
||||
|
||||
session_resumed: bool
|
||||
"Whether a TLS session was resumed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class PingAcknowledged(QuicEvent):
|
||||
"""
|
||||
The PingAcknowledged event is fired when a PING frame is acknowledged.
|
||||
"""
|
||||
|
||||
uid: int
|
||||
"The unique ID of the PING."
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtocolNegotiated(QuicEvent):
|
||||
"""
|
||||
The ProtocolNegotiated event is fired when ALPN negotiation completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: Optional[str]
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopSendingReceived(QuicEvent):
|
||||
"""
|
||||
The StopSendingReceived event is fired when the remote peer requests
|
||||
stopping data transmission on a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that was sent from the peer."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that the peer requested stopping data transmission."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamDataReceived(QuicEvent):
|
||||
"""
|
||||
The StreamDataReceived event is fired whenever data is received on a
|
||||
stream.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
end_stream: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamReset(QuicEvent):
|
||||
"""
|
||||
The StreamReset event is fired when the remote peer resets a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that triggered the reset."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that was reset."
|
||||
329
Code/venv/lib/python3.13/site-packages/aioquic/quic/logger.py
Normal file
329
Code/venv/lib/python3.13/site-packages/aioquic/quic/logger.py
Normal file
@ -0,0 +1,329 @@
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
from ..h3.events import Headers
|
||||
from .packet import (
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
QuicStreamFrame,
|
||||
QuicTransportParameters,
|
||||
)
|
||||
from .rangeset import RangeSet
|
||||
|
||||
PACKET_TYPE_NAMES = {
|
||||
QuicPacketType.INITIAL: "initial",
|
||||
QuicPacketType.HANDSHAKE: "handshake",
|
||||
QuicPacketType.ZERO_RTT: "0RTT",
|
||||
QuicPacketType.ONE_RTT: "1RTT",
|
||||
QuicPacketType.RETRY: "retry",
|
||||
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
|
||||
}
|
||||
QLOG_VERSION = "0.3"
|
||||
|
||||
|
||||
def hexdump(data: bytes) -> str:
|
||||
return binascii.hexlify(data).decode("ascii")
|
||||
|
||||
|
||||
class QuicLoggerTrace:
|
||||
"""
|
||||
A QUIC event trace.
|
||||
|
||||
Events are logged in the format defined by qlog.
|
||||
|
||||
See:
|
||||
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
|
||||
"""
|
||||
|
||||
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
|
||||
self._odcid = odcid
|
||||
self._events: Deque[Dict[str, Any]] = deque()
|
||||
self._vantage_point = {
|
||||
"name": "aioquic",
|
||||
"type": "client" if is_client else "server",
|
||||
}
|
||||
|
||||
# QUIC
|
||||
|
||||
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict:
|
||||
return {
|
||||
"ack_delay": self.encode_time(delay),
|
||||
"acked_ranges": [[x.start, x.stop - 1] for x in ranges],
|
||||
"frame_type": "ack",
|
||||
}
|
||||
|
||||
def encode_connection_close_frame(
|
||||
self, error_code: int, frame_type: Optional[int], reason_phrase: str
|
||||
) -> Dict:
|
||||
attrs = {
|
||||
"error_code": error_code,
|
||||
"error_space": "application" if frame_type is None else "transport",
|
||||
"frame_type": "connection_close",
|
||||
"raw_error_code": error_code,
|
||||
"reason": reason_phrase,
|
||||
}
|
||||
if frame_type is not None:
|
||||
attrs["trigger_frame_type"] = frame_type
|
||||
|
||||
return attrs
|
||||
|
||||
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict:
|
||||
if frame_type == QuicFrameType.MAX_DATA:
|
||||
return {"frame_type": "max_data", "maximum": maximum}
|
||||
else:
|
||||
return {
|
||||
"frame_type": "max_streams",
|
||||
"maximum": maximum,
|
||||
"stream_type": "unidirectional"
|
||||
if frame_type == QuicFrameType.MAX_STREAMS_UNI
|
||||
else "bidirectional",
|
||||
}
|
||||
|
||||
def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict:
|
||||
return {
|
||||
"frame_type": "crypto",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
}
|
||||
|
||||
def encode_data_blocked_frame(self, limit: int) -> Dict:
|
||||
return {"frame_type": "data_blocked", "limit": limit}
|
||||
|
||||
def encode_datagram_frame(self, length: int) -> Dict:
|
||||
return {"frame_type": "datagram", "length": length}
|
||||
|
||||
def encode_handshake_done_frame(self) -> Dict:
|
||||
return {"frame_type": "handshake_done"}
|
||||
|
||||
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "max_stream_data",
|
||||
"maximum": maximum,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_new_connection_id_frame(
|
||||
self,
|
||||
connection_id: bytes,
|
||||
retire_prior_to: int,
|
||||
sequence_number: int,
|
||||
stateless_reset_token: bytes,
|
||||
) -> Dict:
|
||||
return {
|
||||
"connection_id": hexdump(connection_id),
|
||||
"frame_type": "new_connection_id",
|
||||
"length": len(connection_id),
|
||||
"reset_token": hexdump(stateless_reset_token),
|
||||
"retire_prior_to": retire_prior_to,
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_new_token_frame(self, token: bytes) -> Dict:
|
||||
return {
|
||||
"frame_type": "new_token",
|
||||
"length": len(token),
|
||||
"token": hexdump(token),
|
||||
}
|
||||
|
||||
def encode_padding_frame(self) -> Dict:
|
||||
return {"frame_type": "padding"}
|
||||
|
||||
def encode_path_challenge_frame(self, data: bytes) -> Dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_challenge"}
|
||||
|
||||
def encode_path_response_frame(self, data: bytes) -> Dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_response"}
|
||||
|
||||
def encode_ping_frame(self) -> Dict:
|
||||
return {"frame_type": "ping"}
|
||||
|
||||
def encode_reset_stream_frame(
|
||||
self, error_code: int, final_size: int, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"error_code": error_code,
|
||||
"final_size": final_size,
|
||||
"frame_type": "reset_stream",
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "retire_connection_id",
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "stream_data_blocked",
|
||||
"limit": limit,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "stop_sending",
|
||||
"error_code": error_code,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict:
|
||||
return {
|
||||
"fin": frame.fin,
|
||||
"frame_type": "stream",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "streams_blocked",
|
||||
"limit": limit,
|
||||
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
|
||||
}
|
||||
|
||||
def encode_time(self, seconds: float) -> float:
|
||||
"""
|
||||
Convert a time to milliseconds.
|
||||
"""
|
||||
return seconds * 1000
|
||||
|
||||
def encode_transport_parameters(
|
||||
self, owner: str, parameters: QuicTransportParameters
|
||||
) -> Dict[str, Any]:
|
||||
data: Dict[str, Any] = {"owner": owner}
|
||||
for param_name, param_value in parameters.__dict__.items():
|
||||
if isinstance(param_value, bool):
|
||||
data[param_name] = param_value
|
||||
elif isinstance(param_value, bytes):
|
||||
data[param_name] = hexdump(param_value)
|
||||
elif isinstance(param_value, int):
|
||||
data[param_name] = param_value
|
||||
return data
|
||||
|
||||
def packet_type(self, packet_type: QuicPacketType) -> str:
|
||||
return PACKET_TYPE_NAMES[packet_type]
|
||||
|
||||
# HTTP/3
|
||||
|
||||
def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame": {"frame_type": "data"},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_headers_frame(
|
||||
self, length: int, headers: Headers, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "headers",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_push_promise_frame(
|
||||
self, length: int, headers: Headers, push_id: int, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "push_promise",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
"push_id": push_id,
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def _encode_http3_headers(self, headers: Headers) -> List[Dict]:
|
||||
return [
|
||||
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
|
||||
]
|
||||
|
||||
# CORE
|
||||
|
||||
def log_event(self, *, category: str, event: str, data: Dict) -> None:
|
||||
self._events.append(
|
||||
{
|
||||
"data": data,
|
||||
"name": category + ":" + event,
|
||||
"time": self.encode_time(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the trace as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"common_fields": {
|
||||
"ODCID": hexdump(self._odcid),
|
||||
},
|
||||
"events": list(self._events),
|
||||
"vantage_point": self._vantage_point,
|
||||
}
|
||||
|
||||
|
||||
class QuicLogger:
|
||||
"""
|
||||
A QUIC event logger which stores traces in memory.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._traces: List[QuicLoggerTrace] = []
|
||||
|
||||
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
|
||||
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
|
||||
self._traces.append(trace)
|
||||
return trace
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the traces as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace.to_dict() for trace in self._traces],
|
||||
}
|
||||
|
||||
|
||||
class QuicFileLogger(QuicLogger):
|
||||
"""
|
||||
A QUIC event logger which writes one trace per file.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
if not os.path.isdir(path):
|
||||
raise ValueError("QUIC log output directory '%s' does not exist" % path)
|
||||
self.path = path
|
||||
super().__init__()
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
trace_dict = trace.to_dict()
|
||||
trace_path = os.path.join(
|
||||
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
|
||||
)
|
||||
with open(trace_path, "w") as logger_fp:
|
||||
json.dump(
|
||||
{
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace_dict],
|
||||
},
|
||||
logger_fp,
|
||||
)
|
||||
self._traces.remove(trace)
|
||||
640
Code/venv/lib/python3.13/site-packages/aioquic/quic/packet.py
Normal file
640
Code/venv/lib/python3.13/site-packages/aioquic/quic/packet.py
Normal file
@ -0,0 +1,640 @@
|
||||
import binascii
|
||||
import ipaddress
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ..buffer import Buffer
|
||||
from .rangeset import RangeSet
|
||||
|
||||
PACKET_LONG_HEADER = 0x80
|
||||
PACKET_FIXED_BIT = 0x40
|
||||
PACKET_SPIN_BIT = 0x20
|
||||
|
||||
CONNECTION_ID_MAX_SIZE = 20
|
||||
PACKET_NUMBER_MAX_SIZE = 4
|
||||
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
|
||||
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
|
||||
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
|
||||
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
|
||||
RETRY_INTEGRITY_TAG_SIZE = 16
|
||||
STATELESS_RESET_TOKEN_SIZE = 16
|
||||
|
||||
|
||||
class QuicErrorCode(IntEnum):
|
||||
NO_ERROR = 0x0
|
||||
INTERNAL_ERROR = 0x1
|
||||
CONNECTION_REFUSED = 0x2
|
||||
FLOW_CONTROL_ERROR = 0x3
|
||||
STREAM_LIMIT_ERROR = 0x4
|
||||
STREAM_STATE_ERROR = 0x5
|
||||
FINAL_SIZE_ERROR = 0x6
|
||||
FRAME_ENCODING_ERROR = 0x7
|
||||
TRANSPORT_PARAMETER_ERROR = 0x8
|
||||
CONNECTION_ID_LIMIT_ERROR = 0x9
|
||||
PROTOCOL_VIOLATION = 0xA
|
||||
INVALID_TOKEN = 0xB
|
||||
APPLICATION_ERROR = 0xC
|
||||
CRYPTO_BUFFER_EXCEEDED = 0xD
|
||||
KEY_UPDATE_ERROR = 0xE
|
||||
AEAD_LIMIT_REACHED = 0xF
|
||||
VERSION_NEGOTIATION_ERROR = 0x11
|
||||
CRYPTO_ERROR = 0x100
|
||||
|
||||
|
||||
class QuicPacketType(Enum):
|
||||
INITIAL = 0
|
||||
ZERO_RTT = 1
|
||||
HANDSHAKE = 2
|
||||
RETRY = 3
|
||||
VERSION_NEGOTIATION = 4
|
||||
ONE_RTT = 5
|
||||
|
||||
|
||||
# For backwards compatibility only, use `QuicPacketType` in new code.
|
||||
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
|
||||
|
||||
# QUIC version 1
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
|
||||
QuicPacketType.INITIAL: 0,
|
||||
QuicPacketType.ZERO_RTT: 1,
|
||||
QuicPacketType.HANDSHAKE: 2,
|
||||
QuicPacketType.RETRY: 3,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_1 = dict(
|
||||
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
|
||||
)
|
||||
|
||||
# QUIC version 2
|
||||
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
|
||||
QuicPacketType.INITIAL: 1,
|
||||
QuicPacketType.ZERO_RTT: 2,
|
||||
QuicPacketType.HANDSHAKE: 3,
|
||||
QuicPacketType.RETRY: 0,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_2 = dict(
|
||||
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
|
||||
)
|
||||
|
||||
|
||||
class QuicProtocolVersion(IntEnum):
|
||||
NEGOTIATION = 0
|
||||
VERSION_1 = 0x00000001
|
||||
VERSION_2 = 0x6B3343CF
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicHeader:
|
||||
version: Optional[int]
|
||||
"The protocol version. Only present in long header packets."
|
||||
|
||||
packet_type: QuicPacketType
|
||||
"The type of the packet."
|
||||
|
||||
packet_length: int
|
||||
"The total length of the packet, in bytes."
|
||||
|
||||
destination_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
source_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
token: bytes
|
||||
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
|
||||
|
||||
integrity_tag: bytes
|
||||
"The retry integrity tag. Only present in `RETRY` packets."
|
||||
|
||||
supported_versions: List[int]
|
||||
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
|
||||
|
||||
|
||||
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
|
||||
"""
|
||||
Recover a packet number from a truncated packet number.
|
||||
|
||||
See: Appendix A - Sample Packet Number Decoding Algorithm
|
||||
"""
|
||||
window = 1 << num_bits
|
||||
half_window = window // 2
|
||||
candidate = (expected & ~(window - 1)) | truncated
|
||||
if candidate <= expected - half_window and candidate < (1 << 62) - window:
|
||||
return candidate + window
|
||||
elif candidate > expected + half_window and candidate >= window:
|
||||
return candidate - window
|
||||
else:
|
||||
return candidate
|
||||
|
||||
|
||||
def get_retry_integrity_tag(
|
||||
packet_without_tag: bytes, original_destination_cid: bytes, version: int
|
||||
) -> bytes:
|
||||
"""
|
||||
Calculate the integrity tag for a RETRY packet.
|
||||
"""
|
||||
# build Retry pseudo packet
|
||||
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
|
||||
buf.push_uint8(len(original_destination_cid))
|
||||
buf.push_bytes(original_destination_cid)
|
||||
buf.push_bytes(packet_without_tag)
|
||||
assert buf.eof()
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_2
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
|
||||
else:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_1
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
|
||||
|
||||
# run AES-128-GCM
|
||||
aead = AESGCM(aead_key)
|
||||
integrity_tag = aead.encrypt(aead_nonce, b"", buf.data)
|
||||
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
|
||||
return integrity_tag
|
||||
|
||||
|
||||
def get_spin_bit(first_byte: int) -> bool:
|
||||
return bool(first_byte & PACKET_SPIN_BIT)
|
||||
|
||||
|
||||
def is_long_header(first_byte: int) -> bool:
|
||||
return bool(first_byte & PACKET_LONG_HEADER)
|
||||
|
||||
|
||||
def pretty_protocol_version(version: int) -> str:
|
||||
"""
|
||||
Return a user-friendly representation of a protocol version.
|
||||
"""
|
||||
try:
|
||||
version_name = QuicProtocolVersion(version).name
|
||||
except ValueError:
|
||||
version_name = "UNKNOWN"
|
||||
return f"0x{version:08x} ({version_name})"
|
||||
|
||||
|
||||
def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader:
|
||||
packet_start = buf.tell()
|
||||
|
||||
version = None
|
||||
integrity_tag = b""
|
||||
supported_versions = []
|
||||
token = b""
|
||||
|
||||
first_byte = buf.pull_uint8()
|
||||
if is_long_header(first_byte):
|
||||
# Long Header Packets.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
version = buf.pull_uint32()
|
||||
|
||||
destination_cid_length = buf.pull_uint8()
|
||||
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError(
|
||||
"Destination CID is too long (%d bytes)" % destination_cid_length
|
||||
)
|
||||
destination_cid = buf.pull_bytes(destination_cid_length)
|
||||
|
||||
source_cid_length = buf.pull_uint8()
|
||||
if source_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError("Source CID is too long (%d bytes)" % source_cid_length)
|
||||
source_cid = buf.pull_bytes(source_cid_length)
|
||||
|
||||
if version == QuicProtocolVersion.NEGOTIATION:
|
||||
# Version Negotiation Packet.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
|
||||
packet_type = QuicPacketType.VERSION_NEGOTIATION
|
||||
while not buf.eof():
|
||||
supported_versions.append(buf.pull_uint32())
|
||||
packet_end = buf.tell()
|
||||
else:
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
else:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = buf.pull_uint_var()
|
||||
token = buf.pull_bytes(token_length)
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.ZERO_RTT:
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
rest_length = buf.pull_uint_var()
|
||||
else:
|
||||
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
|
||||
token = buf.pull_bytes(token_length)
|
||||
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
|
||||
rest_length = 0
|
||||
|
||||
# Check remainder length.
|
||||
packet_end = buf.tell() + rest_length
|
||||
if packet_end > buf.capacity:
|
||||
raise ValueError("Packet payload is truncated")
|
||||
|
||||
else:
|
||||
# Short Header Packets.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
version = None
|
||||
packet_type = QuicPacketType.ONE_RTT
|
||||
destination_cid = buf.pull_bytes(host_cid_length)
|
||||
source_cid = b""
|
||||
packet_end = buf.capacity
|
||||
|
||||
return QuicHeader(
|
||||
version=version,
|
||||
packet_type=packet_type,
|
||||
packet_length=packet_end - packet_start,
|
||||
destination_cid=destination_cid,
|
||||
source_cid=source_cid,
|
||||
token=token,
|
||||
integrity_tag=integrity_tag,
|
||||
supported_versions=supported_versions,
|
||||
)
|
||||
|
||||
|
||||
def encode_long_header_first_byte(
|
||||
version: int, packet_type: QuicPacketType, bits: int
|
||||
) -> int:
|
||||
"""
|
||||
Encode the first byte of a long header packet.
|
||||
"""
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
|
||||
else:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
|
||||
return (
|
||||
PACKET_LONG_HEADER
|
||||
| PACKET_FIXED_BIT
|
||||
| long_type_encode[packet_type] << 4
|
||||
| bits
|
||||
)
|
||||
|
||||
|
||||
def encode_quic_retry(
|
||||
version: int,
|
||||
source_cid: bytes,
|
||||
destination_cid: bytes,
|
||||
original_destination_cid: bytes,
|
||||
retry_token: bytes,
|
||||
unused: int = 0,
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ len(retry_token)
|
||||
+ RETRY_INTEGRITY_TAG_SIZE
|
||||
)
|
||||
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
|
||||
buf.push_uint32(version)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
buf.push_bytes(retry_token)
|
||||
buf.push_bytes(
|
||||
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
|
||||
)
|
||||
assert buf.eof()
|
||||
return buf.data
|
||||
|
||||
|
||||
def encode_quic_version_negotiation(
|
||||
source_cid: bytes, destination_cid: bytes, supported_versions: List[int]
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ 4 * len(supported_versions)
|
||||
)
|
||||
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
|
||||
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
for version in supported_versions:
|
||||
buf.push_uint32(version)
|
||||
return buf.data
|
||||
|
||||
|
||||
# TLS EXTENSION
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicPreferredAddress:
|
||||
ipv4_address: Optional[Tuple[str, int]]
|
||||
ipv6_address: Optional[Tuple[str, int]]
|
||||
connection_id: bytes
|
||||
stateless_reset_token: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicVersionInformation:
|
||||
chosen_version: int
|
||||
available_versions: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicTransportParameters:
|
||||
original_destination_connection_id: Optional[bytes] = None
|
||||
max_idle_timeout: Optional[int] = None
|
||||
stateless_reset_token: Optional[bytes] = None
|
||||
max_udp_payload_size: Optional[int] = None
|
||||
initial_max_data: Optional[int] = None
|
||||
initial_max_stream_data_bidi_local: Optional[int] = None
|
||||
initial_max_stream_data_bidi_remote: Optional[int] = None
|
||||
initial_max_stream_data_uni: Optional[int] = None
|
||||
initial_max_streams_bidi: Optional[int] = None
|
||||
initial_max_streams_uni: Optional[int] = None
|
||||
ack_delay_exponent: Optional[int] = None
|
||||
max_ack_delay: Optional[int] = None
|
||||
disable_active_migration: Optional[bool] = False
|
||||
preferred_address: Optional[QuicPreferredAddress] = None
|
||||
active_connection_id_limit: Optional[int] = None
|
||||
initial_source_connection_id: Optional[bytes] = None
|
||||
retry_source_connection_id: Optional[bytes] = None
|
||||
version_information: Optional[QuicVersionInformation] = None
|
||||
max_datagram_frame_size: Optional[int] = None
|
||||
quantum_readiness: Optional[bytes] = None
|
||||
|
||||
|
||||
PARAMS = {
|
||||
0x00: ("original_destination_connection_id", bytes),
|
||||
0x01: ("max_idle_timeout", int),
|
||||
0x02: ("stateless_reset_token", bytes),
|
||||
0x03: ("max_udp_payload_size", int),
|
||||
0x04: ("initial_max_data", int),
|
||||
0x05: ("initial_max_stream_data_bidi_local", int),
|
||||
0x06: ("initial_max_stream_data_bidi_remote", int),
|
||||
0x07: ("initial_max_stream_data_uni", int),
|
||||
0x08: ("initial_max_streams_bidi", int),
|
||||
0x09: ("initial_max_streams_uni", int),
|
||||
0x0A: ("ack_delay_exponent", int),
|
||||
0x0B: ("max_ack_delay", int),
|
||||
0x0C: ("disable_active_migration", bool),
|
||||
0x0D: ("preferred_address", QuicPreferredAddress),
|
||||
0x0E: ("active_connection_id_limit", int),
|
||||
0x0F: ("initial_source_connection_id", bytes),
|
||||
0x10: ("retry_source_connection_id", bytes),
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
|
||||
0x11: ("version_information", QuicVersionInformation),
|
||||
# extensions
|
||||
0x0020: ("max_datagram_frame_size", int),
|
||||
0x0C37: ("quantum_readiness", bytes),
|
||||
}
|
||||
|
||||
|
||||
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
|
||||
ipv4_address = None
|
||||
ipv4_host = buf.pull_bytes(4)
|
||||
ipv4_port = buf.pull_uint16()
|
||||
if ipv4_host != bytes(4):
|
||||
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
|
||||
|
||||
ipv6_address = None
|
||||
ipv6_host = buf.pull_bytes(16)
|
||||
ipv6_port = buf.pull_uint16()
|
||||
if ipv6_host != bytes(16):
|
||||
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
|
||||
|
||||
connection_id_length = buf.pull_uint8()
|
||||
connection_id = buf.pull_bytes(connection_id_length)
|
||||
stateless_reset_token = buf.pull_bytes(16)
|
||||
|
||||
return QuicPreferredAddress(
|
||||
ipv4_address=ipv4_address,
|
||||
ipv6_address=ipv6_address,
|
||||
connection_id=connection_id,
|
||||
stateless_reset_token=stateless_reset_token,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_preferred_address(
|
||||
buf: Buffer, preferred_address: QuicPreferredAddress
|
||||
) -> None:
|
||||
if preferred_address.ipv4_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv4_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(6))
|
||||
|
||||
if preferred_address.ipv6_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv6_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(18))
|
||||
|
||||
buf.push_uint8(len(preferred_address.connection_id))
|
||||
buf.push_bytes(preferred_address.connection_id)
|
||||
buf.push_bytes(preferred_address.stateless_reset_token)
|
||||
|
||||
|
||||
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
|
||||
chosen_version = buf.pull_uint32()
|
||||
available_versions = []
|
||||
for i in range(length // 4 - 1):
|
||||
available_versions.append(buf.pull_uint32())
|
||||
|
||||
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
|
||||
# equal to zero, it MUST treat it as a parsing failure.
|
||||
#
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
|
||||
if chosen_version == 0 or 0 in available_versions:
|
||||
raise ValueError("Version Information must not contain version 0")
|
||||
|
||||
return QuicVersionInformation(
|
||||
chosen_version=chosen_version,
|
||||
available_versions=available_versions,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_version_information(
|
||||
buf: Buffer, version_information: QuicVersionInformation
|
||||
) -> None:
|
||||
buf.push_uint32(version_information.chosen_version)
|
||||
for version in version_information.available_versions:
|
||||
buf.push_uint32(version)
|
||||
|
||||
|
||||
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
|
||||
params = QuicTransportParameters()
|
||||
while not buf.eof():
|
||||
param_id = buf.pull_uint_var()
|
||||
param_len = buf.pull_uint_var()
|
||||
param_start = buf.tell()
|
||||
if param_id in PARAMS:
|
||||
# Parse known parameter.
|
||||
param_name, param_type = PARAMS[param_id]
|
||||
if param_type is int:
|
||||
setattr(params, param_name, buf.pull_uint_var())
|
||||
elif param_type is bytes:
|
||||
setattr(params, param_name, buf.pull_bytes(param_len))
|
||||
elif param_type is QuicPreferredAddress:
|
||||
setattr(params, param_name, pull_quic_preferred_address(buf))
|
||||
elif param_type is QuicVersionInformation:
|
||||
setattr(
|
||||
params,
|
||||
param_name,
|
||||
pull_quic_version_information(buf, param_len),
|
||||
)
|
||||
else:
|
||||
setattr(params, param_name, True)
|
||||
else:
|
||||
# Skip unknown parameter.
|
||||
buf.pull_bytes(param_len)
|
||||
|
||||
if buf.tell() != param_start + param_len:
|
||||
raise ValueError("Transport parameter length does not match")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def push_quic_transport_parameters(
|
||||
buf: Buffer, params: QuicTransportParameters
|
||||
) -> None:
|
||||
for param_id, (param_name, param_type) in PARAMS.items():
|
||||
param_value = getattr(params, param_name)
|
||||
if param_value is not None and param_value is not False:
|
||||
param_buf = Buffer(capacity=65536)
|
||||
if param_type is int:
|
||||
param_buf.push_uint_var(param_value)
|
||||
elif param_type is bytes:
|
||||
param_buf.push_bytes(param_value)
|
||||
elif param_type is QuicPreferredAddress:
|
||||
push_quic_preferred_address(param_buf, param_value)
|
||||
elif param_type is QuicVersionInformation:
|
||||
push_quic_version_information(param_buf, param_value)
|
||||
buf.push_uint_var(param_id)
|
||||
buf.push_uint_var(param_buf.tell())
|
||||
buf.push_bytes(param_buf.data)
|
||||
|
||||
|
||||
# FRAMES
|
||||
|
||||
|
||||
class QuicFrameType(IntEnum):
|
||||
PADDING = 0x00
|
||||
PING = 0x01
|
||||
ACK = 0x02
|
||||
ACK_ECN = 0x03
|
||||
RESET_STREAM = 0x04
|
||||
STOP_SENDING = 0x05
|
||||
CRYPTO = 0x06
|
||||
NEW_TOKEN = 0x07
|
||||
STREAM_BASE = 0x08
|
||||
MAX_DATA = 0x10
|
||||
MAX_STREAM_DATA = 0x11
|
||||
MAX_STREAMS_BIDI = 0x12
|
||||
MAX_STREAMS_UNI = 0x13
|
||||
DATA_BLOCKED = 0x14
|
||||
STREAM_DATA_BLOCKED = 0x15
|
||||
STREAMS_BLOCKED_BIDI = 0x16
|
||||
STREAMS_BLOCKED_UNI = 0x17
|
||||
NEW_CONNECTION_ID = 0x18
|
||||
RETIRE_CONNECTION_ID = 0x19
|
||||
PATH_CHALLENGE = 0x1A
|
||||
PATH_RESPONSE = 0x1B
|
||||
TRANSPORT_CLOSE = 0x1C
|
||||
APPLICATION_CLOSE = 0x1D
|
||||
HANDSHAKE_DONE = 0x1E
|
||||
DATAGRAM = 0x30
|
||||
DATAGRAM_WITH_LENGTH = 0x31
|
||||
|
||||
|
||||
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
|
||||
PROBING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.PATH_CHALLENGE,
|
||||
QuicFrameType.PATH_RESPONSE,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.NEW_CONNECTION_ID,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicResetStreamFrame:
|
||||
error_code: int
|
||||
final_size: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStopSendingFrame:
|
||||
error_code: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStreamFrame:
|
||||
data: bytes = b""
|
||||
fin: bool = False
|
||||
offset: int = 0
|
||||
|
||||
|
||||
def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]:
|
||||
rangeset = RangeSet()
|
||||
end = buf.pull_uint_var() # largest acknowledged
|
||||
delay = buf.pull_uint_var()
|
||||
ack_range_count = buf.pull_uint_var()
|
||||
ack_count = buf.pull_uint_var() # first ack range
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
for _ in range(ack_range_count):
|
||||
end -= buf.pull_uint_var() + 2
|
||||
ack_count = buf.pull_uint_var()
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
return rangeset, delay
|
||||
|
||||
|
||||
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
|
||||
ranges = len(rangeset)
|
||||
index = ranges - 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(r.stop - 1)
|
||||
buf.push_uint_var(delay)
|
||||
buf.push_uint_var(index)
|
||||
buf.push_uint_var(r.stop - 1 - r.start)
|
||||
start = r.start
|
||||
while index > 0:
|
||||
index -= 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(start - r.stop - 1)
|
||||
buf.push_uint_var(r.stop - r.start - 1)
|
||||
start = r.start
|
||||
return ranges
|
||||
@ -0,0 +1,384 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..buffer import Buffer, size_uint_var
|
||||
from ..tls import Epoch
|
||||
from .crypto import CryptoPair
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet import (
|
||||
NON_ACK_ELICITING_FRAME_TYPES,
|
||||
NON_IN_FLIGHT_FRAME_TYPES,
|
||||
PACKET_FIXED_BIT,
|
||||
PACKET_NUMBER_MAX_SIZE,
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
encode_long_header_first_byte,
|
||||
)
|
||||
|
||||
PACKET_LENGTH_SEND_SIZE = 2
|
||||
PACKET_NUMBER_SEND_SIZE = 2
|
||||
|
||||
|
||||
QuicDeliveryHandler = Callable[..., None]
|
||||
|
||||
|
||||
class QuicDeliveryState(Enum):
|
||||
ACKED = 0
|
||||
LOST = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicSentPacket:
|
||||
epoch: Epoch
|
||||
in_flight: bool
|
||||
is_ack_eliciting: bool
|
||||
is_crypto_packet: bool
|
||||
packet_number: int
|
||||
packet_type: QuicPacketType
|
||||
sent_time: Optional[float] = None
|
||||
sent_bytes: int = 0
|
||||
|
||||
delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
quic_logger_frames: List[Dict] = field(default_factory=list)
|
||||
|
||||
|
||||
class QuicPacketBuilderStop(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicPacketBuilder:
|
||||
"""
|
||||
Helper for building QUIC packets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_cid: bytes,
|
||||
peer_cid: bytes,
|
||||
version: int,
|
||||
is_client: bool,
|
||||
max_datagram_size: int,
|
||||
packet_number: int = 0,
|
||||
peer_token: bytes = b"",
|
||||
quic_logger: Optional[QuicLoggerTrace] = None,
|
||||
spin_bit: bool = False,
|
||||
):
|
||||
self.max_flight_bytes: Optional[int] = None
|
||||
self.max_total_bytes: Optional[int] = None
|
||||
self.quic_logger_frames: Optional[List[Dict]] = None
|
||||
|
||||
self._host_cid = host_cid
|
||||
self._is_client = is_client
|
||||
self._peer_cid = peer_cid
|
||||
self._peer_token = peer_token
|
||||
self._quic_logger = quic_logger
|
||||
self._spin_bit = spin_bit
|
||||
self._version = version
|
||||
|
||||
# assembled datagrams and packets
|
||||
self._datagrams: List[bytes] = []
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = True
|
||||
self._datagram_needs_padding = False
|
||||
self._packets: List[QuicSentPacket] = []
|
||||
self._flight_bytes = 0
|
||||
self._total_bytes = 0
|
||||
|
||||
# current packet
|
||||
self._header_size = 0
|
||||
self._packet: Optional[QuicSentPacket] = None
|
||||
self._packet_crypto: Optional[CryptoPair] = None
|
||||
self._packet_number = packet_number
|
||||
self._packet_start = 0
|
||||
self._packet_type: Optional[QuicPacketType] = None
|
||||
|
||||
self._buffer = Buffer(max_datagram_size)
|
||||
self._buffer_capacity = max_datagram_size
|
||||
self._flight_capacity = max_datagram_size
|
||||
|
||||
@property
|
||||
def packet_is_empty(self) -> bool:
|
||||
"""
|
||||
Returns `True` if the current packet is empty.
|
||||
"""
|
||||
assert self._packet is not None
|
||||
packet_size = self._buffer.tell() - self._packet_start
|
||||
return packet_size <= self._header_size
|
||||
|
||||
@property
|
||||
def packet_number(self) -> int:
|
||||
"""
|
||||
Returns the packet number for the next packet.
|
||||
"""
|
||||
return self._packet_number
|
||||
|
||||
@property
|
||||
def remaining_buffer_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._buffer_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_flight_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._flight_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]:
|
||||
"""
|
||||
Returns the assembled datagrams.
|
||||
"""
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
self._flush_current_datagram()
|
||||
|
||||
datagrams = self._datagrams
|
||||
packets = self._packets
|
||||
self._datagrams = []
|
||||
self._packets = []
|
||||
return datagrams, packets
|
||||
|
||||
def start_frame(
|
||||
self,
|
||||
frame_type: int,
|
||||
capacity: int = 1,
|
||||
handler: Optional[QuicDeliveryHandler] = None,
|
||||
handler_args: Sequence[Any] = [],
|
||||
) -> Buffer:
|
||||
"""
|
||||
Starts a new frame.
|
||||
"""
|
||||
if self.remaining_buffer_space < capacity or (
|
||||
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
|
||||
and self.remaining_flight_space < capacity
|
||||
):
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
self._buffer.push_uint_var(frame_type)
|
||||
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
|
||||
self._packet.is_ack_eliciting = True
|
||||
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
|
||||
self._packet.in_flight = True
|
||||
if frame_type == QuicFrameType.CRYPTO:
|
||||
self._packet.is_crypto_packet = True
|
||||
if handler is not None:
|
||||
self._packet.delivery_handlers.append((handler, handler_args))
|
||||
return self._buffer
|
||||
|
||||
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
|
||||
"""
|
||||
Starts a new packet.
|
||||
"""
|
||||
assert packet_type in (
|
||||
QuicPacketType.INITIAL,
|
||||
QuicPacketType.HANDSHAKE,
|
||||
QuicPacketType.ZERO_RTT,
|
||||
QuicPacketType.ONE_RTT,
|
||||
), "Invalid packet type"
|
||||
buf = self._buffer
|
||||
|
||||
# finish previous datagram
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
|
||||
# if there is too little space remaining, start a new datagram
|
||||
# FIXME: the limit is arbitrary!
|
||||
packet_start = buf.tell()
|
||||
if self._buffer_capacity - packet_start < 128:
|
||||
self._flush_current_datagram()
|
||||
packet_start = 0
|
||||
|
||||
# initialize datagram if needed
|
||||
if self._datagram_init:
|
||||
if self.max_total_bytes is not None:
|
||||
remaining_total_bytes = self.max_total_bytes - self._total_bytes
|
||||
if remaining_total_bytes < self._buffer_capacity:
|
||||
self._buffer_capacity = remaining_total_bytes
|
||||
|
||||
self._flight_capacity = self._buffer_capacity
|
||||
if self.max_flight_bytes is not None:
|
||||
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
|
||||
if remaining_flight_bytes < self._flight_capacity:
|
||||
self._flight_capacity = remaining_flight_bytes
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = False
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# calculate header size
|
||||
if packet_type != QuicPacketType.ONE_RTT:
|
||||
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = len(self._peer_token)
|
||||
header_size += size_uint_var(token_length) + token_length
|
||||
else:
|
||||
header_size = 3 + len(self._peer_cid)
|
||||
|
||||
# check we have enough space
|
||||
if packet_start + header_size >= self._buffer_capacity:
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
# determine ack epoch
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
epoch = Epoch.INITIAL
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
epoch = Epoch.HANDSHAKE
|
||||
else:
|
||||
epoch = Epoch.ONE_RTT
|
||||
|
||||
self._header_size = header_size
|
||||
self._packet = QuicSentPacket(
|
||||
epoch=epoch,
|
||||
in_flight=False,
|
||||
is_ack_eliciting=False,
|
||||
is_crypto_packet=False,
|
||||
packet_number=self._packet_number,
|
||||
packet_type=packet_type,
|
||||
)
|
||||
self._packet_crypto = crypto
|
||||
self._packet_start = packet_start
|
||||
self._packet_type = packet_type
|
||||
self.quic_logger_frames = self._packet.quic_logger_frames
|
||||
|
||||
buf.seek(self._packet_start + self._header_size)
|
||||
|
||||
def _end_packet(self) -> None:
|
||||
"""
|
||||
Ends the current packet.
|
||||
"""
|
||||
buf = self._buffer
|
||||
packet_size = buf.tell() - self._packet_start
|
||||
if packet_size > self._header_size:
|
||||
# padding to ensure sufficient sample size
|
||||
padding_size = (
|
||||
PACKET_NUMBER_MAX_SIZE
|
||||
- PACKET_NUMBER_SEND_SIZE
|
||||
+ self._header_size
|
||||
- packet_size
|
||||
)
|
||||
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if (
|
||||
self._is_client or self._packet.is_ack_eliciting
|
||||
) and self._packet_type == QuicPacketType.INITIAL:
|
||||
self._datagram_needs_padding = True
|
||||
|
||||
# For datagrams containing 1-RTT data, we *must* apply the padding
|
||||
# inside the packet, we cannot tack bytes onto the end of the
|
||||
# datagram.
|
||||
if (
|
||||
self._datagram_needs_padding
|
||||
and self._packet_type == QuicPacketType.ONE_RTT
|
||||
):
|
||||
if self.remaining_flight_space > padding_size:
|
||||
padding_size = self.remaining_flight_space
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# write padding
|
||||
if padding_size > 0:
|
||||
buf.push_bytes(bytes(padding_size))
|
||||
packet_size += padding_size
|
||||
self._packet.in_flight = True
|
||||
|
||||
# log frame
|
||||
if self._quic_logger is not None:
|
||||
self._packet.quic_logger_frames.append(
|
||||
self._quic_logger.encode_padding_frame()
|
||||
)
|
||||
|
||||
# write header
|
||||
if self._packet_type != QuicPacketType.ONE_RTT:
|
||||
length = (
|
||||
packet_size
|
||||
- self._header_size
|
||||
+ PACKET_NUMBER_SEND_SIZE
|
||||
+ self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
encode_long_header_first_byte(
|
||||
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
|
||||
)
|
||||
)
|
||||
buf.push_uint32(self._version)
|
||||
buf.push_uint8(len(self._peer_cid))
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint8(len(self._host_cid))
|
||||
buf.push_bytes(self._host_cid)
|
||||
if self._packet_type == QuicPacketType.INITIAL:
|
||||
buf.push_uint_var(len(self._peer_token))
|
||||
buf.push_bytes(self._peer_token)
|
||||
buf.push_uint16(length | 0x4000)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
else:
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
PACKET_FIXED_BIT
|
||||
| (self._spin_bit << 5)
|
||||
| (self._packet_crypto.key_phase << 2)
|
||||
| (PACKET_NUMBER_SEND_SIZE - 1)
|
||||
)
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
|
||||
# encrypt in place
|
||||
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_bytes(
|
||||
self._packet_crypto.encrypt_packet(
|
||||
plain[0 : self._header_size],
|
||||
plain[self._header_size : packet_size],
|
||||
self._packet_number,
|
||||
)
|
||||
)
|
||||
self._packet.sent_bytes = buf.tell() - self._packet_start
|
||||
self._packets.append(self._packet)
|
||||
if self._packet.in_flight:
|
||||
self._datagram_flight_bytes += self._packet.sent_bytes
|
||||
|
||||
# Short header packets cannot be coalesced, we need a new datagram.
|
||||
if self._packet_type == QuicPacketType.ONE_RTT:
|
||||
self._flush_current_datagram()
|
||||
|
||||
self._packet_number += 1
|
||||
else:
|
||||
# "cancel" the packet
|
||||
buf.seek(self._packet_start)
|
||||
|
||||
self._packet = None
|
||||
self.quic_logger_frames = None
|
||||
|
||||
def _flush_current_datagram(self) -> None:
|
||||
datagram_bytes = self._buffer.tell()
|
||||
if datagram_bytes:
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if self._datagram_needs_padding:
|
||||
extra_bytes = self._flight_capacity - self._buffer.tell()
|
||||
if extra_bytes > 0:
|
||||
self._buffer.push_bytes(bytes(extra_bytes))
|
||||
self._datagram_flight_bytes += extra_bytes
|
||||
datagram_bytes += extra_bytes
|
||||
|
||||
self._datagrams.append(self._buffer.data)
|
||||
self._flight_bytes += self._datagram_flight_bytes
|
||||
self._total_bytes += datagram_bytes
|
||||
self._datagram_init = True
|
||||
self._buffer.seek(0)
|
||||
@ -0,0 +1,98 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
class RangeSet(Sequence):
|
||||
def __init__(self, ranges: Iterable[range] = []):
|
||||
self.__ranges: List[range] = []
|
||||
for r in ranges:
|
||||
assert r.step == 1
|
||||
self.add(r.start, r.stop)
|
||||
|
||||
def add(self, start: int, stop: Optional[int] = None) -> None:
|
||||
if stop is None:
|
||||
stop = start + 1
|
||||
assert stop > start
|
||||
|
||||
for i, r in enumerate(self.__ranges):
|
||||
# the added range is entirely before current item, insert here
|
||||
if stop < r.start:
|
||||
self.__ranges.insert(i, range(start, stop))
|
||||
return
|
||||
|
||||
# the added range is entirely after current item, keep looking
|
||||
if start > r.stop:
|
||||
continue
|
||||
|
||||
# the added range touches the current item, merge it
|
||||
start = min(start, r.start)
|
||||
stop = max(stop, r.stop)
|
||||
while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop:
|
||||
stop = max(self.__ranges[i + 1].stop, stop)
|
||||
self.__ranges.pop(i + 1)
|
||||
self.__ranges[i] = range(start, stop)
|
||||
return
|
||||
|
||||
# the added range is entirely after all existing items, append it
|
||||
self.__ranges.append(range(start, stop))
|
||||
|
||||
def bounds(self) -> range:
|
||||
return range(self.__ranges[0].start, self.__ranges[-1].stop)
|
||||
|
||||
def shift(self) -> range:
|
||||
return self.__ranges.pop(0)
|
||||
|
||||
def subtract(self, start: int, stop: int) -> None:
|
||||
assert stop > start
|
||||
|
||||
i = 0
|
||||
while i < len(self.__ranges):
|
||||
r = self.__ranges[i]
|
||||
|
||||
# the removed range is entirely before current item, stop here
|
||||
if stop <= r.start:
|
||||
return
|
||||
|
||||
# the removed range is entirely after current item, keep looking
|
||||
if start >= r.stop:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# the removed range completely covers the current item, remove it
|
||||
if start <= r.start and stop >= r.stop:
|
||||
self.__ranges.pop(i)
|
||||
continue
|
||||
|
||||
# the removed range touches the current item
|
||||
if start > r.start:
|
||||
self.__ranges[i] = range(r.start, start)
|
||||
if stop < r.stop:
|
||||
self.__ranges.insert(i + 1, range(stop, r.stop))
|
||||
else:
|
||||
self.__ranges[i] = range(stop, r.stop)
|
||||
|
||||
i += 1
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __contains__(self, val: Any) -> bool:
|
||||
for r in self.__ranges:
|
||||
if val in r:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, RangeSet):
|
||||
return NotImplemented
|
||||
|
||||
return self.__ranges == other.__ranges
|
||||
|
||||
def __getitem__(self, key: Any) -> range:
|
||||
return self.__ranges[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__ranges)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "RangeSet({})".format(repr(self.__ranges))
|
||||
389
Code/venv/lib/python3.13/site-packages/aioquic/quic/recovery.py
Normal file
389
Code/venv/lib/python3.13/site-packages/aioquic/quic/recovery.py
Normal file
@ -0,0 +1,389 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from .congestion import cubic, reno # noqa
|
||||
from .congestion.base import K_GRANULARITY, create_congestion_control
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet_builder import QuicDeliveryState, QuicSentPacket
|
||||
from .rangeset import RangeSet
|
||||
|
||||
# loss detection
|
||||
K_PACKET_THRESHOLD = 3
|
||||
K_TIME_THRESHOLD = 9 / 8
|
||||
K_MICRO_SECOND = 0.000001
|
||||
K_SECOND = 1.0
|
||||
|
||||
|
||||
class QuicPacketSpace:
|
||||
def __init__(self) -> None:
|
||||
self.ack_at: Optional[float] = None
|
||||
self.ack_queue = RangeSet()
|
||||
self.discarded = False
|
||||
self.expected_packet_number = 0
|
||||
self.largest_received_packet = -1
|
||||
self.largest_received_time: Optional[float] = None
|
||||
|
||||
# sent packets and loss
|
||||
self.ack_eliciting_in_flight = 0
|
||||
self.largest_acked_packet = 0
|
||||
self.loss_time: Optional[float] = None
|
||||
self.sent_packets: Dict[int, QuicSentPacket] = {}
|
||||
|
||||
|
||||
class QuicPacketPacer:
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self.bucket_max: float = 0.0
|
||||
self.bucket_time: float = 0.0
|
||||
self.evaluation_time: float = 0.0
|
||||
self.packet_time: Optional[float] = None
|
||||
|
||||
def next_send_time(self, now: float) -> float:
|
||||
if self.packet_time is not None:
|
||||
self.update_bucket(now=now)
|
||||
if self.bucket_time <= 0:
|
||||
return now + self.packet_time
|
||||
return None
|
||||
|
||||
def update_after_send(self, now: float) -> None:
|
||||
if self.packet_time is not None:
|
||||
self.update_bucket(now=now)
|
||||
if self.bucket_time < self.packet_time:
|
||||
self.bucket_time = 0.0
|
||||
else:
|
||||
self.bucket_time -= self.packet_time
|
||||
|
||||
def update_bucket(self, now: float) -> None:
|
||||
if now > self.evaluation_time:
|
||||
self.bucket_time = min(
|
||||
self.bucket_time + (now - self.evaluation_time), self.bucket_max
|
||||
)
|
||||
self.evaluation_time = now
|
||||
|
||||
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None:
|
||||
pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND)
|
||||
self.packet_time = max(
|
||||
K_MICRO_SECOND, min(self._max_datagram_size / pacing_rate, K_SECOND)
|
||||
)
|
||||
|
||||
self.bucket_max = (
|
||||
max(
|
||||
2 * self._max_datagram_size,
|
||||
min(congestion_window // 4, 16 * self._max_datagram_size),
|
||||
)
|
||||
/ pacing_rate
|
||||
)
|
||||
if self.bucket_time > self.bucket_max:
|
||||
self.bucket_time = self.bucket_max
|
||||
|
||||
|
||||
class QuicPacketRecovery:
|
||||
"""
|
||||
Packet loss and congestion controller.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
congestion_control_algorithm: str,
|
||||
initial_rtt: float,
|
||||
max_datagram_size: int,
|
||||
peer_completed_address_validation: bool,
|
||||
send_probe: Callable[[], None],
|
||||
logger: Optional[logging.LoggerAdapter] = None,
|
||||
quic_logger: Optional[QuicLoggerTrace] = None,
|
||||
) -> None:
|
||||
self.max_ack_delay = 0.025
|
||||
self.peer_completed_address_validation = peer_completed_address_validation
|
||||
self.spaces: List[QuicPacketSpace] = []
|
||||
|
||||
# callbacks
|
||||
self._logger = logger
|
||||
self._quic_logger = quic_logger
|
||||
self._send_probe = send_probe
|
||||
|
||||
# loss detection
|
||||
self._pto_count = 0
|
||||
self._rtt_initial = initial_rtt
|
||||
self._rtt_initialized = False
|
||||
self._rtt_latest = 0.0
|
||||
self._rtt_min = math.inf
|
||||
self._rtt_smoothed = 0.0
|
||||
self._rtt_variance = 0.0
|
||||
self._time_of_last_sent_ack_eliciting_packet = 0.0
|
||||
|
||||
# congestion control
|
||||
self._cc = create_congestion_control(
|
||||
congestion_control_algorithm, max_datagram_size=max_datagram_size
|
||||
)
|
||||
self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size)
|
||||
|
||||
@property
|
||||
def bytes_in_flight(self) -> int:
|
||||
return self._cc.bytes_in_flight
|
||||
|
||||
@property
|
||||
def congestion_window(self) -> int:
|
||||
return self._cc.congestion_window
|
||||
|
||||
def discard_space(self, space: QuicPacketSpace) -> None:
|
||||
assert space in self.spaces
|
||||
|
||||
self._cc.on_packets_expired(
|
||||
packets=filter(lambda x: x.in_flight, space.sent_packets.values())
|
||||
)
|
||||
space.sent_packets.clear()
|
||||
|
||||
space.ack_at = None
|
||||
space.ack_eliciting_in_flight = 0
|
||||
space.loss_time = None
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def get_loss_detection_time(self) -> float:
|
||||
# loss timer
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
return loss_space.loss_time
|
||||
|
||||
# packet timer
|
||||
if (
|
||||
not self.peer_completed_address_validation
|
||||
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
|
||||
):
|
||||
timeout = self.get_probe_timeout() * (2**self._pto_count)
|
||||
return self._time_of_last_sent_ack_eliciting_packet + timeout
|
||||
|
||||
return None
|
||||
|
||||
def get_probe_timeout(self) -> float:
|
||||
if not self._rtt_initialized:
|
||||
return 2 * self._rtt_initial
|
||||
return (
|
||||
self._rtt_smoothed
|
||||
+ max(4 * self._rtt_variance, K_GRANULARITY)
|
||||
+ self.max_ack_delay
|
||||
)
|
||||
|
||||
def on_ack_received(
|
||||
self,
|
||||
*,
|
||||
ack_rangeset: RangeSet,
|
||||
ack_delay: float,
|
||||
now: float,
|
||||
space: QuicPacketSpace,
|
||||
) -> None:
|
||||
"""
|
||||
Update metrics as the result of an ACK being received.
|
||||
"""
|
||||
is_ack_eliciting = False
|
||||
largest_acked = ack_rangeset.bounds().stop - 1
|
||||
largest_newly_acked = None
|
||||
largest_sent_time = None
|
||||
|
||||
if largest_acked > space.largest_acked_packet:
|
||||
space.largest_acked_packet = largest_acked
|
||||
|
||||
for packet_number in sorted(space.sent_packets.keys()):
|
||||
if packet_number > largest_acked:
|
||||
break
|
||||
if packet_number in ack_rangeset:
|
||||
# remove packet and update counters
|
||||
packet = space.sent_packets.pop(packet_number)
|
||||
if packet.is_ack_eliciting:
|
||||
is_ack_eliciting = True
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
if packet.in_flight:
|
||||
self._cc.on_packet_acked(packet=packet, now=now)
|
||||
largest_newly_acked = packet_number
|
||||
largest_sent_time = packet.sent_time
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.ACKED, *args)
|
||||
|
||||
# nothing to do if there are no newly acked packets
|
||||
if largest_newly_acked is None:
|
||||
return
|
||||
|
||||
if largest_acked == largest_newly_acked and is_ack_eliciting:
|
||||
latest_rtt = now - largest_sent_time
|
||||
log_rtt = True
|
||||
|
||||
# limit ACK delay to max_ack_delay
|
||||
ack_delay = min(ack_delay, self.max_ack_delay)
|
||||
|
||||
# update RTT estimate, which cannot be < 1 ms
|
||||
self._rtt_latest = max(latest_rtt, 0.001)
|
||||
if self._rtt_latest < self._rtt_min:
|
||||
self._rtt_min = self._rtt_latest
|
||||
if self._rtt_latest > self._rtt_min + ack_delay:
|
||||
self._rtt_latest -= ack_delay
|
||||
|
||||
if not self._rtt_initialized:
|
||||
self._rtt_initialized = True
|
||||
self._rtt_variance = latest_rtt / 2
|
||||
self._rtt_smoothed = latest_rtt
|
||||
else:
|
||||
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
|
||||
self._rtt_min - self._rtt_latest
|
||||
)
|
||||
self._rtt_smoothed = (
|
||||
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
|
||||
)
|
||||
|
||||
# inform congestion controller
|
||||
self._cc.on_rtt_measurement(now=now, rtt=latest_rtt)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
|
||||
else:
|
||||
log_rtt = False
|
||||
|
||||
self._detect_loss(now=now, space=space)
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated(log_rtt=log_rtt)
|
||||
|
||||
def on_loss_detection_timeout(self, *, now: float) -> None:
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
self._detect_loss(now=now, space=loss_space)
|
||||
else:
|
||||
self._pto_count += 1
|
||||
self.reschedule_data(now=now)
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
|
||||
space.sent_packets[packet.packet_number] = packet
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight += 1
|
||||
if packet.in_flight:
|
||||
if packet.is_ack_eliciting:
|
||||
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
|
||||
|
||||
# add packet to bytes in flight
|
||||
self._cc.on_packet_sent(packet=packet)
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def reschedule_data(self, *, now: float) -> None:
|
||||
"""
|
||||
Schedule some data for retransmission.
|
||||
"""
|
||||
# if there is any outstanding CRYPTO, retransmit it
|
||||
crypto_scheduled = False
|
||||
for space in self.spaces:
|
||||
packets = tuple(
|
||||
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
|
||||
)
|
||||
if packets:
|
||||
self._on_packets_lost(now=now, packets=packets, space=space)
|
||||
crypto_scheduled = True
|
||||
if crypto_scheduled and self._logger is not None:
|
||||
self._logger.debug("Scheduled CRYPTO data for retransmission")
|
||||
|
||||
# ensure an ACK-elliciting packet is sent
|
||||
self._send_probe()
|
||||
|
||||
def _detect_loss(self, *, now: float, space: QuicPacketSpace) -> None:
|
||||
"""
|
||||
Check whether any packets should be declared lost.
|
||||
"""
|
||||
loss_delay = K_TIME_THRESHOLD * (
|
||||
max(self._rtt_latest, self._rtt_smoothed)
|
||||
if self._rtt_initialized
|
||||
else self._rtt_initial
|
||||
)
|
||||
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
|
||||
time_threshold = now - loss_delay
|
||||
|
||||
lost_packets = []
|
||||
space.loss_time = None
|
||||
for packet_number, packet in space.sent_packets.items():
|
||||
if packet_number > space.largest_acked_packet:
|
||||
break
|
||||
|
||||
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
|
||||
lost_packets.append(packet)
|
||||
else:
|
||||
packet_loss_time = packet.sent_time + loss_delay
|
||||
if space.loss_time is None or space.loss_time > packet_loss_time:
|
||||
space.loss_time = packet_loss_time
|
||||
|
||||
self._on_packets_lost(now=now, packets=lost_packets, space=space)
|
||||
|
||||
def _get_loss_space(self) -> Optional[QuicPacketSpace]:
|
||||
loss_space = None
|
||||
for space in self.spaces:
|
||||
if space.loss_time is not None and (
|
||||
loss_space is None or space.loss_time < loss_space.loss_time
|
||||
):
|
||||
loss_space = space
|
||||
return loss_space
|
||||
|
||||
def _log_metrics_updated(self, log_rtt=False) -> None:
|
||||
data: Dict[str, Any] = self._cc.get_log_data()
|
||||
|
||||
if log_rtt:
|
||||
data.update(
|
||||
{
|
||||
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
|
||||
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
|
||||
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
|
||||
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
|
||||
}
|
||||
)
|
||||
|
||||
self._quic_logger.log_event(
|
||||
category="recovery", event="metrics_updated", data=data
|
||||
)
|
||||
|
||||
def _on_packets_lost(
|
||||
self, *, now: float, packets: Iterable[QuicSentPacket], space: QuicPacketSpace
|
||||
) -> None:
|
||||
lost_packets_cc = []
|
||||
for packet in packets:
|
||||
del space.sent_packets[packet.packet_number]
|
||||
|
||||
if packet.in_flight:
|
||||
lost_packets_cc.append(packet)
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._quic_logger.log_event(
|
||||
category="recovery",
|
||||
event="packet_lost",
|
||||
data={
|
||||
"type": self._quic_logger.packet_type(packet.packet_type),
|
||||
"packet_number": packet.packet_number,
|
||||
},
|
||||
)
|
||||
self._log_metrics_updated()
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.LOST, *args)
|
||||
|
||||
# inform congestion controller
|
||||
if lost_packets_cc:
|
||||
self._cc.on_packets_lost(now=now, packets=lost_packets_cc)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
53
Code/venv/lib/python3.13/site-packages/aioquic/quic/retry.py
Normal file
53
Code/venv/lib/python3.13/site-packages/aioquic/quic/retry.py
Normal file
@ -0,0 +1,53 @@
|
||||
import ipaddress
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from ..buffer import Buffer
|
||||
from ..tls import pull_opaque, push_opaque
|
||||
from .connection import NetworkAddress
|
||||
|
||||
|
||||
def encode_address(addr: NetworkAddress) -> bytes:
|
||||
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
|
||||
|
||||
|
||||
class QuicRetryTokenHandler:
|
||||
def __init__(self) -> None:
|
||||
self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
addr: NetworkAddress,
|
||||
original_destination_connection_id: bytes,
|
||||
retry_source_connection_id: bytes,
|
||||
) -> bytes:
|
||||
buf = Buffer(capacity=512)
|
||||
push_opaque(buf, 1, encode_address(addr))
|
||||
push_opaque(buf, 1, original_destination_connection_id)
|
||||
push_opaque(buf, 1, retry_source_connection_id)
|
||||
return self._key.public_key().encrypt(
|
||||
buf.data,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
|
||||
),
|
||||
)
|
||||
|
||||
def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]:
|
||||
buf = Buffer(
|
||||
data=self._key.decrypt(
|
||||
token,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
encoded_addr = pull_opaque(buf, 1)
|
||||
original_destination_connection_id = pull_opaque(buf, 1)
|
||||
retry_source_connection_id = pull_opaque(buf, 1)
|
||||
if encoded_addr != encode_address(addr):
|
||||
raise ValueError("Remote address does not match.")
|
||||
return original_destination_connection_id, retry_source_connection_id
|
||||
364
Code/venv/lib/python3.13/site-packages/aioquic/quic/stream.py
Normal file
364
Code/venv/lib/python3.13/site-packages/aioquic/quic/stream.py
Normal file
@ -0,0 +1,364 @@
|
||||
from typing import Optional
|
||||
|
||||
from . import events
|
||||
from .packet import (
|
||||
QuicErrorCode,
|
||||
QuicResetStreamFrame,
|
||||
QuicStopSendingFrame,
|
||||
QuicStreamFrame,
|
||||
)
|
||||
from .packet_builder import QuicDeliveryState
|
||||
from .rangeset import RangeSet
|
||||
|
||||
|
||||
class FinalSizeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StreamFinishedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicStreamReceiver:
|
||||
"""
|
||||
The receive part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a send-only stream
|
||||
- upon reception of a STREAM_RESET frame
|
||||
- upon reception of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: Optional[int], readable: bool) -> None:
|
||||
self.highest_offset = 0 # the highest offset ever seen
|
||||
self.is_finished = False
|
||||
self.stop_pending = False
|
||||
|
||||
self._buffer = bytearray()
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._final_size: Optional[int] = None
|
||||
self._ranges = RangeSet()
|
||||
self._stream_id = stream_id
|
||||
self._stop_error_code: Optional[int] = None
|
||||
|
||||
def get_stop_frame(self) -> QuicStopSendingFrame:
|
||||
self.stop_pending = False
|
||||
return QuicStopSendingFrame(
|
||||
error_code=self._stop_error_code,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def starting_offset(self) -> int:
|
||||
return self._buffer_start
|
||||
|
||||
def handle_frame(
|
||||
self, frame: QuicStreamFrame
|
||||
) -> Optional[events.StreamDataReceived]:
|
||||
"""
|
||||
Handle a frame of received data.
|
||||
"""
|
||||
pos = frame.offset - self._buffer_start
|
||||
count = len(frame.data)
|
||||
frame_end = frame.offset + count
|
||||
|
||||
# we should receive no more data beyond FIN!
|
||||
if self._final_size is not None:
|
||||
if frame_end > self._final_size:
|
||||
raise FinalSizeError("Data received beyond final size")
|
||||
elif frame.fin and frame_end != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
if frame.fin:
|
||||
self._final_size = frame_end
|
||||
if frame_end > self.highest_offset:
|
||||
self.highest_offset = frame_end
|
||||
|
||||
# fast path: new in-order chunk
|
||||
if pos == 0 and count and not self._buffer:
|
||||
self._buffer_start += count
|
||||
if frame.fin:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
return events.StreamDataReceived(
|
||||
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
|
||||
)
|
||||
|
||||
# discard duplicate data
|
||||
if pos < 0:
|
||||
frame.data = frame.data[-pos:]
|
||||
frame.offset -= pos
|
||||
pos = 0
|
||||
count = len(frame.data)
|
||||
|
||||
# marked received range
|
||||
if frame_end > frame.offset:
|
||||
self._ranges.add(frame.offset, frame_end)
|
||||
|
||||
# add new data
|
||||
gap = pos - len(self._buffer)
|
||||
if gap > 0:
|
||||
self._buffer += bytearray(gap)
|
||||
self._buffer[pos : pos + count] = frame.data
|
||||
|
||||
# return data from the front of the buffer
|
||||
data = self._pull_data()
|
||||
end_stream = self._buffer_start == self._final_size
|
||||
if end_stream:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
if data or end_stream:
|
||||
return events.StreamDataReceived(
|
||||
data=data, end_stream=end_stream, stream_id=self._stream_id
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def handle_reset(
|
||||
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
|
||||
) -> Optional[events.StreamReset]:
|
||||
"""
|
||||
Handle an abrupt termination of the receiving part of the QUIC stream.
|
||||
"""
|
||||
if self._final_size is not None and final_size != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
|
||||
# we are done receiving
|
||||
self._final_size = final_size
|
||||
self.is_finished = True
|
||||
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
|
||||
|
||||
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a STOP_SENDING is ACK'd.
|
||||
"""
|
||||
if delivery != QuicDeliveryState.ACKED:
|
||||
self.stop_pending = True
|
||||
|
||||
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
|
||||
"""
|
||||
Request the peer stop sending data on the QUIC stream.
|
||||
"""
|
||||
self._stop_error_code = error_code
|
||||
self.stop_pending = True
|
||||
|
||||
def _pull_data(self) -> bytes:
|
||||
"""
|
||||
Remove data from the front of the buffer.
|
||||
"""
|
||||
try:
|
||||
has_data_to_read = self._ranges[0].start == self._buffer_start
|
||||
except IndexError:
|
||||
has_data_to_read = False
|
||||
if not has_data_to_read:
|
||||
return b""
|
||||
|
||||
r = self._ranges.shift()
|
||||
pos = r.stop - r.start
|
||||
data = bytes(self._buffer[:pos])
|
||||
del self._buffer[:pos]
|
||||
self._buffer_start = r.stop
|
||||
return data
|
||||
|
||||
|
||||
class QuicStreamSender:
|
||||
"""
|
||||
The send part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a receive-only stream
|
||||
- upon acknowledgement of a STREAM_RESET frame
|
||||
- upon acknowledgement of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: Optional[int], writable: bool) -> None:
|
||||
self.buffer_is_empty = True
|
||||
self.highest_offset = 0
|
||||
self.is_finished = not writable
|
||||
self.reset_pending = False
|
||||
|
||||
self._acked = RangeSet()
|
||||
self._acked_fin = False
|
||||
self._buffer = bytearray()
|
||||
self._buffer_fin: Optional[int] = None
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._buffer_stop = 0 # the offset for the stop of the buffer
|
||||
self._pending = RangeSet()
|
||||
self._pending_eof = False
|
||||
self._reset_error_code: Optional[int] = None
|
||||
self._stream_id = stream_id
|
||||
|
||||
@property
|
||||
def next_offset(self) -> int:
|
||||
"""
|
||||
The offset for the next frame to send.
|
||||
|
||||
This is used to determine the space needed for the frame's `offset` field.
|
||||
"""
|
||||
try:
|
||||
return self._pending[0].start
|
||||
except IndexError:
|
||||
return self._buffer_stop
|
||||
|
||||
def get_frame(
|
||||
self, max_size: int, max_offset: Optional[int] = None
|
||||
) -> Optional[QuicStreamFrame]:
|
||||
"""
|
||||
Get a frame of data to send.
|
||||
"""
|
||||
assert self._reset_error_code is None, "cannot call get_frame() after reset()"
|
||||
|
||||
# get the first pending data range
|
||||
try:
|
||||
r = self._pending[0]
|
||||
except IndexError:
|
||||
if self._pending_eof:
|
||||
# FIN only
|
||||
self._pending_eof = False
|
||||
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
|
||||
|
||||
self.buffer_is_empty = True
|
||||
return None
|
||||
|
||||
# apply flow control
|
||||
start = r.start
|
||||
stop = min(r.stop, start + max_size)
|
||||
if max_offset is not None and stop > max_offset:
|
||||
stop = max_offset
|
||||
if stop <= start:
|
||||
return None
|
||||
|
||||
# create frame
|
||||
frame = QuicStreamFrame(
|
||||
data=bytes(
|
||||
self._buffer[start - self._buffer_start : stop - self._buffer_start]
|
||||
),
|
||||
offset=start,
|
||||
)
|
||||
self._pending.subtract(start, stop)
|
||||
|
||||
# track the highest offset ever sent
|
||||
if stop > self.highest_offset:
|
||||
self.highest_offset = stop
|
||||
|
||||
# if the buffer is empty and EOF was written, set the FIN bit
|
||||
if self._buffer_fin == stop:
|
||||
frame.fin = True
|
||||
self._pending_eof = False
|
||||
|
||||
return frame
|
||||
|
||||
def get_reset_frame(self) -> QuicResetStreamFrame:
|
||||
self.reset_pending = False
|
||||
return QuicResetStreamFrame(
|
||||
error_code=self._reset_error_code,
|
||||
final_size=self.highest_offset,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def on_data_delivery(
|
||||
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
|
||||
) -> None:
|
||||
"""
|
||||
Callback when sent data is ACK'd.
|
||||
"""
|
||||
# If the frame had the FIN bit set, its end MUST match otherwise
|
||||
# we have a programming error.
|
||||
assert (
|
||||
not fin or stop == self._buffer_fin
|
||||
), "on_data_delivered() was called with inconsistent fin / stop"
|
||||
|
||||
# If a reset has been requested, stop processing data delivery.
|
||||
# The transition to the finished state only depends on the reset
|
||||
# being acknowledged.
|
||||
if self._reset_error_code is not None:
|
||||
return
|
||||
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
if stop > start:
|
||||
# Some data has been ACK'd, discard it.
|
||||
self._acked.add(start, stop)
|
||||
first_range = self._acked[0]
|
||||
if first_range.start == self._buffer_start:
|
||||
size = first_range.stop - first_range.start
|
||||
self._acked.shift()
|
||||
self._buffer_start += size
|
||||
del self._buffer[:size]
|
||||
|
||||
if fin:
|
||||
# The FIN has been ACK'd.
|
||||
self._acked_fin = True
|
||||
|
||||
if self._buffer_start == self._buffer_fin and self._acked_fin:
|
||||
# All data and the FIN have been ACK'd, we're done sending.
|
||||
self.is_finished = True
|
||||
else:
|
||||
if stop > start:
|
||||
# Some data has been lost, reschedule it.
|
||||
self.buffer_is_empty = False
|
||||
self._pending.add(start, stop)
|
||||
|
||||
if fin:
|
||||
# The FIN has been lost, reschedule it.
|
||||
self.buffer_is_empty = False
|
||||
self._pending_eof = True
|
||||
|
||||
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a reset is ACK'd.
|
||||
"""
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
# The reset has been ACK'd, we're done sending.
|
||||
self.is_finished = True
|
||||
else:
|
||||
# The reset has been lost, reschedule it.
|
||||
self.reset_pending = True
|
||||
|
||||
def reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Abruptly terminate the sending part of the QUIC stream.
|
||||
"""
|
||||
assert self._reset_error_code is None, "cannot call reset() more than once"
|
||||
self._reset_error_code = error_code
|
||||
self.reset_pending = True
|
||||
|
||||
# Prevent any more data from being sent or re-sent.
|
||||
self.buffer_is_empty = True
|
||||
|
||||
def write(self, data: bytes, end_stream: bool = False) -> None:
|
||||
"""
|
||||
Write some data bytes to the QUIC stream.
|
||||
"""
|
||||
assert self._buffer_fin is None, "cannot call write() after FIN"
|
||||
assert self._reset_error_code is None, "cannot call write() after reset()"
|
||||
size = len(data)
|
||||
|
||||
if size:
|
||||
self.buffer_is_empty = False
|
||||
self._pending.add(self._buffer_stop, self._buffer_stop + size)
|
||||
self._buffer += data
|
||||
self._buffer_stop += size
|
||||
if end_stream:
|
||||
self.buffer_is_empty = False
|
||||
self._buffer_fin = self._buffer_stop
|
||||
self._pending_eof = True
|
||||
|
||||
|
||||
class QuicStream:
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: Optional[int] = None,
|
||||
max_stream_data_local: int = 0,
|
||||
max_stream_data_remote: int = 0,
|
||||
readable: bool = True,
|
||||
writable: bool = True,
|
||||
) -> None:
|
||||
self.is_blocked = False
|
||||
self.max_stream_data_local = max_stream_data_local
|
||||
self.max_stream_data_local_sent = max_stream_data_local
|
||||
self.max_stream_data_remote = max_stream_data_remote
|
||||
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
|
||||
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
|
||||
self.stream_id = stream_id
|
||||
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
return self.receiver.is_finished and self.sender.is_finished
|
||||
Reference in New Issue
Block a user