This commit is contained in:
mofixx
2025-08-08 10:41:30 +02:00
parent 4444be3799
commit a5df3861fd
1674 changed files with 234266 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from .client import connect # noqa
from .protocol import QuicConnectionProtocol # noqa
from .server import serve # noqa

View File

@ -0,0 +1,98 @@
import asyncio
import socket
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Callable, Optional, cast
from ..quic.configuration import QuicConfiguration
from ..quic.connection import QuicConnection, QuicTokenHandler
from ..tls import SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["connect"]
@asynccontextmanager
async def connect(
host: str,
port: int,
*,
configuration: Optional[QuicConfiguration] = None,
create_protocol: Optional[Callable] = QuicConnectionProtocol,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stream_handler: Optional[QuicStreamHandler] = None,
token_handler: Optional[QuicTokenHandler] = None,
wait_connected: bool = True,
local_port: int = 0,
) -> AsyncGenerator[QuicConnectionProtocol, None]:
"""
Connect to a QUIC server at the given `host` and `port`.
:meth:`connect()` returns an awaitable. Awaiting it yields a
:class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to
create streams.
:func:`connect` also accepts the following optional arguments:
* ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration`
configuration object.
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is received.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
* ``wait_connected`` indicates whether the context manager should wait for the
connection to be established before yielding the
:class:`~aioquic.asyncio.QuicConnectionProtocol`. By default this is `True` but
you can set it to `False` if you want to immediately start sending data using
0-RTT.
* ``local_port`` is the UDP port number that this client wants to bind.
"""
loop = asyncio.get_event_loop()
local_host = "::"
# lookup remote address
infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
addr = infos[0][4]
if len(addr) == 2:
addr = ("::ffff:" + addr[0], addr[1], 0, 0)
# prepare QUIC connection
if configuration is None:
configuration = QuicConfiguration(is_client=True)
if configuration.server_name is None:
configuration.server_name = host
connection = QuicConnection(
configuration=configuration,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)
# explicitly enable IPv4/IPv6 dual stack
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
completed = False
try:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
sock.bind((local_host, local_port, 0, 0))
completed = True
finally:
if not completed:
sock.close()
# connect
transport, protocol = await loop.create_datagram_endpoint(
lambda: create_protocol(connection, stream_handler=stream_handler),
sock=sock,
)
protocol = cast(QuicConnectionProtocol, protocol)
try:
protocol.connect(addr, transmit=wait_connected)
if wait_connected:
await protocol.wait_connected()
yield protocol
finally:
protocol.close()
await protocol.wait_closed()
transport.close()

View File

@ -0,0 +1,272 @@
import asyncio
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
from ..quic import events
from ..quic.connection import NetworkAddress, QuicConnection
from ..quic.packet import QuicErrorCode
QuicConnectionIdHandler = Callable[[bytes], None]
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
class QuicConnectionProtocol(asyncio.DatagramProtocol):
def __init__(
self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
):
loop = asyncio.get_event_loop()
self._closed = asyncio.Event()
self._connected = False
self._connected_waiter: Optional[asyncio.Future[None]] = None
self._loop = loop
self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
self._quic = quic
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
self._timer: Optional[asyncio.TimerHandle] = None
self._timer_at: Optional[float] = None
self._transmit_task: Optional[asyncio.Handle] = None
self._transport: Optional[asyncio.DatagramTransport] = None
# callbacks
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
self._connection_terminated_handler: Callable[[], None] = lambda: None
if stream_handler is not None:
self._stream_handler = stream_handler
else:
self._stream_handler = lambda r, w: None
def change_connection_id(self) -> None:
"""
Change the connection ID used to communicate with the peer.
The previous connection ID will be retired.
"""
self._quic.change_connection_id()
self.transmit()
def close(
self,
error_code: int = QuicErrorCode.NO_ERROR,
reason_phrase: str = "",
) -> None:
"""
Close the connection.
:param error_code: An error code indicating why the connection is
being closed.
:param reason_phrase: A human-readable explanation of why the
connection is being closed.
"""
self._quic.close(
error_code=error_code,
reason_phrase=reason_phrase,
)
self.transmit()
def connect(self, addr: NetworkAddress, transmit=True) -> None:
"""
Initiate the TLS handshake.
This method can only be called for clients and a single time.
"""
self._quic.connect(addr, now=self._loop.time())
if transmit:
self.transmit()
async def create_stream(
self, is_unidirectional: bool = False
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Create a QUIC stream and return a pair of (reader, writer) objects.
The returned reader and writer objects are instances of
:class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes.
"""
stream_id = self._quic.get_next_available_stream_id(
is_unidirectional=is_unidirectional
)
return self._create_stream(stream_id)
def request_key_update(self) -> None:
"""
Request an update of the encryption keys.
"""
self._quic.request_key_update()
self.transmit()
async def ping(self) -> None:
"""
Ping the peer and wait for the response.
"""
waiter = self._loop.create_future()
uid = id(waiter)
self._ping_waiters[uid] = waiter
self._quic.send_ping(uid)
self.transmit()
await asyncio.shield(waiter)
def transmit(self) -> None:
"""
Send pending datagrams to the peer and arm the timer if needed.
This method is called automatically when data is received from the peer
or when a timer goes off. If you interact directly with the underlying
:class:`~aioquic.quic.connection.QuicConnection`, make sure you call this
method whenever data needs to be sent out to the network.
"""
self._transmit_task = None
# send datagrams
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
self._transport.sendto(data, addr)
# re-arm timer
timer_at = self._quic.get_timer()
if self._timer is not None and self._timer_at != timer_at:
self._timer.cancel()
self._timer = None
if self._timer is None and timer_at is not None:
self._timer = self._loop.call_at(timer_at, self._handle_timer)
self._timer_at = timer_at
async def wait_closed(self) -> None:
"""
Wait for the connection to be closed.
"""
await self._closed.wait()
async def wait_connected(self) -> None:
"""
Wait for the TLS handshake to complete.
"""
assert self._connected_waiter is None, "already awaiting connected"
if not self._connected:
self._connected_waiter = self._loop.create_future()
await asyncio.shield(self._connected_waiter)
# asyncio.Transport
def connection_made(self, transport: asyncio.BaseTransport) -> None:
""":meta private:"""
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
""":meta private:"""
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
self._process_events()
self.transmit()
# overridable
def quic_event_received(self, event: events.QuicEvent) -> None:
"""
Called when a QUIC event is received.
Reimplement this in your subclass to handle the events.
"""
# FIXME: move this to a subclass
if isinstance(event, events.ConnectionTerminated):
for reader in self._stream_readers.values():
reader.feed_eof()
elif isinstance(event, events.StreamDataReceived):
reader = self._stream_readers.get(event.stream_id, None)
if reader is None:
reader, writer = self._create_stream(event.stream_id)
self._stream_handler(reader, writer)
reader.feed_data(event.data)
if event.end_stream:
reader.feed_eof()
# private
def _create_stream(
self, stream_id: int
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
adapter = QuicStreamAdapter(self, stream_id)
reader = asyncio.StreamReader()
protocol = asyncio.streams.StreamReaderProtocol(reader)
writer = asyncio.StreamWriter(adapter, protocol, reader, self._loop)
self._stream_readers[stream_id] = reader
return reader, writer
def _handle_timer(self) -> None:
now = max(self._timer_at, self._loop.time())
self._timer = None
self._timer_at = None
self._quic.handle_timer(now=now)
self._process_events()
self.transmit()
def _process_events(self) -> None:
event = self._quic.next_event()
while event is not None:
if isinstance(event, events.ConnectionIdIssued):
self._connection_id_issued_handler(event.connection_id)
elif isinstance(event, events.ConnectionIdRetired):
self._connection_id_retired_handler(event.connection_id)
elif isinstance(event, events.ConnectionTerminated):
self._connection_terminated_handler()
# abort connection waiter
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected_waiter = None
waiter.set_exception(ConnectionError)
# abort ping waiters
for waiter in self._ping_waiters.values():
waiter.set_exception(ConnectionError)
self._ping_waiters.clear()
self._closed.set()
elif isinstance(event, events.HandshakeCompleted):
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected = True
self._connected_waiter = None
waiter.set_result(None)
elif isinstance(event, events.PingAcknowledged):
waiter = self._ping_waiters.pop(event.uid, None)
if waiter is not None:
waiter.set_result(None)
self.quic_event_received(event)
event = self._quic.next_event()
def _transmit_soon(self) -> None:
if self._transmit_task is None:
self._transmit_task = self._loop.call_soon(self.transmit)
class QuicStreamAdapter(asyncio.Transport):
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
self.protocol = protocol
self.stream_id = stream_id
self._closing = False
def can_write_eof(self) -> bool:
return True
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""
Get information about the underlying QUIC stream.
"""
if name == "stream_id":
return self.stream_id
def write(self, data):
self.protocol._quic.send_stream_data(self.stream_id, data)
self.protocol._transmit_soon()
def write_eof(self):
if self._closing:
return
self._closing = True
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
self.protocol._transmit_soon()
def close(self):
self.write_eof()
def is_closing(self) -> bool:
return self._closing

View File

@ -0,0 +1,215 @@
import asyncio
import os
from functools import partial
from typing import Callable, Dict, Optional, Text, Union, cast
from ..buffer import Buffer
from ..quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
from ..quic.connection import NetworkAddress, QuicConnection
from ..quic.packet import (
QuicPacketType,
encode_quic_retry,
encode_quic_version_negotiation,
pull_quic_header,
)
from ..quic.retry import QuicRetryTokenHandler
from ..tls import SessionTicketFetcher, SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["serve"]
class QuicServer(asyncio.DatagramProtocol):
def __init__(
self,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
retry: bool = False,
stream_handler: Optional[QuicStreamHandler] = None,
) -> None:
self._configuration = configuration
self._create_protocol = create_protocol
self._loop = asyncio.get_event_loop()
self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._transport: Optional[asyncio.DatagramTransport] = None
self._stream_handler = stream_handler
if retry:
self._retry = QuicRetryTokenHandler()
else:
self._retry = None
def close(self):
for protocol in set(self._protocols.values()):
protocol.close()
self._protocols.clear()
self._transport.close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
data = cast(bytes, data)
buf = Buffer(data=data)
try:
header = pull_quic_header(
buf, host_cid_length=self._configuration.connection_id_length
)
except ValueError:
return
# version negotiation
if (
header.version is not None
and header.version not in self._configuration.supported_versions
):
self._transport.sendto(
encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self._configuration.supported_versions,
),
addr,
)
return
protocol = self._protocols.get(header.destination_cid, None)
original_destination_connection_id: Optional[bytes] = None
retry_source_connection_id: Optional[bytes] = None
if (
protocol is None
and len(data) >= SMALLEST_MAX_DATAGRAM_SIZE
and header.packet_type == QuicPacketType.INITIAL
):
# retry
if self._retry is not None:
if not header.token:
# create a retry token
source_cid = os.urandom(8)
self._transport.sendto(
encode_quic_retry(
version=header.version,
source_cid=source_cid,
destination_cid=header.source_cid,
original_destination_cid=header.destination_cid,
retry_token=self._retry.create_token(
addr, header.destination_cid, source_cid
),
),
addr,
)
return
else:
# validate retry token
try:
(
original_destination_connection_id,
retry_source_connection_id,
) = self._retry.validate_token(addr, header.token)
except ValueError:
return
else:
original_destination_connection_id = header.destination_cid
# create new connection
connection = QuicConnection(
configuration=self._configuration,
original_destination_connection_id=original_destination_connection_id,
retry_source_connection_id=retry_source_connection_id,
session_ticket_fetcher=self._session_ticket_fetcher,
session_ticket_handler=self._session_ticket_handler,
)
protocol = self._create_protocol(
connection, stream_handler=self._stream_handler
)
protocol.connection_made(self._transport)
# register callbacks
protocol._connection_id_issued_handler = partial(
self._connection_id_issued, protocol=protocol
)
protocol._connection_id_retired_handler = partial(
self._connection_id_retired, protocol=protocol
)
protocol._connection_terminated_handler = partial(
self._connection_terminated, protocol=protocol
)
self._protocols[header.destination_cid] = protocol
self._protocols[connection.host_cid] = protocol
if protocol is not None:
protocol.datagram_received(data, addr)
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
self._protocols[cid] = protocol
def _connection_id_retired(
self, cid: bytes, protocol: QuicConnectionProtocol
) -> None:
assert self._protocols[cid] == protocol
del self._protocols[cid]
def _connection_terminated(self, protocol: QuicConnectionProtocol):
for cid, proto in list(self._protocols.items()):
if proto == protocol:
del self._protocols[cid]
async def serve(
host: str,
port: int,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
retry: bool = False,
stream_handler: QuicStreamHandler = None,
) -> QuicServer:
"""
Start a QUIC server at the given `host` and `port`.
:func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration`
containing TLS certificate and private key as the ``configuration`` argument.
:func:`serve` also accepts the following optional arguments:
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_fetcher`` is a callback which is invoked by the TLS
engine when a session ticket is presented by the peer. It should return
the session ticket with the specified ID or `None` if it is not found.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is issued. It should store the session
ticket for future lookup.
* ``retry`` specifies whether client addresses should be validated prior to
the cryptographic handshake using a retry packet.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
"""
loop = asyncio.get_event_loop()
_, protocol = await loop.create_datagram_endpoint(
lambda: QuicServer(
configuration=configuration,
create_protocol=create_protocol,
session_ticket_fetcher=session_ticket_fetcher,
session_ticket_handler=session_ticket_handler,
retry=retry,
stream_handler=stream_handler,
),
local_addr=(host, port),
)
return protocol