Code
This commit is contained in:
@ -0,0 +1,272 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
|
||||
|
||||
from ..quic import events
|
||||
from ..quic.connection import NetworkAddress, QuicConnection
|
||||
from ..quic.packet import QuicErrorCode
|
||||
|
||||
QuicConnectionIdHandler = Callable[[bytes], None]
|
||||
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
|
||||
|
||||
|
||||
class QuicConnectionProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(
|
||||
self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
|
||||
):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._closed = asyncio.Event()
|
||||
self._connected = False
|
||||
self._connected_waiter: Optional[asyncio.Future[None]] = None
|
||||
self._loop = loop
|
||||
self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
|
||||
self._quic = quic
|
||||
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
|
||||
self._timer: Optional[asyncio.TimerHandle] = None
|
||||
self._timer_at: Optional[float] = None
|
||||
self._transmit_task: Optional[asyncio.Handle] = None
|
||||
self._transport: Optional[asyncio.DatagramTransport] = None
|
||||
|
||||
# callbacks
|
||||
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_terminated_handler: Callable[[], None] = lambda: None
|
||||
if stream_handler is not None:
|
||||
self._stream_handler = stream_handler
|
||||
else:
|
||||
self._stream_handler = lambda r, w: None
|
||||
|
||||
def change_connection_id(self) -> None:
|
||||
"""
|
||||
Change the connection ID used to communicate with the peer.
|
||||
|
||||
The previous connection ID will be retired.
|
||||
"""
|
||||
self._quic.change_connection_id()
|
||||
self.transmit()
|
||||
|
||||
def close(
|
||||
self,
|
||||
error_code: int = QuicErrorCode.NO_ERROR,
|
||||
reason_phrase: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Close the connection.
|
||||
|
||||
:param error_code: An error code indicating why the connection is
|
||||
being closed.
|
||||
:param reason_phrase: A human-readable explanation of why the
|
||||
connection is being closed.
|
||||
"""
|
||||
self._quic.close(
|
||||
error_code=error_code,
|
||||
reason_phrase=reason_phrase,
|
||||
)
|
||||
self.transmit()
|
||||
|
||||
def connect(self, addr: NetworkAddress, transmit=True) -> None:
|
||||
"""
|
||||
Initiate the TLS handshake.
|
||||
|
||||
This method can only be called for clients and a single time.
|
||||
"""
|
||||
self._quic.connect(addr, now=self._loop.time())
|
||||
if transmit:
|
||||
self.transmit()
|
||||
|
||||
async def create_stream(
|
||||
self, is_unidirectional: bool = False
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
"""
|
||||
Create a QUIC stream and return a pair of (reader, writer) objects.
|
||||
|
||||
The returned reader and writer objects are instances of
|
||||
:class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes.
|
||||
"""
|
||||
stream_id = self._quic.get_next_available_stream_id(
|
||||
is_unidirectional=is_unidirectional
|
||||
)
|
||||
return self._create_stream(stream_id)
|
||||
|
||||
def request_key_update(self) -> None:
|
||||
"""
|
||||
Request an update of the encryption keys.
|
||||
"""
|
||||
self._quic.request_key_update()
|
||||
self.transmit()
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""
|
||||
Ping the peer and wait for the response.
|
||||
"""
|
||||
waiter = self._loop.create_future()
|
||||
uid = id(waiter)
|
||||
self._ping_waiters[uid] = waiter
|
||||
self._quic.send_ping(uid)
|
||||
self.transmit()
|
||||
await asyncio.shield(waiter)
|
||||
|
||||
def transmit(self) -> None:
|
||||
"""
|
||||
Send pending datagrams to the peer and arm the timer if needed.
|
||||
|
||||
This method is called automatically when data is received from the peer
|
||||
or when a timer goes off. If you interact directly with the underlying
|
||||
:class:`~aioquic.quic.connection.QuicConnection`, make sure you call this
|
||||
method whenever data needs to be sent out to the network.
|
||||
"""
|
||||
self._transmit_task = None
|
||||
|
||||
# send datagrams
|
||||
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
|
||||
self._transport.sendto(data, addr)
|
||||
|
||||
# re-arm timer
|
||||
timer_at = self._quic.get_timer()
|
||||
if self._timer is not None and self._timer_at != timer_at:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
if self._timer is None and timer_at is not None:
|
||||
self._timer = self._loop.call_at(timer_at, self._handle_timer)
|
||||
self._timer_at = timer_at
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
"""
|
||||
Wait for the connection to be closed.
|
||||
"""
|
||||
await self._closed.wait()
|
||||
|
||||
async def wait_connected(self) -> None:
|
||||
"""
|
||||
Wait for the TLS handshake to complete.
|
||||
"""
|
||||
assert self._connected_waiter is None, "already awaiting connected"
|
||||
if not self._connected:
|
||||
self._connected_waiter = self._loop.create_future()
|
||||
await asyncio.shield(self._connected_waiter)
|
||||
|
||||
# asyncio.Transport
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
""":meta private:"""
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
|
||||
""":meta private:"""
|
||||
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
# overridable
|
||||
|
||||
def quic_event_received(self, event: events.QuicEvent) -> None:
|
||||
"""
|
||||
Called when a QUIC event is received.
|
||||
|
||||
Reimplement this in your subclass to handle the events.
|
||||
"""
|
||||
# FIXME: move this to a subclass
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
for reader in self._stream_readers.values():
|
||||
reader.feed_eof()
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
reader = self._stream_readers.get(event.stream_id, None)
|
||||
if reader is None:
|
||||
reader, writer = self._create_stream(event.stream_id)
|
||||
self._stream_handler(reader, writer)
|
||||
reader.feed_data(event.data)
|
||||
if event.end_stream:
|
||||
reader.feed_eof()
|
||||
|
||||
# private
|
||||
|
||||
def _create_stream(
|
||||
self, stream_id: int
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
adapter = QuicStreamAdapter(self, stream_id)
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.streams.StreamReaderProtocol(reader)
|
||||
writer = asyncio.StreamWriter(adapter, protocol, reader, self._loop)
|
||||
self._stream_readers[stream_id] = reader
|
||||
return reader, writer
|
||||
|
||||
def _handle_timer(self) -> None:
|
||||
now = max(self._timer_at, self._loop.time())
|
||||
self._timer = None
|
||||
self._timer_at = None
|
||||
self._quic.handle_timer(now=now)
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
def _process_events(self) -> None:
|
||||
event = self._quic.next_event()
|
||||
while event is not None:
|
||||
if isinstance(event, events.ConnectionIdIssued):
|
||||
self._connection_id_issued_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionIdRetired):
|
||||
self._connection_id_retired_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionTerminated):
|
||||
self._connection_terminated_handler()
|
||||
|
||||
# abort connection waiter
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected_waiter = None
|
||||
waiter.set_exception(ConnectionError)
|
||||
|
||||
# abort ping waiters
|
||||
for waiter in self._ping_waiters.values():
|
||||
waiter.set_exception(ConnectionError)
|
||||
self._ping_waiters.clear()
|
||||
|
||||
self._closed.set()
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected = True
|
||||
self._connected_waiter = None
|
||||
waiter.set_result(None)
|
||||
elif isinstance(event, events.PingAcknowledged):
|
||||
waiter = self._ping_waiters.pop(event.uid, None)
|
||||
if waiter is not None:
|
||||
waiter.set_result(None)
|
||||
self.quic_event_received(event)
|
||||
event = self._quic.next_event()
|
||||
|
||||
def _transmit_soon(self) -> None:
|
||||
if self._transmit_task is None:
|
||||
self._transmit_task = self._loop.call_soon(self.transmit)
|
||||
|
||||
|
||||
class QuicStreamAdapter(asyncio.Transport):
|
||||
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
|
||||
self.protocol = protocol
|
||||
self.stream_id = stream_id
|
||||
self._closing = False
|
||||
|
||||
def can_write_eof(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get information about the underlying QUIC stream.
|
||||
"""
|
||||
if name == "stream_id":
|
||||
return self.stream_id
|
||||
|
||||
def write(self, data):
|
||||
self.protocol._quic.send_stream_data(self.stream_id, data)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def write_eof(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def close(self):
|
||||
self.write_eof()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
Reference in New Issue
Block a user