Files
mofixx a5df3861fd Code
2025-08-08 10:41:30 +02:00

3624 lines
136 KiB
Python

import binascii
import logging
import os
from collections import deque
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import (
Any,
Callable,
Deque,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Set,
Tuple,
)
from .. import tls
from ..buffer import (
UINT_VAR_MAX,
UINT_VAR_MAX_SIZE,
Buffer,
BufferReadError,
size_uint_var,
)
from . import events
from .configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
from .congestion.base import K_GRANULARITY
from .crypto import CryptoError, CryptoPair, KeyUnavailableError, NoCallback
from .logger import QuicLoggerTrace
from .packet import (
CONNECTION_ID_MAX_SIZE,
NON_ACK_ELICITING_FRAME_TYPES,
PROBING_FRAME_TYPES,
RETRY_INTEGRITY_TAG_SIZE,
STATELESS_RESET_TOKEN_SIZE,
QuicErrorCode,
QuicFrameType,
QuicHeader,
QuicPacketType,
QuicProtocolVersion,
QuicStreamFrame,
QuicTransportParameters,
QuicVersionInformation,
get_retry_integrity_tag,
get_spin_bit,
pretty_protocol_version,
pull_ack_frame,
pull_quic_header,
pull_quic_transport_parameters,
push_ack_frame,
push_quic_transport_parameters,
)
from .packet_builder import QuicDeliveryState, QuicPacketBuilder, QuicPacketBuilderStop
from .recovery import QuicPacketRecovery, QuicPacketSpace
from .stream import FinalSizeError, QuicStream, StreamFinishedError
logger = logging.getLogger("quic")
CRYPTO_BUFFER_SIZE = 16384
EPOCH_SHORTCUTS = {
"I": tls.Epoch.INITIAL,
"H": tls.Epoch.HANDSHAKE,
"0": tls.Epoch.ZERO_RTT,
"1": tls.Epoch.ONE_RTT,
}
MAX_EARLY_DATA = 0xFFFFFFFF
MAX_REMOTE_CHALLENGES = 5
MAX_LOCAL_CHALLENGES = 5
SECRETS_LABELS = [
[
None,
"CLIENT_EARLY_TRAFFIC_SECRET",
"CLIENT_HANDSHAKE_TRAFFIC_SECRET",
"CLIENT_TRAFFIC_SECRET_0",
],
[
None,
None,
"SERVER_HANDSHAKE_TRAFFIC_SECRET",
"SERVER_TRAFFIC_SECRET_0",
],
]
STREAM_FLAGS = 0x07
STREAM_COUNT_MAX = 0x1000000000000000
UDP_HEADER_SIZE = 8
MAX_PENDING_RETIRES = 100
MAX_PENDING_CRYPTO = 524288 # in bytes
NetworkAddress = Any
# frame sizes
ACK_FRAME_CAPACITY = 64 # FIXME: this is arbitrary!
APPLICATION_CLOSE_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE # + reason length
CONNECTION_LIMIT_FRAME_CAPACITY = 1 + UINT_VAR_MAX_SIZE
HANDSHAKE_DONE_FRAME_CAPACITY = 1
MAX_STREAM_DATA_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE
NEW_CONNECTION_ID_FRAME_CAPACITY = (
1 + 2 * UINT_VAR_MAX_SIZE + 1 + CONNECTION_ID_MAX_SIZE + STATELESS_RESET_TOKEN_SIZE
)
PATH_CHALLENGE_FRAME_CAPACITY = 1 + 8
PATH_RESPONSE_FRAME_CAPACITY = 1 + 8
PING_FRAME_CAPACITY = 1
RESET_STREAM_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE
RETIRE_CONNECTION_ID_CAPACITY = 1 + UINT_VAR_MAX_SIZE
STOP_SENDING_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE
STREAMS_BLOCKED_CAPACITY = 1 + UINT_VAR_MAX_SIZE
TRANSPORT_CLOSE_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE # + reason length
def EPOCHS(shortcut: str) -> FrozenSet[tls.Epoch]:
return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut)
def is_version_compatible(from_version: int, to_version: int) -> bool:
"""
Return whether it is possible to perform compatible version negotiation
from `from_version` to `to_version`.
"""
# Version 1 is compatible with version 2 and vice versa. These are the
# only compatible versions so far.
return set([from_version, to_version]) == set(
[QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2]
)
def dump_cid(cid: bytes) -> str:
return binascii.hexlify(cid).decode("ascii")
def get_epoch(packet_type: QuicPacketType) -> tls.Epoch:
if packet_type == QuicPacketType.INITIAL:
return tls.Epoch.INITIAL
elif packet_type == QuicPacketType.ZERO_RTT:
return tls.Epoch.ZERO_RTT
elif packet_type == QuicPacketType.HANDSHAKE:
return tls.Epoch.HANDSHAKE
else:
return tls.Epoch.ONE_RTT
def stream_is_client_initiated(stream_id: int) -> bool:
"""
Returns True if the stream is client initiated.
"""
return not (stream_id & 1)
def stream_is_unidirectional(stream_id: int) -> bool:
"""
Returns True if the stream is unidirectional.
"""
return bool(stream_id & 2)
class Limit:
def __init__(self, frame_type: int, name: str, value: int):
self.frame_type = frame_type
self.name = name
self.sent = value
self.used = 0
self.value = value
class QuicConnectionError(Exception):
def __init__(self, error_code: int, frame_type: int, reason_phrase: str):
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase
def __str__(self) -> str:
s = "Error: %d, reason: %s" % (self.error_code, self.reason_phrase)
if self.frame_type is not None:
s += ", frame_type: %s" % self.frame_type
return s
class QuicConnectionAdapter(logging.LoggerAdapter):
def process(self, msg: str, kwargs: Any) -> Tuple[str, Any]:
return "[%s] %s" % (self.extra["id"], msg), kwargs
@dataclass
class QuicConnectionId:
cid: bytes
sequence_number: int
stateless_reset_token: bytes = b""
was_sent: bool = False
class QuicConnectionState(Enum):
FIRSTFLIGHT = 0
CONNECTED = 1
CLOSING = 2
DRAINING = 3
TERMINATED = 4
class QuicNetworkPath:
def __init__(self, addr: NetworkAddress, is_validated: bool = False):
self.addr: NetworkAddress = addr
self.bytes_received: int = 0
self.bytes_sent: int = 0
self.is_validated: bool = is_validated
self.local_challenge_sent: bool = False
self.remote_challenges: Deque[bytes] = deque()
def can_send(self, size: int) -> bool:
return self.is_validated or (self.bytes_sent + size) <= 3 * self.bytes_received
@dataclass
class QuicReceiveContext:
epoch: tls.Epoch
host_cid: bytes
network_path: QuicNetworkPath
quic_logger_frames: Optional[List[Any]]
time: float
version: Optional[int]
QuicTokenHandler = Callable[[bytes], None]
END_STATES = frozenset(
[
QuicConnectionState.CLOSING,
QuicConnectionState.DRAINING,
QuicConnectionState.TERMINATED,
]
)
class QuicConnection:
"""
A QUIC connection.
The state machine is driven by three kinds of sources:
- the API user requesting data to be send out (see :meth:`connect`,
:meth:`reset_stream`, :meth:`send_ping`, :meth:`send_datagram_frame`
and :meth:`send_stream_data`)
- data being received from the network (see :meth:`receive_datagram`)
- a timer firing (see :meth:`handle_timer`)
:param configuration: The QUIC configuration to use.
"""
def __init__(
self,
*,
configuration: QuicConfiguration,
original_destination_connection_id: Optional[bytes] = None,
retry_source_connection_id: Optional[bytes] = None,
session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None,
session_ticket_handler: Optional[tls.SessionTicketHandler] = None,
token_handler: Optional[QuicTokenHandler] = None,
) -> None:
assert configuration.max_datagram_size >= SMALLEST_MAX_DATAGRAM_SIZE, (
"The smallest allowed maximum datagram size is "
f"{SMALLEST_MAX_DATAGRAM_SIZE} bytes"
)
if configuration.is_client:
assert (
original_destination_connection_id is None
), "Cannot set original_destination_connection_id for a client"
assert (
retry_source_connection_id is None
), "Cannot set retry_source_connection_id for a client"
else:
assert token_handler is None, "Cannot set `token_handler` for a server"
assert (
configuration.token == b""
), "Cannot set `configuration.token` for a server"
assert (
configuration.certificate is not None
), "SSL certificate is required for a server"
assert (
configuration.private_key is not None
), "SSL private key is required for a server"
assert (
original_destination_connection_id is not None
), "original_destination_connection_id is required for a server"
# configuration
self._configuration = configuration
self._is_client = configuration.is_client
self._ack_delay = K_GRANULARITY
self._close_at: Optional[float] = None
self._close_event: Optional[events.ConnectionTerminated] = None
self._connect_called = False
self._cryptos: Dict[tls.Epoch, CryptoPair] = {}
self._cryptos_initial: Dict[int, CryptoPair] = {}
self._crypto_buffers: Dict[tls.Epoch, Buffer] = {}
self._crypto_frame_type: Optional[int] = None
self._crypto_packet_version: Optional[int] = None
self._crypto_retransmitted = False
self._crypto_streams: Dict[tls.Epoch, QuicStream] = {}
self._events: Deque[events.QuicEvent] = deque()
self._handshake_complete = False
self._handshake_confirmed = False
self._host_cids = [
QuicConnectionId(
cid=os.urandom(configuration.connection_id_length),
sequence_number=0,
stateless_reset_token=os.urandom(16) if not self._is_client else None,
was_sent=True,
)
]
self.host_cid = self._host_cids[0].cid
self._host_cid_seq = 1
self._local_ack_delay_exponent = 3
self._local_active_connection_id_limit = 8
self._local_challenges: Dict[bytes, QuicNetworkPath] = {}
self._local_initial_source_connection_id = self._host_cids[0].cid
self._local_max_data = Limit(
frame_type=QuicFrameType.MAX_DATA,
name="max_data",
value=configuration.max_data,
)
self._local_max_stream_data_bidi_local = configuration.max_stream_data
self._local_max_stream_data_bidi_remote = configuration.max_stream_data
self._local_max_stream_data_uni = configuration.max_stream_data
self._local_max_streams_bidi = Limit(
frame_type=QuicFrameType.MAX_STREAMS_BIDI,
name="max_streams_bidi",
value=128,
)
self._local_max_streams_uni = Limit(
frame_type=QuicFrameType.MAX_STREAMS_UNI, name="max_streams_uni", value=128
)
self._local_next_stream_id_bidi = 0 if self._is_client else 1
self._local_next_stream_id_uni = 2 if self._is_client else 3
self._loss_at: Optional[float] = None
self._max_datagram_size = configuration.max_datagram_size
self._network_paths: List[QuicNetworkPath] = []
self._pacing_at: Optional[float] = None
self._packet_number = 0
self._peer_cid = QuicConnectionId(
cid=os.urandom(configuration.connection_id_length), sequence_number=None
)
self._peer_cid_available: List[QuicConnectionId] = []
self._peer_cid_sequence_numbers: Set[int] = set([0])
self._peer_retire_prior_to = 0
self._peer_token = configuration.token
self._quic_logger: Optional[QuicLoggerTrace] = None
self._remote_ack_delay_exponent = 3
self._remote_active_connection_id_limit = 2
self._remote_initial_source_connection_id: Optional[bytes] = None
self._remote_max_idle_timeout: Optional[float] = None # seconds
self._remote_max_data = 0
self._remote_max_data_used = 0
self._remote_max_datagram_frame_size: Optional[int] = None
self._remote_max_stream_data_bidi_local = 0
self._remote_max_stream_data_bidi_remote = 0
self._remote_max_stream_data_uni = 0
self._remote_max_streams_bidi = 0
self._remote_max_streams_uni = 0
self._remote_version_information: Optional[QuicVersionInformation] = None
self._retry_count = 0
self._retry_source_connection_id = retry_source_connection_id
self._spaces: Dict[tls.Epoch, QuicPacketSpace] = {}
self._spin_bit = False
self._spin_highest_pn = 0
self._state = QuicConnectionState.FIRSTFLIGHT
self._streams: Dict[int, QuicStream] = {}
self._streams_queue: List[QuicStream] = []
self._streams_blocked_bidi: List[QuicStream] = []
self._streams_blocked_uni: List[QuicStream] = []
self._streams_finished: Set[int] = set()
self._version: Optional[int] = None
self._version_negotiated_compatible = False
self._version_negotiated_incompatible = False
if self._is_client:
self._original_destination_connection_id = self._peer_cid.cid
else:
self._original_destination_connection_id = (
original_destination_connection_id
)
# logging
self._logger = QuicConnectionAdapter(
logger, {"id": dump_cid(self._original_destination_connection_id)}
)
if configuration.quic_logger:
self._quic_logger = configuration.quic_logger.start_trace(
is_client=configuration.is_client,
odcid=self._original_destination_connection_id,
)
# loss recovery
self._loss = QuicPacketRecovery(
congestion_control_algorithm=configuration.congestion_control_algorithm,
initial_rtt=configuration.initial_rtt,
max_datagram_size=self._max_datagram_size,
peer_completed_address_validation=not self._is_client,
quic_logger=self._quic_logger,
send_probe=self._send_probe,
logger=self._logger,
)
# things to send
self._close_pending = False
self._datagrams_pending: Deque[bytes] = deque()
self._handshake_done_pending = False
self._ping_pending: List[int] = []
self._probe_pending = False
self._retire_connection_ids: List[int] = []
self._streams_blocked_pending = False
# callbacks
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._token_handler = token_handler
# frame handlers
self.__frame_handlers = {
0x00: (self._handle_padding_frame, EPOCHS("IH01")),
0x01: (self._handle_ping_frame, EPOCHS("IH01")),
0x02: (self._handle_ack_frame, EPOCHS("IH1")),
0x03: (self._handle_ack_frame, EPOCHS("IH1")),
0x04: (self._handle_reset_stream_frame, EPOCHS("01")),
0x05: (self._handle_stop_sending_frame, EPOCHS("01")),
0x06: (self._handle_crypto_frame, EPOCHS("IH1")),
0x07: (self._handle_new_token_frame, EPOCHS("1")),
0x08: (self._handle_stream_frame, EPOCHS("01")),
0x09: (self._handle_stream_frame, EPOCHS("01")),
0x0A: (self._handle_stream_frame, EPOCHS("01")),
0x0B: (self._handle_stream_frame, EPOCHS("01")),
0x0C: (self._handle_stream_frame, EPOCHS("01")),
0x0D: (self._handle_stream_frame, EPOCHS("01")),
0x0E: (self._handle_stream_frame, EPOCHS("01")),
0x0F: (self._handle_stream_frame, EPOCHS("01")),
0x10: (self._handle_max_data_frame, EPOCHS("01")),
0x11: (self._handle_max_stream_data_frame, EPOCHS("01")),
0x12: (self._handle_max_streams_bidi_frame, EPOCHS("01")),
0x13: (self._handle_max_streams_uni_frame, EPOCHS("01")),
0x14: (self._handle_data_blocked_frame, EPOCHS("01")),
0x15: (self._handle_stream_data_blocked_frame, EPOCHS("01")),
0x16: (self._handle_streams_blocked_frame, EPOCHS("01")),
0x17: (self._handle_streams_blocked_frame, EPOCHS("01")),
0x18: (self._handle_new_connection_id_frame, EPOCHS("01")),
0x19: (self._handle_retire_connection_id_frame, EPOCHS("01")),
0x1A: (self._handle_path_challenge_frame, EPOCHS("01")),
0x1B: (self._handle_path_response_frame, EPOCHS("01")),
0x1C: (self._handle_connection_close_frame, EPOCHS("IH01")),
0x1D: (self._handle_connection_close_frame, EPOCHS("01")),
0x1E: (self._handle_handshake_done_frame, EPOCHS("1")),
0x30: (self._handle_datagram_frame, EPOCHS("01")),
0x31: (self._handle_datagram_frame, EPOCHS("01")),
}
@property
def configuration(self) -> QuicConfiguration:
return self._configuration
@property
def original_destination_connection_id(self) -> bytes:
return self._original_destination_connection_id
def change_connection_id(self) -> None:
"""
Switch to the next available connection ID and retire
the previous one.
.. aioquic_transmit::
"""
if self._peer_cid_available:
# retire previous CID
self._retire_peer_cid(self._peer_cid)
# assign new CID
self._consume_peer_cid()
def close(
self,
error_code: int = QuicErrorCode.NO_ERROR,
frame_type: Optional[int] = None,
reason_phrase: str = "",
) -> None:
"""
Close the connection.
.. aioquic_transmit::
:param error_code: An error code indicating why the connection is
being closed.
:param reason_phrase: A human-readable explanation of why the
connection is being closed.
"""
if self._close_event is None and self._state not in END_STATES:
self._close_event = events.ConnectionTerminated(
error_code=error_code,
frame_type=frame_type,
reason_phrase=reason_phrase,
)
self._close_pending = True
def connect(self, addr: NetworkAddress, now: float) -> None:
"""
Initiate the TLS handshake.
This method can only be called for clients and a single time.
.. aioquic_transmit::
:param addr: The network address of the remote peer.
:param now: The current time.
"""
assert (
self._is_client and not self._connect_called
), "connect() can only be called for clients and a single time"
self._connect_called = True
self._network_paths = [QuicNetworkPath(addr, is_validated=True)]
if self._configuration.original_version is not None:
self._version = self._configuration.original_version
else:
self._version = self._configuration.supported_versions[0]
self._connect(now=now)
def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]:
"""
Return a list of `(data, addr)` tuples of datagrams which need to be
sent, and the network address to which they need to be sent.
After calling this method call :meth:`get_timer` to know when the next
timer needs to be set.
:param now: The current time.
"""
network_path = self._network_paths[0]
if self._state in END_STATES:
return []
# build datagrams
builder = QuicPacketBuilder(
host_cid=self.host_cid,
is_client=self._is_client,
max_datagram_size=self._max_datagram_size,
packet_number=self._packet_number,
peer_cid=self._peer_cid.cid,
peer_token=self._peer_token,
quic_logger=self._quic_logger,
spin_bit=self._spin_bit,
version=self._version,
)
if self._close_pending:
epoch_packet_types = []
if not self._handshake_confirmed:
epoch_packet_types += [
(tls.Epoch.INITIAL, QuicPacketType.INITIAL),
(tls.Epoch.HANDSHAKE, QuicPacketType.HANDSHAKE),
]
epoch_packet_types.append((tls.Epoch.ONE_RTT, QuicPacketType.ONE_RTT))
for epoch, packet_type in epoch_packet_types:
crypto = self._cryptos[epoch]
if crypto.send.is_valid():
builder.start_packet(packet_type, crypto)
self._write_connection_close_frame(
builder=builder,
epoch=epoch,
error_code=self._close_event.error_code,
frame_type=self._close_event.frame_type,
reason_phrase=self._close_event.reason_phrase,
)
self._logger.info(
"Connection close sent (code 0x%X, reason %s)",
self._close_event.error_code,
self._close_event.reason_phrase,
)
self._close_pending = False
self._close_begin(is_initiator=True, now=now)
else:
# congestion control
builder.max_flight_bytes = (
self._loss.congestion_window - self._loss.bytes_in_flight
)
if (
self._probe_pending
and builder.max_flight_bytes < self._max_datagram_size
):
builder.max_flight_bytes = self._max_datagram_size
# limit data on un-validated network paths
if not network_path.is_validated:
builder.max_total_bytes = (
network_path.bytes_received * 3 - network_path.bytes_sent
)
try:
if not self._handshake_confirmed:
for epoch in [tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE]:
self._write_handshake(builder, epoch, now)
self._write_application(builder, network_path, now)
except QuicPacketBuilderStop:
pass
datagrams, packets = builder.flush()
if datagrams:
self._packet_number = builder.packet_number
# register packets
sent_handshake = False
for packet in packets:
packet.sent_time = now
self._loss.on_packet_sent(
packet=packet, space=self._spaces[packet.epoch]
)
if packet.epoch == tls.Epoch.HANDSHAKE:
sent_handshake = True
# log packet
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_sent",
data={
"frames": packet.quic_logger_frames,
"header": {
"packet_number": packet.packet_number,
"packet_type": self._quic_logger.packet_type(
packet.packet_type
),
"scid": (
""
if packet.packet_type == QuicPacketType.ONE_RTT
else dump_cid(self.host_cid)
),
"dcid": dump_cid(self._peer_cid.cid),
},
"raw": {"length": packet.sent_bytes},
},
)
# check if we can discard initial keys
if sent_handshake and self._is_client:
self._discard_epoch(tls.Epoch.INITIAL)
# return datagrams to send and the destination network address
ret = []
for datagram in datagrams:
payload_length = len(datagram)
network_path.bytes_sent += payload_length
ret.append((datagram, network_path.addr))
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="datagrams_sent",
data={
"count": 1,
"raw": [
{
"length": UDP_HEADER_SIZE + payload_length,
"payload_length": payload_length,
}
],
},
)
return ret
def get_next_available_stream_id(self, is_unidirectional=False) -> int:
"""
Return the stream ID for the next stream created by this endpoint.
"""
if is_unidirectional:
return self._local_next_stream_id_uni
else:
return self._local_next_stream_id_bidi
def get_timer(self) -> Optional[float]:
"""
Return the time at which the timer should fire or None if no timer is needed.
"""
timer_at = self._close_at
if self._state not in END_STATES:
# ack timer
for space in self._loss.spaces:
if space.ack_at is not None and space.ack_at < timer_at:
timer_at = space.ack_at
# loss detection timer
self._loss_at = self._loss.get_loss_detection_time()
if self._loss_at is not None and self._loss_at < timer_at:
timer_at = self._loss_at
# pacing timer
if self._pacing_at is not None and self._pacing_at < timer_at:
timer_at = self._pacing_at
return timer_at
def handle_timer(self, now: float) -> None:
"""
Handle the timer.
.. aioquic_transmit::
:param now: The current time.
"""
# end of closing period or idle timeout
if now >= self._close_at:
if self._close_event is None:
self._close_event = events.ConnectionTerminated(
error_code=QuicErrorCode.INTERNAL_ERROR,
frame_type=QuicFrameType.PADDING,
reason_phrase="Idle timeout",
)
self._close_end()
return
# loss detection timeout
if self._loss_at is not None and now >= self._loss_at:
self._logger.debug("Loss detection triggered")
self._loss.on_loss_detection_timeout(now=now)
def next_event(self) -> Optional[events.QuicEvent]:
"""
Retrieve the next event from the event buffer.
Returns `None` if there are no buffered events.
"""
try:
return self._events.popleft()
except IndexError:
return None
def _idle_timeout(self) -> float:
# RFC 9000 section 10.1
# Start with our local timeout.
idle_timeout = self._configuration.idle_timeout
if self._remote_max_idle_timeout is not None:
# Our peer has a preference too, so pick the smaller timeout.
idle_timeout = min(idle_timeout, self._remote_max_idle_timeout)
# But not too small!
return max(idle_timeout, 3 * self._loss.get_probe_timeout())
def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> None:
"""
Handle an incoming datagram.
.. aioquic_transmit::
:param data: The datagram which was received.
:param addr: The network address from which the datagram was received.
:param now: The current time.
"""
payload_length = len(data)
# stop handling packets when closing
if self._state in END_STATES:
return
# log datagram
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="datagrams_received",
data={
"count": 1,
"raw": [
{
"length": UDP_HEADER_SIZE + payload_length,
"payload_length": payload_length,
}
],
},
)
# For anti-amplification purposes, servers need to keep track of the
# amount of data received on unvalidated network paths. We must count the
# entire datagram size regardless of whether packets are processed or
# dropped.
#
# This is particularly important when talking to clients who pad
# datagrams containing INITIAL packets by appending bytes after the
# long-header packets, which is legitimate behaviour.
#
# https://datatracker.ietf.org/doc/html/rfc9000#section-8.1
network_path = self._find_network_path(addr)
if not network_path.is_validated:
network_path.bytes_received += payload_length
# for servers, arm the idle timeout on the first datagram
if self._close_at is None:
self._close_at = now + self._idle_timeout()
buf = Buffer(data=data)
while not buf.eof():
start_off = buf.tell()
try:
header = pull_quic_header(
buf, host_cid_length=self._configuration.connection_id_length
)
except ValueError:
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "header_parse_error",
"raw": {"length": buf.capacity - start_off},
},
)
return
# RFC 9000 section 14.1 requires servers to drop all initial packets
# contained in a datagram smaller than 1200 bytes.
if (
not self._is_client
and header.packet_type == QuicPacketType.INITIAL
and payload_length < SMALLEST_MAX_DATAGRAM_SIZE
):
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "initial_packet_datagram_too_small",
"raw": {"length": header.packet_length},
},
)
return
# Check destination CID matches.
destination_cid_seq: Optional[int] = None
for connection_id in self._host_cids:
if header.destination_cid == connection_id.cid:
destination_cid_seq = connection_id.sequence_number
break
if (
self._is_client or header.packet_type == QuicPacketType.HANDSHAKE
) and destination_cid_seq is None:
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "unknown_connection_id",
"raw": {"length": header.packet_length},
},
)
return
# Handle version negotiation packet.
if header.packet_type == QuicPacketType.VERSION_NEGOTIATION:
self._receive_version_negotiation_packet(header=header, now=now)
return
# Check long header packet protocol version.
if (
header.version is not None
and header.version not in self._configuration.supported_versions
):
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "unsupported_version",
"raw": {"length": header.packet_length},
},
)
return
# Handle retry packet.
if header.packet_type == QuicPacketType.RETRY:
self._receive_retry_packet(
header=header,
packet_without_tag=buf.data_slice(
start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE
),
now=now,
)
return
crypto_frame_required = False
# Server initialization.
if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT:
assert (
header.packet_type == QuicPacketType.INITIAL
), "first packet must be INITIAL"
crypto_frame_required = True
self._network_paths = [network_path]
self._version = header.version
self._initialize(header.destination_cid)
# Determine crypto and packet space.
epoch = get_epoch(header.packet_type)
if epoch == tls.Epoch.INITIAL:
crypto = self._cryptos_initial[header.version]
else:
crypto = self._cryptos[epoch]
if epoch == tls.Epoch.ZERO_RTT:
space = self._spaces[tls.Epoch.ONE_RTT]
else:
space = self._spaces[epoch]
# decrypt packet
encrypted_off = buf.tell() - start_off
end_off = start_off + header.packet_length
buf.seek(end_off)
try:
plain_header, plain_payload, packet_number = crypto.decrypt_packet(
data[start_off:end_off], encrypted_off, space.expected_packet_number
)
except KeyUnavailableError as exc:
self._logger.debug(exc)
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "key_unavailable",
"raw": {"length": header.packet_length},
},
)
# If a client receives HANDSHAKE or 1-RTT packets before it has
# handshake keys, it can assume that the server's INITIAL was lost.
if (
self._is_client
and epoch in (tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT)
and not self._crypto_retransmitted
):
self._loss.reschedule_data(now=now)
self._crypto_retransmitted = True
continue
except CryptoError as exc:
self._logger.debug(exc)
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "payload_decrypt_error",
"raw": {"length": header.packet_length},
},
)
continue
# check reserved bits
if header.packet_type == QuicPacketType.ONE_RTT:
reserved_mask = 0x18
else:
reserved_mask = 0x0C
if plain_header[0] & reserved_mask:
self.close(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=QuicFrameType.PADDING,
reason_phrase="Reserved bits must be zero",
)
return
# log packet
quic_logger_frames: Optional[List[Dict]] = None
if self._quic_logger is not None:
quic_logger_frames = []
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": quic_logger_frames,
"header": {
"packet_number": packet_number,
"packet_type": self._quic_logger.packet_type(
header.packet_type
),
"dcid": dump_cid(header.destination_cid),
"scid": dump_cid(header.source_cid),
},
"raw": {"length": header.packet_length},
},
)
# raise expected packet number
if packet_number > space.expected_packet_number:
space.expected_packet_number = packet_number + 1
# discard initial keys and packet space
if not self._is_client and epoch == tls.Epoch.HANDSHAKE:
self._discard_epoch(tls.Epoch.INITIAL)
# update state
if self._peer_cid.sequence_number is None:
self._peer_cid.cid = header.source_cid
self._peer_cid.sequence_number = 0
if self._state == QuicConnectionState.FIRSTFLIGHT:
self._remote_initial_source_connection_id = header.source_cid
self._set_state(QuicConnectionState.CONNECTED)
# update spin bit
if (
header.packet_type == QuicPacketType.ONE_RTT
and packet_number > self._spin_highest_pn
):
spin_bit = get_spin_bit(plain_header[0])
if self._is_client:
self._spin_bit = not spin_bit
else:
self._spin_bit = spin_bit
self._spin_highest_pn = packet_number
if self._quic_logger is not None:
self._quic_logger.log_event(
category="connectivity",
event="spin_bit_updated",
data={"state": self._spin_bit},
)
# handle payload
context = QuicReceiveContext(
epoch=epoch,
host_cid=header.destination_cid,
network_path=network_path,
quic_logger_frames=quic_logger_frames,
time=now,
version=header.version,
)
try:
is_ack_eliciting, is_probing = self._payload_received(
context, plain_payload, crypto_frame_required=crypto_frame_required
)
except QuicConnectionError as exc:
self._logger.warning(exc)
self.close(
error_code=exc.error_code,
frame_type=exc.frame_type,
reason_phrase=exc.reason_phrase,
)
if self._state in END_STATES or self._close_pending:
return
# update idle timeout
self._close_at = now + self._idle_timeout()
# handle migration
if (
not self._is_client
and context.host_cid != self.host_cid
and epoch == tls.Epoch.ONE_RTT
):
self._logger.debug(
"Peer switching to CID %s (%d)",
dump_cid(context.host_cid),
destination_cid_seq,
)
self.host_cid = context.host_cid
self.change_connection_id()
# update network path
if not network_path.is_validated and epoch == tls.Epoch.HANDSHAKE:
self._logger.debug(
"Network path %s validated by handshake", network_path.addr
)
network_path.is_validated = True
if network_path not in self._network_paths:
self._network_paths.append(network_path)
idx = self._network_paths.index(network_path)
if idx and not is_probing and packet_number > space.largest_received_packet:
self._logger.debug("Network path %s promoted", network_path.addr)
self._network_paths.pop(idx)
self._network_paths.insert(0, network_path)
# record packet as received
if not space.discarded:
if packet_number > space.largest_received_packet:
space.largest_received_packet = packet_number
space.largest_received_time = now
space.ack_queue.add(packet_number)
if is_ack_eliciting and space.ack_at is None:
space.ack_at = now + self._ack_delay
def request_key_update(self) -> None:
"""
Request an update of the encryption keys.
.. aioquic_transmit::
"""
assert self._handshake_complete, "cannot change key before handshake completes"
self._cryptos[tls.Epoch.ONE_RTT].update_key()
def reset_stream(self, stream_id: int, error_code: int) -> None:
"""
Abruptly terminate the sending part of a stream.
.. aioquic_transmit::
:param stream_id: The stream's ID.
:param error_code: An error code indicating why the stream is being reset.
"""
stream = self._get_or_create_stream_for_send(stream_id)
stream.sender.reset(error_code)
def send_ping(self, uid: int) -> None:
"""
Send a PING frame to the peer.
.. aioquic_transmit::
:param uid: A unique ID for this PING.
"""
self._ping_pending.append(uid)
def send_datagram_frame(self, data: bytes) -> None:
"""
Send a DATAGRAM frame.
.. aioquic_transmit::
:param data: The data to be sent.
"""
self._datagrams_pending.append(data)
def send_stream_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
"""
Send data on the specific stream.
.. aioquic_transmit::
:param stream_id: The stream's ID.
:param data: The data to be sent.
:param end_stream: If set to `True`, the FIN bit will be set.
"""
stream = self._get_or_create_stream_for_send(stream_id)
stream.sender.write(data, end_stream=end_stream)
def stop_stream(self, stream_id: int, error_code: int) -> None:
"""
Request termination of the receiving part of a stream.
.. aioquic_transmit::
:param stream_id: The stream's ID.
:param error_code: An error code indicating why the stream is being stopped.
"""
if not self._stream_can_receive(stream_id):
raise ValueError(
"Cannot stop receiving on a local-initiated unidirectional stream"
)
stream = self._streams.get(stream_id, None)
if stream is None:
raise ValueError("Cannot stop receiving on an unknown stream")
stream.receiver.stop(error_code)
# Private
def _alpn_handler(self, alpn_protocol: str) -> None:
"""
Callback which is invoked by the TLS engine at most once, when the
ALPN negotiation completes.
At this point, TLS extensions have been received so we can parse the
transport parameters.
"""
# Parse the remote transport parameters.
for ext_type, ext_data in self.tls.received_extensions:
if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS:
self._parse_transport_parameters(ext_data)
break
else:
raise QuicConnectionError(
error_code=QuicErrorCode.CRYPTO_ERROR
+ tls.AlertDescription.missing_extension,
frame_type=self._crypto_frame_type,
reason_phrase="No QUIC transport parameters received",
)
# For servers, determine the Negotiated Version.
if not self._is_client and not self._version_negotiated_compatible:
if self._remote_version_information is not None:
# Pick the first version we support in the client's available versions,
# which is compatible with the current version.
for version in self._remote_version_information.available_versions:
if version == self._version:
# Stay with the current version.
break
elif (
version in self._configuration.supported_versions
and is_version_compatible(self._version, version)
):
# Change version.
self._version = version
self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[
version
]
# Update our transport parameters to reflect the chosen version.
self.tls.handshake_extensions = [
(
tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS,
self._serialize_transport_parameters(),
)
]
break
self._version_negotiated_compatible = True
self._logger.info(
"Negotiated protocol version %s", pretty_protocol_version(self._version)
)
# Notify the application.
self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol))
def _assert_stream_can_receive(self, frame_type: int, stream_id: int) -> None:
"""
Check the specified stream can receive data or raises a QuicConnectionError.
"""
if not self._stream_can_receive(stream_id):
raise QuicConnectionError(
error_code=QuicErrorCode.STREAM_STATE_ERROR,
frame_type=frame_type,
reason_phrase="Stream is send-only",
)
def _assert_stream_can_send(self, frame_type: int, stream_id: int) -> None:
"""
Check the specified stream can send data or raises a QuicConnectionError.
"""
if not self._stream_can_send(stream_id):
raise QuicConnectionError(
error_code=QuicErrorCode.STREAM_STATE_ERROR,
frame_type=frame_type,
reason_phrase="Stream is receive-only",
)
def _consume_peer_cid(self) -> None:
"""
Update the destination connection ID by taking the next
available connection ID provided by the peer.
"""
self._peer_cid = self._peer_cid_available.pop(0)
self._logger.debug(
"Switching to CID %s (%d)",
dump_cid(self._peer_cid.cid),
self._peer_cid.sequence_number,
)
def _close_begin(self, is_initiator: bool, now: float) -> None:
"""
Begin the close procedure.
"""
self._close_at = now + 3 * self._loss.get_probe_timeout()
if is_initiator:
self._set_state(QuicConnectionState.CLOSING)
else:
self._set_state(QuicConnectionState.DRAINING)
def _close_end(self) -> None:
"""
End the close procedure.
"""
self._close_at = None
for epoch in self._spaces.keys():
self._discard_epoch(epoch)
self._events.append(self._close_event)
self._set_state(QuicConnectionState.TERMINATED)
# signal log end
if self._quic_logger is not None:
self._configuration.quic_logger.end_trace(self._quic_logger)
self._quic_logger = None
def _connect(self, now: float) -> None:
"""
Start the client handshake.
"""
assert self._is_client
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="version_information",
data={
"client_versions": self._configuration.supported_versions,
"chosen_version": self._version,
},
)
self._quic_logger.log_event(
category="transport",
event="alpn_information",
data={"client_alpns": self._configuration.alpn_protocols},
)
self._close_at = now + self._idle_timeout()
self._initialize(self._peer_cid.cid)
self.tls.handle_message(b"", self._crypto_buffers)
self._push_crypto_data()
def _discard_epoch(self, epoch: tls.Epoch) -> None:
if not self._spaces[epoch].discarded:
self._logger.debug("Discarding epoch %s", epoch)
self._cryptos[epoch].teardown()
if epoch == tls.Epoch.INITIAL:
# Tear the crypto pairs, but do not log the event,
# to avoid duplicate log entries.
for crypto in self._cryptos_initial.values():
crypto.recv._teardown_cb = NoCallback
crypto.send._teardown_cb = NoCallback
crypto.teardown()
self._loss.discard_space(self._spaces[epoch])
self._spaces[epoch].discarded = True
def _find_network_path(self, addr: NetworkAddress) -> QuicNetworkPath:
# check existing network paths
for idx, network_path in enumerate(self._network_paths):
if network_path.addr == addr:
return network_path
# new network path
network_path = QuicNetworkPath(addr)
self._logger.debug("Network path %s discovered", network_path.addr)
return network_path
def _get_or_create_stream(self, frame_type: int, stream_id: int) -> QuicStream:
"""
Get or create a stream in response to a received frame.
"""
if stream_id in self._streams_finished:
# the stream was created, but its state was since discarded
raise StreamFinishedError
stream = self._streams.get(stream_id, None)
if stream is None:
# check initiator
if stream_is_client_initiated(stream_id) == self._is_client:
raise QuicConnectionError(
error_code=QuicErrorCode.STREAM_STATE_ERROR,
frame_type=frame_type,
reason_phrase="Wrong stream initiator",
)
# determine limits
if stream_is_unidirectional(stream_id):
max_stream_data_local = self._local_max_stream_data_uni
max_stream_data_remote = 0
max_streams = self._local_max_streams_uni
else:
max_stream_data_local = self._local_max_stream_data_bidi_remote
max_stream_data_remote = self._remote_max_stream_data_bidi_local
max_streams = self._local_max_streams_bidi
# check max streams
stream_count = (stream_id // 4) + 1
if stream_count > max_streams.value:
raise QuicConnectionError(
error_code=QuicErrorCode.STREAM_LIMIT_ERROR,
frame_type=frame_type,
reason_phrase="Too many streams open",
)
elif stream_count > max_streams.used:
max_streams.used = stream_count
# create stream
self._logger.debug("Stream %d created by peer" % stream_id)
stream = self._streams[stream_id] = QuicStream(
stream_id=stream_id,
max_stream_data_local=max_stream_data_local,
max_stream_data_remote=max_stream_data_remote,
writable=not stream_is_unidirectional(stream_id),
)
self._streams_queue.append(stream)
return stream
def _get_or_create_stream_for_send(self, stream_id: int) -> QuicStream:
"""
Get or create a QUIC stream in order to send data to the peer.
This always occurs as a result of an API call.
"""
if not self._stream_can_send(stream_id):
raise ValueError("Cannot send data on peer-initiated unidirectional stream")
stream = self._streams.get(stream_id, None)
if stream is None:
# check initiator
if stream_is_client_initiated(stream_id) != self._is_client:
raise ValueError("Cannot send data on unknown peer-initiated stream")
# determine limits
if stream_is_unidirectional(stream_id):
max_stream_data_local = 0
max_stream_data_remote = self._remote_max_stream_data_uni
max_streams = self._remote_max_streams_uni
streams_blocked = self._streams_blocked_uni
else:
max_stream_data_local = self._local_max_stream_data_bidi_local
max_stream_data_remote = self._remote_max_stream_data_bidi_remote
max_streams = self._remote_max_streams_bidi
streams_blocked = self._streams_blocked_bidi
# create stream
is_unidirectional = stream_is_unidirectional(stream_id)
stream = self._streams[stream_id] = QuicStream(
stream_id=stream_id,
max_stream_data_local=max_stream_data_local,
max_stream_data_remote=max_stream_data_remote,
readable=not is_unidirectional,
)
self._streams_queue.append(stream)
if is_unidirectional:
self._local_next_stream_id_uni = stream_id + 4
else:
self._local_next_stream_id_bidi = stream_id + 4
# mark stream as blocked if needed
if stream_id // 4 >= max_streams:
stream.is_blocked = True
streams_blocked.append(stream)
self._streams_blocked_pending = True
return stream
def _handle_session_ticket(self, session_ticket: tls.SessionTicket) -> None:
if (
session_ticket.max_early_data_size is not None
and session_ticket.max_early_data_size != MAX_EARLY_DATA
):
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="Invalid max_early_data value %s"
% session_ticket.max_early_data_size,
)
self._session_ticket_handler(session_ticket)
def _initialize(self, peer_cid: bytes) -> None:
# TLS
self.tls = tls.Context(
alpn_protocols=self._configuration.alpn_protocols,
cadata=self._configuration.cadata,
cafile=self._configuration.cafile,
capath=self._configuration.capath,
cipher_suites=self.configuration.cipher_suites,
is_client=self._is_client,
logger=self._logger,
max_early_data=None if self._is_client else MAX_EARLY_DATA,
server_name=self._configuration.server_name,
verify_mode=self._configuration.verify_mode,
)
self.tls.certificate = self._configuration.certificate
self.tls.certificate_chain = self._configuration.certificate_chain
self.tls.certificate_private_key = self._configuration.private_key
self.tls.handshake_extensions = [
(
tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS,
self._serialize_transport_parameters(),
)
]
# TLS session resumption
session_ticket = self._configuration.session_ticket
if (
self._is_client
and session_ticket is not None
and session_ticket.is_valid
and session_ticket.server_name == self._configuration.server_name
):
self.tls.session_ticket = self._configuration.session_ticket
# parse saved QUIC transport parameters - for 0-RTT
if session_ticket.max_early_data_size == MAX_EARLY_DATA:
for ext_type, ext_data in session_ticket.other_extensions:
if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS:
self._parse_transport_parameters(
ext_data, from_session_ticket=True
)
break
# TLS callbacks
self.tls.alpn_cb = self._alpn_handler
if self._session_ticket_fetcher is not None:
self.tls.get_session_ticket_cb = self._session_ticket_fetcher
if self._session_ticket_handler is not None:
self.tls.new_session_ticket_cb = self._handle_session_ticket
self.tls.update_traffic_key_cb = self._update_traffic_key
# packet spaces
def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair:
epoch_name = ["initial", "0rtt", "handshake", "1rtt"][epoch.value]
secret_names = [
"server_%s_secret" % epoch_name,
"client_%s_secret" % epoch_name,
]
recv_secret_name = secret_names[not self._is_client]
send_secret_name = secret_names[self._is_client]
return CryptoPair(
recv_setup_cb=partial(self._log_key_updated, recv_secret_name),
recv_teardown_cb=partial(self._log_key_retired, recv_secret_name),
send_setup_cb=partial(self._log_key_updated, send_secret_name),
send_teardown_cb=partial(self._log_key_retired, send_secret_name),
)
# To enable version negotiation, setup encryption keys for all
# our supported versions.
self._cryptos_initial = {}
for version in self._configuration.supported_versions:
pair = CryptoPair()
pair.setup_initial(cid=peer_cid, is_client=self._is_client, version=version)
self._cryptos_initial[version] = pair
self._cryptos = dict(
(epoch, create_crypto_pair(epoch))
for epoch in (
tls.Epoch.ZERO_RTT,
tls.Epoch.HANDSHAKE,
tls.Epoch.ONE_RTT,
)
)
self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[self._version]
self._crypto_buffers = {
tls.Epoch.INITIAL: Buffer(capacity=CRYPTO_BUFFER_SIZE),
tls.Epoch.HANDSHAKE: Buffer(capacity=CRYPTO_BUFFER_SIZE),
tls.Epoch.ONE_RTT: Buffer(capacity=CRYPTO_BUFFER_SIZE),
}
self._crypto_streams = {
tls.Epoch.INITIAL: QuicStream(),
tls.Epoch.HANDSHAKE: QuicStream(),
tls.Epoch.ONE_RTT: QuicStream(),
}
self._spaces = {
tls.Epoch.INITIAL: QuicPacketSpace(),
tls.Epoch.HANDSHAKE: QuicPacketSpace(),
tls.Epoch.ONE_RTT: QuicPacketSpace(),
}
self._loss.spaces = list(self._spaces.values())
def _handle_ack_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle an ACK frame.
"""
ack_rangeset, ack_delay_encoded = pull_ack_frame(buf)
if frame_type == QuicFrameType.ACK_ECN:
buf.pull_uint_var()
buf.pull_uint_var()
buf.pull_uint_var()
ack_delay = (ack_delay_encoded << self._remote_ack_delay_exponent) / 1000000
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_ack_frame(ack_rangeset, ack_delay)
)
# check whether peer completed address validation
if not self._loss.peer_completed_address_validation and context.epoch in (
tls.Epoch.HANDSHAKE,
tls.Epoch.ONE_RTT,
):
self._loss.peer_completed_address_validation = True
self._loss.on_ack_received(
ack_rangeset=ack_rangeset,
ack_delay=ack_delay,
now=context.time,
space=self._spaces[context.epoch],
)
def _handle_connection_close_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a CONNECTION_CLOSE frame.
"""
error_code = buf.pull_uint_var()
if frame_type == QuicFrameType.TRANSPORT_CLOSE:
frame_type = buf.pull_uint_var()
else:
frame_type = None
reason_length = buf.pull_uint_var()
try:
reason_phrase = buf.pull_bytes(reason_length).decode("utf8")
except UnicodeDecodeError:
reason_phrase = ""
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_connection_close_frame(
error_code=error_code,
frame_type=frame_type,
reason_phrase=reason_phrase,
)
)
self._logger.info(
"Connection close received (code 0x%X, reason %s)",
error_code,
reason_phrase,
)
if self._close_event is None:
self._close_event = events.ConnectionTerminated(
error_code=error_code,
frame_type=frame_type,
reason_phrase=reason_phrase,
)
self._close_begin(is_initiator=False, now=context.time)
def _handle_crypto_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a CRYPTO frame.
"""
offset = buf.pull_uint_var()
length = buf.pull_uint_var()
if offset + length > UINT_VAR_MAX:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="offset + length cannot exceed 2^62 - 1",
)
frame = QuicStreamFrame(offset=offset, data=buf.pull_bytes(length))
# Log the frame.
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_crypto_frame(frame)
)
stream = self._crypto_streams[context.epoch]
pending = offset + length - stream.receiver.starting_offset()
if pending > MAX_PENDING_CRYPTO:
raise QuicConnectionError(
error_code=QuicErrorCode.CRYPTO_BUFFER_EXCEEDED,
frame_type=frame_type,
reason_phrase="too much crypto buffering",
)
event = stream.receiver.handle_frame(frame)
if event is not None:
# Pass data to TLS layer, which may cause calls to:
# - _alpn_handler
# - _update_traffic_key
self._crypto_frame_type = frame_type
self._crypto_packet_version = context.version
try:
self.tls.handle_message(event.data, self._crypto_buffers)
self._push_crypto_data()
except tls.Alert as exc:
raise QuicConnectionError(
error_code=QuicErrorCode.CRYPTO_ERROR + int(exc.description),
frame_type=frame_type,
reason_phrase=str(exc),
)
# Update the current epoch.
if not self._handshake_complete and self.tls.state in [
tls.State.CLIENT_POST_HANDSHAKE,
tls.State.SERVER_POST_HANDSHAKE,
]:
self._handshake_complete = True
# for servers, the handshake is now confirmed
if not self._is_client:
self._discard_epoch(tls.Epoch.HANDSHAKE)
self._handshake_confirmed = True
self._handshake_done_pending = True
self._replenish_connection_ids()
self._events.append(
events.HandshakeCompleted(
alpn_protocol=self.tls.alpn_negotiated,
early_data_accepted=self.tls.early_data_accepted,
session_resumed=self.tls.session_resumed,
)
)
self._unblock_streams(is_unidirectional=False)
self._unblock_streams(is_unidirectional=True)
self._logger.info(
"ALPN negotiated protocol %s", self.tls.alpn_negotiated
)
else:
self._logger.info(
"Duplicate CRYPTO data received for epoch %s", context.epoch
)
# If a server receives duplicate CRYPTO in an INITIAL packet,
# it can assume the client did not receive the server's CRYPTO.
if (
not self._is_client
and context.epoch == tls.Epoch.INITIAL
and not self._crypto_retransmitted
):
self._loss.reschedule_data(now=context.time)
self._crypto_retransmitted = True
def _handle_data_blocked_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a DATA_BLOCKED frame.
"""
limit = buf.pull_uint_var()
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_data_blocked_frame(limit=limit)
)
def _handle_datagram_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a DATAGRAM frame.
"""
start = buf.tell()
if frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH:
length = buf.pull_uint_var()
else:
length = buf.capacity - start
data = buf.pull_bytes(length)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_datagram_frame(length=length)
)
# check frame is allowed
if (
self._configuration.max_datagram_frame_size is None
or buf.tell() - start >= self._configuration.max_datagram_frame_size
):
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Unexpected DATAGRAM frame",
)
self._events.append(events.DatagramFrameReceived(data=data))
def _handle_handshake_done_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a HANDSHAKE_DONE frame.
"""
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_handshake_done_frame()
)
if not self._is_client:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Clients must not send HANDSHAKE_DONE frames",
)
# for clients, the handshake is now confirmed
if not self._handshake_confirmed:
self._discard_epoch(tls.Epoch.HANDSHAKE)
self._handshake_confirmed = True
self._loss.peer_completed_address_validation = True
def _handle_max_data_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a MAX_DATA frame.
This adjusts the total amount of we can send to the peer.
"""
max_data = buf.pull_uint_var()
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_connection_limit_frame(
frame_type=frame_type, maximum=max_data
)
)
if max_data > self._remote_max_data:
self._logger.debug("Remote max_data raised to %d", max_data)
self._remote_max_data = max_data
def _handle_max_stream_data_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a MAX_STREAM_DATA frame.
This adjusts the amount of data we can send on a specific stream.
"""
stream_id = buf.pull_uint_var()
max_stream_data = buf.pull_uint_var()
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_max_stream_data_frame(
maximum=max_stream_data, stream_id=stream_id
)
)
# check stream direction
self._assert_stream_can_send(frame_type, stream_id)
stream = self._get_or_create_stream(frame_type, stream_id)
if max_stream_data > stream.max_stream_data_remote:
self._logger.debug(
"Stream %d remote max_stream_data raised to %d",
stream_id,
max_stream_data,
)
stream.max_stream_data_remote = max_stream_data
def _handle_max_streams_bidi_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a MAX_STREAMS_BIDI frame.
This raises number of bidirectional streams we can initiate to the peer.
"""
max_streams = buf.pull_uint_var()
if max_streams > STREAM_COUNT_MAX:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="Maximum Streams cannot exceed 2^60",
)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_connection_limit_frame(
frame_type=frame_type, maximum=max_streams
)
)
if max_streams > self._remote_max_streams_bidi:
self._logger.debug("Remote max_streams_bidi raised to %d", max_streams)
self._remote_max_streams_bidi = max_streams
self._unblock_streams(is_unidirectional=False)
def _handle_max_streams_uni_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a MAX_STREAMS_UNI frame.
This raises number of unidirectional streams we can initiate to the peer.
"""
max_streams = buf.pull_uint_var()
if max_streams > STREAM_COUNT_MAX:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="Maximum Streams cannot exceed 2^60",
)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_connection_limit_frame(
frame_type=frame_type, maximum=max_streams
)
)
if max_streams > self._remote_max_streams_uni:
self._logger.debug("Remote max_streams_uni raised to %d", max_streams)
self._remote_max_streams_uni = max_streams
self._unblock_streams(is_unidirectional=True)
def _handle_new_connection_id_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a NEW_CONNECTION_ID frame.
"""
sequence_number = buf.pull_uint_var()
retire_prior_to = buf.pull_uint_var()
length = buf.pull_uint8()
connection_id = buf.pull_bytes(length)
stateless_reset_token = buf.pull_bytes(STATELESS_RESET_TOKEN_SIZE)
if not connection_id or len(connection_id) > CONNECTION_ID_MAX_SIZE:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="Length must be greater than 0 and less than 20",
)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_new_connection_id_frame(
connection_id=connection_id,
retire_prior_to=retire_prior_to,
sequence_number=sequence_number,
stateless_reset_token=stateless_reset_token,
)
)
# sanity check
if retire_prior_to > sequence_number:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Retire Prior To is greater than Sequence Number",
)
# only accept retire_prior_to if it is bigger than the one we know
self._peer_retire_prior_to = max(retire_prior_to, self._peer_retire_prior_to)
# determine which CIDs to retire
change_cid = False
retire = [
cid
for cid in self._peer_cid_available
if cid.sequence_number < self._peer_retire_prior_to
]
if self._peer_cid.sequence_number < self._peer_retire_prior_to:
change_cid = True
retire.insert(0, self._peer_cid)
# update available CIDs
self._peer_cid_available = [
cid
for cid in self._peer_cid_available
if cid.sequence_number >= self._peer_retire_prior_to
]
if (
sequence_number >= self._peer_retire_prior_to
and sequence_number not in self._peer_cid_sequence_numbers
):
self._peer_cid_available.append(
QuicConnectionId(
cid=connection_id,
sequence_number=sequence_number,
stateless_reset_token=stateless_reset_token,
)
)
self._peer_cid_sequence_numbers.add(sequence_number)
# retire previous CIDs
for quic_connection_id in retire:
self._retire_peer_cid(quic_connection_id)
# assign new CID if we retired the active one
if change_cid:
self._consume_peer_cid()
# check number of active connection IDs, including the selected one
if 1 + len(self._peer_cid_available) > self._local_active_connection_id_limit:
raise QuicConnectionError(
error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR,
frame_type=frame_type,
reason_phrase="Too many active connection IDs",
)
# Check the number of retired connection IDs pending, though with a safer limit
# than the 2x recommended in section 5.1.2 of the RFC. Note that we are doing
# the check here and not in _retire_peer_cid() because we know the frame type to
# use here, and because it is the new connection id path that is potentially
# dangerous. We may transiently go a bit over the limit due to unacked frames
# getting added back to the list, but that's ok as it is bounded.
if len(self._retire_connection_ids) > min(
self._local_active_connection_id_limit * 4, MAX_PENDING_RETIRES
):
raise QuicConnectionError(
error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR,
frame_type=frame_type,
reason_phrase="Too many pending retired connection IDs",
)
def _handle_new_token_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a NEW_TOKEN frame.
"""
length = buf.pull_uint_var()
token = buf.pull_bytes(length)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_new_token_frame(token=token)
)
if not self._is_client:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Clients must not send NEW_TOKEN frames",
)
if self._token_handler is not None:
self._token_handler(token)
def _handle_padding_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a PADDING frame.
"""
# consume padding
pos = buf.tell()
for byte in buf.data_slice(pos, buf.capacity):
if byte:
break
pos += 1
buf.seek(pos)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(self._quic_logger.encode_padding_frame())
def _handle_path_challenge_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a PATH_CHALLENGE frame.
"""
data = buf.pull_bytes(8)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_path_challenge_frame(data=data)
)
context.network_path.remote_challenges.append(data)
def _handle_path_response_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a PATH_RESPONSE frame.
"""
data = buf.pull_bytes(8)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_path_response_frame(data=data)
)
try:
network_path = self._local_challenges.pop(data)
except KeyError:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Response does not match challenge",
)
self._logger.debug("Network path %s validated by challenge", network_path.addr)
network_path.is_validated = True
def _handle_ping_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a PING frame.
"""
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(self._quic_logger.encode_ping_frame())
def _handle_reset_stream_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a RESET_STREAM frame.
"""
stream_id = buf.pull_uint_var()
error_code = buf.pull_uint_var()
final_size = buf.pull_uint_var()
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_reset_stream_frame(
error_code=error_code, final_size=final_size, stream_id=stream_id
)
)
# check stream direction
self._assert_stream_can_receive(frame_type, stream_id)
# check flow-control limits
stream = self._get_or_create_stream(frame_type, stream_id)
if final_size > stream.max_stream_data_local:
raise QuicConnectionError(
error_code=QuicErrorCode.FLOW_CONTROL_ERROR,
frame_type=frame_type,
reason_phrase="Over stream data limit",
)
newly_received = max(0, final_size - stream.receiver.highest_offset)
if self._local_max_data.used + newly_received > self._local_max_data.value:
raise QuicConnectionError(
error_code=QuicErrorCode.FLOW_CONTROL_ERROR,
frame_type=frame_type,
reason_phrase="Over connection data limit",
)
# process reset
self._logger.info(
"Stream %d reset by peer (error code %d, final size %d)",
stream_id,
error_code,
final_size,
)
try:
event = stream.receiver.handle_reset(
error_code=error_code, final_size=final_size
)
except FinalSizeError as exc:
raise QuicConnectionError(
error_code=QuicErrorCode.FINAL_SIZE_ERROR,
frame_type=frame_type,
reason_phrase=str(exc),
)
if event is not None:
self._events.append(event)
self._local_max_data.used += newly_received
def _handle_retire_connection_id_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a RETIRE_CONNECTION_ID frame.
"""
sequence_number = buf.pull_uint_var()
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_retire_connection_id_frame(sequence_number)
)
if sequence_number >= self._host_cid_seq:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Cannot retire unknown connection ID",
)
# find the connection ID by sequence number
for index, connection_id in enumerate(self._host_cids):
if connection_id.sequence_number == sequence_number:
if connection_id.cid == context.host_cid:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Cannot retire current connection ID",
)
self._logger.debug(
"Peer retiring CID %s (%d)",
dump_cid(connection_id.cid),
connection_id.sequence_number,
)
del self._host_cids[index]
self._events.append(
events.ConnectionIdRetired(connection_id=connection_id.cid)
)
break
# issue a new connection ID
self._replenish_connection_ids()
def _handle_stop_sending_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a STOP_SENDING frame.
"""
stream_id = buf.pull_uint_var()
error_code = buf.pull_uint_var() # application error code
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_stop_sending_frame(
error_code=error_code, stream_id=stream_id
)
)
# check stream direction
self._assert_stream_can_send(frame_type, stream_id)
# reset the stream
stream = self._get_or_create_stream(frame_type, stream_id)
stream.sender.reset(error_code=QuicErrorCode.NO_ERROR)
self._events.append(
events.StopSendingReceived(error_code=error_code, stream_id=stream_id)
)
def _handle_stream_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a STREAM frame.
"""
stream_id = buf.pull_uint_var()
if frame_type & 4:
offset = buf.pull_uint_var()
else:
offset = 0
if frame_type & 2:
length = buf.pull_uint_var()
else:
length = buf.capacity - buf.tell()
if offset + length > UINT_VAR_MAX:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="offset + length cannot exceed 2^62 - 1",
)
frame = QuicStreamFrame(
offset=offset, data=buf.pull_bytes(length), fin=bool(frame_type & 1)
)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_stream_frame(frame, stream_id=stream_id)
)
# check stream direction
self._assert_stream_can_receive(frame_type, stream_id)
# check flow-control limits
stream = self._get_or_create_stream(frame_type, stream_id)
if offset + length > stream.max_stream_data_local:
raise QuicConnectionError(
error_code=QuicErrorCode.FLOW_CONTROL_ERROR,
frame_type=frame_type,
reason_phrase="Over stream data limit",
)
newly_received = max(0, offset + length - stream.receiver.highest_offset)
if self._local_max_data.used + newly_received > self._local_max_data.value:
raise QuicConnectionError(
error_code=QuicErrorCode.FLOW_CONTROL_ERROR,
frame_type=frame_type,
reason_phrase="Over connection data limit",
)
# process data
try:
event = stream.receiver.handle_frame(frame)
except FinalSizeError as exc:
raise QuicConnectionError(
error_code=QuicErrorCode.FINAL_SIZE_ERROR,
frame_type=frame_type,
reason_phrase=str(exc),
)
if event is not None:
self._events.append(event)
self._local_max_data.used += newly_received
def _handle_stream_data_blocked_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a STREAM_DATA_BLOCKED frame.
"""
stream_id = buf.pull_uint_var()
limit = buf.pull_uint_var()
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_stream_data_blocked_frame(
limit=limit, stream_id=stream_id
)
)
# check stream direction
self._assert_stream_can_receive(frame_type, stream_id)
self._get_or_create_stream(frame_type, stream_id)
def _handle_streams_blocked_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
"""
Handle a STREAMS_BLOCKED frame.
"""
limit = buf.pull_uint_var()
if limit > STREAM_COUNT_MAX:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="Maximum Streams cannot exceed 2^60",
)
# log frame
if self._quic_logger is not None:
context.quic_logger_frames.append(
self._quic_logger.encode_streams_blocked_frame(
is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI,
limit=limit,
)
)
def _log_key_retired(self, key_type: str, trigger: str) -> None:
"""
Log a key retirement.
"""
if self._quic_logger is not None:
self._quic_logger.log_event(
category="security",
event="key_retired",
data={"key_type": key_type, "trigger": trigger},
)
def _log_key_updated(self, key_type: str, trigger: str) -> None:
"""
Log a key update.
"""
if self._quic_logger is not None:
self._quic_logger.log_event(
category="security",
event="key_updated",
data={"key_type": key_type, "trigger": trigger},
)
def _on_ack_delivery(
self, delivery: QuicDeliveryState, space: QuicPacketSpace, highest_acked: int
) -> None:
"""
Callback when an ACK frame is acknowledged or lost.
"""
if delivery == QuicDeliveryState.ACKED:
space.ack_queue.subtract(0, highest_acked + 1)
def _on_connection_limit_delivery(
self, delivery: QuicDeliveryState, limit: Limit
) -> None:
"""
Callback when a MAX_DATA or MAX_STREAMS frame is acknowledged or lost.
"""
if delivery != QuicDeliveryState.ACKED:
limit.sent = 0
def _on_handshake_done_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a HANDSHAKE_DONE frame is acknowledged or lost.
"""
if delivery != QuicDeliveryState.ACKED:
self._handshake_done_pending = True
def _on_max_stream_data_delivery(
self, delivery: QuicDeliveryState, stream: QuicStream
) -> None:
"""
Callback when a MAX_STREAM_DATA frame is acknowledged or lost.
"""
if delivery != QuicDeliveryState.ACKED:
stream.max_stream_data_local_sent = 0
def _on_new_connection_id_delivery(
self, delivery: QuicDeliveryState, connection_id: QuicConnectionId
) -> None:
"""
Callback when a NEW_CONNECTION_ID frame is acknowledged or lost.
"""
if delivery != QuicDeliveryState.ACKED:
connection_id.was_sent = False
def _on_ping_delivery(
self, delivery: QuicDeliveryState, uids: Sequence[int]
) -> None:
"""
Callback when a PING frame is acknowledged or lost.
"""
if delivery == QuicDeliveryState.ACKED:
self._logger.debug("Received PING%s response", "" if uids else " (probe)")
for uid in uids:
self._events.append(events.PingAcknowledged(uid=uid))
else:
self._ping_pending.extend(uids)
def _on_retire_connection_id_delivery(
self, delivery: QuicDeliveryState, sequence_number: int
) -> None:
"""
Callback when a RETIRE_CONNECTION_ID frame is acknowledged or lost.
"""
if delivery != QuicDeliveryState.ACKED:
self._retire_connection_ids.append(sequence_number)
def _payload_received(
self,
context: QuicReceiveContext,
plain: bytes,
crypto_frame_required: bool = False,
) -> Tuple[bool, bool]:
"""
Handle a QUIC packet payload.
"""
buf = Buffer(data=plain)
crypto_frame_found = False
frame_found = False
is_ack_eliciting = False
is_probing = None
while not buf.eof():
# get frame type
try:
frame_type = buf.pull_uint_var()
except BufferReadError:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=None,
reason_phrase="Malformed frame type",
)
# check frame type is known
try:
frame_handler, frame_epochs = self.__frame_handlers[frame_type]
except KeyError:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="Unknown frame type",
)
# check frame type is allowed for the epoch
if context.epoch not in frame_epochs:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=frame_type,
reason_phrase="Unexpected frame type",
)
# handle the frame
try:
frame_handler(context, frame_type, buf)
except BufferReadError:
raise QuicConnectionError(
error_code=QuicErrorCode.FRAME_ENCODING_ERROR,
frame_type=frame_type,
reason_phrase="Failed to parse frame",
)
except StreamFinishedError:
# we lack the state for the stream, ignore the frame
pass
# update ACK only / probing flags
frame_found = True
if frame_type == QuicFrameType.CRYPTO:
crypto_frame_found = True
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
is_ack_eliciting = True
if frame_type not in PROBING_FRAME_TYPES:
is_probing = False
elif is_probing is None:
is_probing = True
if not frame_found:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=QuicFrameType.PADDING,
reason_phrase="Packet contains no frames",
)
# RFC 9000 - 17.2.2. Initial Packet
# The first packet sent by a client always includes a CRYPTO frame.
if crypto_frame_required and not crypto_frame_found:
raise QuicConnectionError(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
frame_type=QuicFrameType.PADDING,
reason_phrase="Packet contains no CRYPTO frame",
)
return is_ack_eliciting, bool(is_probing)
def _receive_retry_packet(
self, header: QuicHeader, packet_without_tag: bytes, now: float
) -> None:
"""
Handle a retry packet.
"""
if (
self._is_client
and not self._retry_count
and header.destination_cid == self.host_cid
and header.integrity_tag
== get_retry_integrity_tag(
packet_without_tag,
self._peer_cid.cid,
version=header.version,
)
):
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": [],
"header": {
"packet_type": "retry",
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": header.packet_length},
},
)
self._peer_cid.cid = header.source_cid
self._peer_token = header.token
self._retry_count += 1
self._retry_source_connection_id = header.source_cid
self._logger.info("Retrying with token (%d bytes)" % len(header.token))
self._connect(now=now)
else:
# Unexpected or invalid retry packet.
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "unexpected_packet",
"raw": {"length": header.packet_length},
},
)
def _receive_version_negotiation_packet(
self, header: QuicHeader, now: float
) -> None:
"""
Handle a version negotiation packet.
This is used in "Incompatible Version Negotiation", see:
https://datatracker.ietf.org/doc/html/rfc9368#section-2.2
"""
# Only clients process Version Negotiation, and once a Version
# Negotiation packet has been acted upon, any further
# such packets must be ignored.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if (
self._is_client
and self._state == QuicConnectionState.FIRSTFLIGHT
and not self._version_negotiated_incompatible
):
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": [],
"header": {
"packet_type": self._quic_logger.packet_type(
header.packet_type
),
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": header.packet_length},
},
)
# Ignore any Version Negotiation packets that contain the
# original version.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if self._version in header.supported_versions:
self._logger.warning(
"Version negotiation packet contains protocol version %s",
pretty_protocol_version(self._version),
)
return
# Look for a common protocol version.
common = [
x
for x in self._configuration.supported_versions
if x in header.supported_versions
]
# Look for a common protocol version.
chosen_version = common[0] if common else None
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="version_information",
data={
"server_versions": header.supported_versions,
"client_versions": self._configuration.supported_versions,
"chosen_version": chosen_version,
},
)
if chosen_version is None:
self._logger.error("Could not find a common protocol version")
self._close_event = events.ConnectionTerminated(
error_code=QuicErrorCode.INTERNAL_ERROR,
frame_type=QuicFrameType.PADDING,
reason_phrase="Could not find a common protocol version",
)
self._close_end()
return
self._packet_number = 0
self._version = chosen_version
self._version_negotiated_incompatible = True
self._logger.info(
"Retrying with protocol version %s",
pretty_protocol_version(self._version),
)
self._connect(now=now)
else:
# Unexpected version negotiation packet.
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={
"trigger": "unexpected_packet",
"raw": {"length": header.packet_length},
},
)
def _replenish_connection_ids(self) -> None:
"""
Generate new connection IDs.
"""
while len(self._host_cids) < min(8, self._remote_active_connection_id_limit):
self._host_cids.append(
QuicConnectionId(
cid=os.urandom(self._configuration.connection_id_length),
sequence_number=self._host_cid_seq,
stateless_reset_token=os.urandom(16),
)
)
self._host_cid_seq += 1
def _retire_peer_cid(self, connection_id: QuicConnectionId) -> None:
"""
Retire a destination connection ID.
"""
self._logger.debug(
"Retiring CID %s (%d) [%d]",
dump_cid(connection_id.cid),
connection_id.sequence_number,
len(self._retire_connection_ids) + 1,
)
self._retire_connection_ids.append(connection_id.sequence_number)
def _push_crypto_data(self) -> None:
for epoch, buf in self._crypto_buffers.items():
self._crypto_streams[epoch].sender.write(buf.data)
buf.seek(0)
def _send_probe(self) -> None:
self._probe_pending = True
def _parse_transport_parameters(
self, data: bytes, from_session_ticket: bool = False
) -> None:
"""
Parse and apply remote transport parameters.
`from_session_ticket` is `True` when restoring saved transport parameters,
and `False` when handling received transport parameters.
"""
try:
quic_transport_parameters = pull_quic_transport_parameters(
Buffer(data=data)
)
except ValueError:
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="Could not parse QUIC transport parameters",
)
# log event
if self._quic_logger is not None and not from_session_ticket:
self._quic_logger.log_event(
category="transport",
event="parameters_set",
data=self._quic_logger.encode_transport_parameters(
owner="remote", parameters=quic_transport_parameters
),
)
# Validate remote parameters.
if not self._is_client:
for attr in [
"original_destination_connection_id",
"preferred_address",
"retry_source_connection_id",
"stateless_reset_token",
]:
if getattr(quic_transport_parameters, attr) is not None:
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="%s is not allowed for clients" % attr,
)
if not from_session_ticket:
if (
quic_transport_parameters.initial_source_connection_id
!= self._remote_initial_source_connection_id
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="initial_source_connection_id does not match",
)
if self._is_client and (
quic_transport_parameters.original_destination_connection_id
!= self._original_destination_connection_id
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="original_destination_connection_id does not match",
)
if self._is_client and (
quic_transport_parameters.retry_source_connection_id
!= self._retry_source_connection_id
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="retry_source_connection_id does not match",
)
if (
quic_transport_parameters.active_connection_id_limit is not None
and quic_transport_parameters.active_connection_id_limit < 2
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="active_connection_id_limit must be no less than 2",
)
if (
quic_transport_parameters.ack_delay_exponent is not None
and quic_transport_parameters.ack_delay_exponent > 20
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="ack_delay_exponent must be <= 20",
)
if (
quic_transport_parameters.max_ack_delay is not None
and quic_transport_parameters.max_ack_delay >= 2**14
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase="max_ack_delay must be < 2^14",
)
if quic_transport_parameters.max_udp_payload_size is not None and (
quic_transport_parameters.max_udp_payload_size
< SMALLEST_MAX_DATAGRAM_SIZE
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase=(
f"max_udp_payload_size must be >= {SMALLEST_MAX_DATAGRAM_SIZE}"
),
)
# Validate Version Information extension.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if quic_transport_parameters.version_information is not None:
version_information = quic_transport_parameters.version_information
# If a server receives Version Information where the Chosen Version
# is not included in Available Versions, it MUST treat is as a
# parsing failure.
if (
not self._is_client
and version_information.chosen_version
not in version_information.available_versions
):
raise QuicConnectionError(
error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase=(
"version_information's chosen_version is not included "
"in available_versions"
),
)
# Validate that the Chosen Version matches the version in use for the
# connection.
if version_information.chosen_version != self._crypto_packet_version:
raise QuicConnectionError(
error_code=QuicErrorCode.VERSION_NEGOTIATION_ERROR,
frame_type=QuicFrameType.CRYPTO,
reason_phrase=(
"version_information's chosen_version does not match "
"the version in use"
),
)
# Store remote parameters.
if not from_session_ticket:
if quic_transport_parameters.ack_delay_exponent is not None:
self._remote_ack_delay_exponent = self._remote_ack_delay_exponent
if quic_transport_parameters.max_ack_delay is not None:
self._loss.max_ack_delay = (
quic_transport_parameters.max_ack_delay / 1000.0
)
if (
self._is_client
and self._peer_cid.sequence_number == 0
and quic_transport_parameters.stateless_reset_token is not None
):
self._peer_cid.stateless_reset_token = (
quic_transport_parameters.stateless_reset_token
)
self._remote_version_information = (
quic_transport_parameters.version_information
)
if quic_transport_parameters.active_connection_id_limit is not None:
self._remote_active_connection_id_limit = (
quic_transport_parameters.active_connection_id_limit
)
if quic_transport_parameters.max_idle_timeout is not None:
self._remote_max_idle_timeout = (
quic_transport_parameters.max_idle_timeout / 1000.0
)
self._remote_max_datagram_frame_size = (
quic_transport_parameters.max_datagram_frame_size
)
for param in [
"max_data",
"max_stream_data_bidi_local",
"max_stream_data_bidi_remote",
"max_stream_data_uni",
"max_streams_bidi",
"max_streams_uni",
]:
value = getattr(quic_transport_parameters, "initial_" + param)
if value is not None:
setattr(self, "_remote_" + param, value)
def _serialize_transport_parameters(self) -> bytes:
quic_transport_parameters = QuicTransportParameters(
ack_delay_exponent=self._local_ack_delay_exponent,
active_connection_id_limit=self._local_active_connection_id_limit,
max_idle_timeout=int(self._configuration.idle_timeout * 1000),
initial_max_data=self._local_max_data.value,
initial_max_stream_data_bidi_local=self._local_max_stream_data_bidi_local,
initial_max_stream_data_bidi_remote=self._local_max_stream_data_bidi_remote,
initial_max_stream_data_uni=self._local_max_stream_data_uni,
initial_max_streams_bidi=self._local_max_streams_bidi.value,
initial_max_streams_uni=self._local_max_streams_uni.value,
initial_source_connection_id=self._local_initial_source_connection_id,
max_ack_delay=25,
max_datagram_frame_size=self._configuration.max_datagram_frame_size,
quantum_readiness=(
b"Q" * SMALLEST_MAX_DATAGRAM_SIZE
if self._configuration.quantum_readiness_test
else None
),
stateless_reset_token=self._host_cids[0].stateless_reset_token,
version_information=QuicVersionInformation(
chosen_version=self._version,
available_versions=self._configuration.supported_versions,
),
)
if not self._is_client:
quic_transport_parameters.original_destination_connection_id = (
self._original_destination_connection_id
)
quic_transport_parameters.retry_source_connection_id = (
self._retry_source_connection_id
)
# log event
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="parameters_set",
data=self._quic_logger.encode_transport_parameters(
owner="local", parameters=quic_transport_parameters
),
)
buf = Buffer(capacity=3 * self._max_datagram_size)
push_quic_transport_parameters(buf, quic_transport_parameters)
return buf.data
def _set_state(self, state: QuicConnectionState) -> None:
self._logger.debug("%s -> %s", self._state, state)
self._state = state
def _stream_can_receive(self, stream_id: int) -> bool:
return stream_is_client_initiated(
stream_id
) != self._is_client or not stream_is_unidirectional(stream_id)
def _stream_can_send(self, stream_id: int) -> bool:
return stream_is_client_initiated(
stream_id
) == self._is_client or not stream_is_unidirectional(stream_id)
def _unblock_streams(self, is_unidirectional: bool) -> None:
if is_unidirectional:
max_stream_data_remote = self._remote_max_stream_data_uni
max_streams = self._remote_max_streams_uni
streams_blocked = self._streams_blocked_uni
else:
max_stream_data_remote = self._remote_max_stream_data_bidi_remote
max_streams = self._remote_max_streams_bidi
streams_blocked = self._streams_blocked_bidi
while streams_blocked and streams_blocked[0].stream_id // 4 < max_streams:
stream = streams_blocked.pop(0)
stream.is_blocked = False
stream.max_stream_data_remote = max_stream_data_remote
if not self._streams_blocked_bidi and not self._streams_blocked_uni:
self._streams_blocked_pending = False
def _update_traffic_key(
self,
direction: tls.Direction,
epoch: tls.Epoch,
cipher_suite: tls.CipherSuite,
secret: bytes,
) -> None:
"""
Callback which is invoked by the TLS engine when new traffic keys are
available.
"""
# For clients, determine the negotiated protocol version.
if (
self._is_client
and self._crypto_packet_version is not None
and not self._version_negotiated_compatible
):
self._version = self._crypto_packet_version
self._version_negotiated_compatible = True
self._logger.info(
"Negotiated protocol version %s", pretty_protocol_version(self._version)
)
secrets_log_file = self._configuration.secrets_log_file
if secrets_log_file is not None:
label_row = self._is_client == (direction == tls.Direction.DECRYPT)
label = SECRETS_LABELS[label_row][epoch.value]
secrets_log_file.write(
"%s %s %s\n" % (label, self.tls.client_random.hex(), secret.hex())
)
secrets_log_file.flush()
crypto = self._cryptos[epoch]
if direction == tls.Direction.ENCRYPT:
crypto.send.setup(
cipher_suite=cipher_suite, secret=secret, version=self._version
)
else:
crypto.recv.setup(
cipher_suite=cipher_suite, secret=secret, version=self._version
)
def _add_local_challenge(self, challenge: bytes, network_path: QuicNetworkPath):
self._local_challenges[challenge] = network_path
while len(self._local_challenges) > MAX_LOCAL_CHALLENGES:
# Dictionaries are ordered, so pop the first key until we are below the
# limit.
key = next(iter(self._local_challenges.keys()))
del self._local_challenges[key]
def _write_application(
self, builder: QuicPacketBuilder, network_path: QuicNetworkPath, now: float
) -> None:
crypto_stream: Optional[QuicStream] = None
if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid():
crypto = self._cryptos[tls.Epoch.ONE_RTT]
crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT]
packet_type = QuicPacketType.ONE_RTT
elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid():
crypto = self._cryptos[tls.Epoch.ZERO_RTT]
packet_type = QuicPacketType.ZERO_RTT
else:
return
space = self._spaces[tls.Epoch.ONE_RTT]
while True:
# apply pacing, except if we have ACKs to send
if space.ack_at is None or space.ack_at >= now:
self._pacing_at = self._loss._pacer.next_send_time(now=now)
if self._pacing_at is not None:
break
builder.start_packet(packet_type, crypto)
if self._handshake_complete:
# ACK
if space.ack_at is not None and space.ack_at <= now:
self._write_ack_frame(builder=builder, space=space, now=now)
# HANDSHAKE_DONE
if self._handshake_done_pending:
self._write_handshake_done_frame(builder=builder)
self._handshake_done_pending = False
# PATH CHALLENGE
if not (network_path.is_validated or network_path.local_challenge_sent):
challenge = os.urandom(8)
self._add_local_challenge(
challenge=challenge, network_path=network_path
)
self._write_path_challenge_frame(
builder=builder, challenge=challenge
)
network_path.local_challenge_sent = True
# PATH RESPONSE
while len(network_path.remote_challenges) > 0:
challenge = network_path.remote_challenges.popleft()
self._write_path_response_frame(
builder=builder, challenge=challenge
)
# NEW_CONNECTION_ID
for connection_id in self._host_cids:
if not connection_id.was_sent:
self._write_new_connection_id_frame(
builder=builder, connection_id=connection_id
)
# RETIRE_CONNECTION_ID
for sequence_number in self._retire_connection_ids[:]:
self._write_retire_connection_id_frame(
builder=builder, sequence_number=sequence_number
)
self._retire_connection_ids.pop(0)
# STREAMS_BLOCKED
if self._streams_blocked_pending:
if self._streams_blocked_bidi:
self._write_streams_blocked_frame(
builder=builder,
frame_type=QuicFrameType.STREAMS_BLOCKED_BIDI,
limit=self._remote_max_streams_bidi,
)
if self._streams_blocked_uni:
self._write_streams_blocked_frame(
builder=builder,
frame_type=QuicFrameType.STREAMS_BLOCKED_UNI,
limit=self._remote_max_streams_uni,
)
self._streams_blocked_pending = False
# MAX_DATA and MAX_STREAMS
self._write_connection_limits(builder=builder, space=space)
# stream-level limits
for stream in self._streams.values():
self._write_stream_limits(builder=builder, space=space, stream=stream)
# PING (user-request)
if self._ping_pending:
self._write_ping_frame(builder, self._ping_pending)
self._ping_pending.clear()
# PING (probe)
if self._probe_pending:
self._write_ping_frame(builder, comment="probe")
self._probe_pending = False
# CRYPTO
if crypto_stream is not None and not crypto_stream.sender.buffer_is_empty:
self._write_crypto_frame(
builder=builder, space=space, stream=crypto_stream
)
# DATAGRAM
while self._datagrams_pending:
try:
self._write_datagram_frame(
builder=builder,
data=self._datagrams_pending[0],
frame_type=QuicFrameType.DATAGRAM_WITH_LENGTH,
)
self._datagrams_pending.popleft()
except QuicPacketBuilderStop:
break
sent: Set[QuicStream] = set()
discarded: Set[QuicStream] = set()
try:
for stream in self._streams_queue:
# if the stream is finished, discard it
if stream.is_finished:
self._logger.debug("Stream %d discarded", stream.stream_id)
self._streams.pop(stream.stream_id)
self._streams_finished.add(stream.stream_id)
discarded.add(stream)
continue
if stream.receiver.stop_pending:
# STOP_SENDING
self._write_stop_sending_frame(builder=builder, stream=stream)
if stream.sender.reset_pending:
# RESET_STREAM
self._write_reset_stream_frame(builder=builder, stream=stream)
elif not stream.is_blocked and not stream.sender.buffer_is_empty:
# STREAM
used = self._write_stream_frame(
builder=builder,
space=space,
stream=stream,
max_offset=min(
stream.sender.highest_offset
+ self._remote_max_data
- self._remote_max_data_used,
stream.max_stream_data_remote,
),
)
self._remote_max_data_used += used
if used > 0:
sent.add(stream)
finally:
# Make a new stream service order, putting served ones at the end.
#
# This method of updating the streams queue ensures that discarded
# streams are removed and ones which sent are moved to the end even
# if an exception occurs in the loop.
self._streams_queue = [
stream
for stream in self._streams_queue
if not (stream in discarded or stream in sent)
]
self._streams_queue.extend(sent)
if builder.packet_is_empty:
break
else:
self._loss._pacer.update_after_send(now=now)
def _write_handshake(
self, builder: QuicPacketBuilder, epoch: tls.Epoch, now: float
) -> None:
crypto = self._cryptos[epoch]
if not crypto.send.is_valid():
return
crypto_stream = self._crypto_streams[epoch]
space = self._spaces[epoch]
while True:
if epoch == tls.Epoch.INITIAL:
packet_type = QuicPacketType.INITIAL
else:
packet_type = QuicPacketType.HANDSHAKE
builder.start_packet(packet_type, crypto)
# ACK
if space.ack_at is not None:
self._write_ack_frame(builder=builder, space=space, now=now)
# CRYPTO
if not crypto_stream.sender.buffer_is_empty:
if self._write_crypto_frame(
builder=builder, space=space, stream=crypto_stream
):
self._probe_pending = False
# PING (probe)
if (
self._probe_pending
and not self._handshake_complete
and (
epoch == tls.Epoch.HANDSHAKE
or not self._cryptos[tls.Epoch.HANDSHAKE].send.is_valid()
)
):
self._write_ping_frame(builder, comment="probe")
self._probe_pending = False
if builder.packet_is_empty:
break
def _write_ack_frame(
self, builder: QuicPacketBuilder, space: QuicPacketSpace, now: float
) -> None:
# calculate ACK delay
ack_delay = now - space.largest_received_time
ack_delay_encoded = int(ack_delay * 1000000) >> self._local_ack_delay_exponent
buf = builder.start_frame(
QuicFrameType.ACK,
capacity=ACK_FRAME_CAPACITY,
handler=self._on_ack_delivery,
handler_args=(space, space.largest_received_packet),
)
ranges = push_ack_frame(buf, space.ack_queue, ack_delay_encoded)
space.ack_at = None
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_ack_frame(
ranges=space.ack_queue, delay=ack_delay
)
)
# check if we need to trigger an ACK-of-ACK
if ranges > 1 and builder.packet_number % 8 == 0:
self._write_ping_frame(builder, comment="ACK-of-ACK trigger")
def _write_connection_close_frame(
self,
builder: QuicPacketBuilder,
epoch: tls.Epoch,
error_code: int,
frame_type: Optional[int],
reason_phrase: str,
) -> None:
# convert application-level close to transport-level close in early stages
if frame_type is None and epoch in (tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE):
error_code = QuicErrorCode.APPLICATION_ERROR
frame_type = QuicFrameType.PADDING
reason_phrase = ""
reason_bytes = reason_phrase.encode("utf8")
reason_length = len(reason_bytes)
if frame_type is None:
buf = builder.start_frame(
QuicFrameType.APPLICATION_CLOSE,
capacity=APPLICATION_CLOSE_FRAME_CAPACITY + reason_length,
)
buf.push_uint_var(error_code)
buf.push_uint_var(reason_length)
buf.push_bytes(reason_bytes)
else:
buf = builder.start_frame(
QuicFrameType.TRANSPORT_CLOSE,
capacity=TRANSPORT_CLOSE_FRAME_CAPACITY + reason_length,
)
buf.push_uint_var(error_code)
buf.push_uint_var(frame_type)
buf.push_uint_var(reason_length)
buf.push_bytes(reason_bytes)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_connection_close_frame(
error_code=error_code,
frame_type=frame_type,
reason_phrase=reason_phrase,
)
)
def _write_connection_limits(
self, builder: QuicPacketBuilder, space: QuicPacketSpace
) -> None:
"""
Raise MAX_DATA or MAX_STREAMS if needed.
"""
for limit in (
self._local_max_data,
self._local_max_streams_bidi,
self._local_max_streams_uni,
):
if limit.used * 2 > limit.value:
limit.value *= 2
self._logger.debug("Local %s raised to %d", limit.name, limit.value)
if limit.value != limit.sent:
buf = builder.start_frame(
limit.frame_type,
capacity=CONNECTION_LIMIT_FRAME_CAPACITY,
handler=self._on_connection_limit_delivery,
handler_args=(limit,),
)
buf.push_uint_var(limit.value)
limit.sent = limit.value
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_connection_limit_frame(
frame_type=limit.frame_type,
maximum=limit.value,
)
)
def _write_crypto_frame(
self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream
) -> bool:
frame_overhead = 3 + size_uint_var(stream.sender.next_offset)
frame = stream.sender.get_frame(builder.remaining_flight_space - frame_overhead)
if frame is not None:
buf = builder.start_frame(
QuicFrameType.CRYPTO,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data), False),
)
buf.push_uint_var(frame.offset)
buf.push_uint16(len(frame.data) | 0x4000)
buf.push_bytes(frame.data)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_crypto_frame(frame)
)
return True
return False
def _write_datagram_frame(
self, builder: QuicPacketBuilder, data: bytes, frame_type: QuicFrameType
) -> bool:
"""
Write a DATAGRAM frame.
Returns True if the frame was processed, False otherwise.
"""
assert frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH
length = len(data)
frame_size = 1 + size_uint_var(length) + length
buf = builder.start_frame(frame_type, capacity=frame_size)
buf.push_uint_var(length)
buf.push_bytes(data)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_datagram_frame(length=length)
)
return True
def _write_handshake_done_frame(self, builder: QuicPacketBuilder) -> None:
builder.start_frame(
QuicFrameType.HANDSHAKE_DONE,
capacity=HANDSHAKE_DONE_FRAME_CAPACITY,
handler=self._on_handshake_done_delivery,
)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_handshake_done_frame()
)
def _write_new_connection_id_frame(
self, builder: QuicPacketBuilder, connection_id: QuicConnectionId
) -> None:
retire_prior_to = 0 # FIXME
buf = builder.start_frame(
QuicFrameType.NEW_CONNECTION_ID,
capacity=NEW_CONNECTION_ID_FRAME_CAPACITY,
handler=self._on_new_connection_id_delivery,
handler_args=(connection_id,),
)
buf.push_uint_var(connection_id.sequence_number)
buf.push_uint_var(retire_prior_to)
buf.push_uint8(len(connection_id.cid))
buf.push_bytes(connection_id.cid)
buf.push_bytes(connection_id.stateless_reset_token)
connection_id.was_sent = True
self._events.append(events.ConnectionIdIssued(connection_id=connection_id.cid))
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_new_connection_id_frame(
connection_id=connection_id.cid,
retire_prior_to=retire_prior_to,
sequence_number=connection_id.sequence_number,
stateless_reset_token=connection_id.stateless_reset_token,
)
)
def _write_path_challenge_frame(
self, builder: QuicPacketBuilder, challenge: bytes
) -> None:
buf = builder.start_frame(
QuicFrameType.PATH_CHALLENGE, capacity=PATH_CHALLENGE_FRAME_CAPACITY
)
buf.push_bytes(challenge)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_path_challenge_frame(data=challenge)
)
def _write_path_response_frame(
self, builder: QuicPacketBuilder, challenge: bytes
) -> None:
buf = builder.start_frame(
QuicFrameType.PATH_RESPONSE, capacity=PATH_RESPONSE_FRAME_CAPACITY
)
buf.push_bytes(challenge)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_path_response_frame(data=challenge)
)
def _write_ping_frame(
self, builder: QuicPacketBuilder, uids: List[int] = [], comment=""
):
builder.start_frame(
QuicFrameType.PING,
capacity=PING_FRAME_CAPACITY,
handler=self._on_ping_delivery,
handler_args=(tuple(uids),),
)
self._logger.debug(
"Sending PING%s in packet %d",
" (%s)" % comment if comment else "",
builder.packet_number,
)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(self._quic_logger.encode_ping_frame())
def _write_reset_stream_frame(
self,
builder: QuicPacketBuilder,
stream: QuicStream,
) -> None:
buf = builder.start_frame(
frame_type=QuicFrameType.RESET_STREAM,
capacity=RESET_STREAM_FRAME_CAPACITY,
handler=stream.sender.on_reset_delivery,
)
frame = stream.sender.get_reset_frame()
buf.push_uint_var(frame.stream_id)
buf.push_uint_var(frame.error_code)
buf.push_uint_var(frame.final_size)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_reset_stream_frame(
error_code=frame.error_code,
final_size=frame.final_size,
stream_id=frame.stream_id,
)
)
def _write_retire_connection_id_frame(
self, builder: QuicPacketBuilder, sequence_number: int
) -> None:
buf = builder.start_frame(
QuicFrameType.RETIRE_CONNECTION_ID,
capacity=RETIRE_CONNECTION_ID_CAPACITY,
handler=self._on_retire_connection_id_delivery,
handler_args=(sequence_number,),
)
buf.push_uint_var(sequence_number)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_retire_connection_id_frame(sequence_number)
)
def _write_stop_sending_frame(
self,
builder: QuicPacketBuilder,
stream: QuicStream,
) -> None:
buf = builder.start_frame(
frame_type=QuicFrameType.STOP_SENDING,
capacity=STOP_SENDING_FRAME_CAPACITY,
handler=stream.receiver.on_stop_sending_delivery,
)
frame = stream.receiver.get_stop_frame()
buf.push_uint_var(frame.stream_id)
buf.push_uint_var(frame.error_code)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_stop_sending_frame(
error_code=frame.error_code, stream_id=frame.stream_id
)
)
def _write_stream_frame(
self,
builder: QuicPacketBuilder,
space: QuicPacketSpace,
stream: QuicStream,
max_offset: int,
) -> int:
# the frame data size is constrained by our peer's MAX_DATA and
# the space available in the current packet
frame_overhead = (
3
+ size_uint_var(stream.stream_id)
+ (
size_uint_var(stream.sender.next_offset)
if stream.sender.next_offset
else 0
)
)
previous_send_highest = stream.sender.highest_offset
frame = stream.sender.get_frame(
builder.remaining_flight_space - frame_overhead, max_offset
)
if frame is not None:
frame_type = QuicFrameType.STREAM_BASE | 2 # length
if frame.offset:
frame_type |= 4
if frame.fin:
frame_type |= 1
buf = builder.start_frame(
frame_type,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data), frame.fin),
)
buf.push_uint_var(stream.stream_id)
if frame.offset:
buf.push_uint_var(frame.offset)
buf.push_uint16(len(frame.data) | 0x4000)
buf.push_bytes(frame.data)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_stream_frame(
frame, stream_id=stream.stream_id
)
)
return stream.sender.highest_offset - previous_send_highest
else:
return 0
def _write_stream_limits(
self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream
) -> None:
"""
Raise MAX_STREAM_DATA if needed.
The only case where `stream.max_stream_data_local` is zero is for
locally created unidirectional streams. We skip such streams to avoid
spurious logging.
"""
if (
stream.max_stream_data_local
and stream.receiver.highest_offset * 2 > stream.max_stream_data_local
):
stream.max_stream_data_local *= 2
self._logger.debug(
"Stream %d local max_stream_data raised to %d",
stream.stream_id,
stream.max_stream_data_local,
)
if stream.max_stream_data_local_sent != stream.max_stream_data_local:
buf = builder.start_frame(
QuicFrameType.MAX_STREAM_DATA,
capacity=MAX_STREAM_DATA_FRAME_CAPACITY,
handler=self._on_max_stream_data_delivery,
handler_args=(stream,),
)
buf.push_uint_var(stream.stream_id)
buf.push_uint_var(stream.max_stream_data_local)
stream.max_stream_data_local_sent = stream.max_stream_data_local
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_max_stream_data_frame(
maximum=stream.max_stream_data_local, stream_id=stream.stream_id
)
)
def _write_streams_blocked_frame(
self, builder: QuicPacketBuilder, frame_type: QuicFrameType, limit: int
) -> None:
buf = builder.start_frame(frame_type, capacity=STREAMS_BLOCKED_CAPACITY)
buf.push_uint_var(limit)
# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_streams_blocked_frame(
is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI,
limit=limit,
)
)