Files
mofixx a5df3861fd Code
2025-08-08 10:41:30 +02:00

129 lines
3.8 KiB
Python

import abc
from typing import Any, Dict, Iterable, Optional, Protocol
from ..packet_builder import QuicSentPacket
K_GRANULARITY = 0.001 # seconds
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2
class QuicCongestionControl(abc.ABC):
"""
Base class for congestion control implementations.
"""
bytes_in_flight: int = 0
congestion_window: int = 0
ssthresh: Optional[int] = None
def __init__(self, *, max_datagram_size: int) -> None:
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size
@abc.abstractmethod
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ...
@abc.abstractmethod
def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ...
@abc.abstractmethod
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ...
@abc.abstractmethod
def on_packets_lost(
self, *, now: float, packets: Iterable[QuicSentPacket]
) -> None: ...
@abc.abstractmethod
def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ...
def get_log_data(self) -> Dict[str, Any]:
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
if self.ssthresh is not None:
data["ssthresh"] = self.ssthresh
return data
class QuicCongestionControlFactory(Protocol):
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ...
class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""
def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5
self._filtered_min: Optional[float] = None
self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]
def add_rtt(self, *, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1
if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True
if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample
def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt=rtt)
self._sample_time = now
if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max
delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False
_factories: Dict[str, QuicCongestionControlFactory] = {}
def create_congestion_control(
name: str, *, max_datagram_size: int
) -> QuicCongestionControl:
"""
Create an instance of the `name` congestion control algorithm.
"""
try:
factory = _factories[name]
except KeyError:
raise Exception(f"Unknown congestion control algorithm: {name}")
return factory(max_datagram_size=max_datagram_size)
def register_congestion_control(
name: str, factory: QuicCongestionControlFactory
) -> None:
"""
Register a congestion control algorithm named `name`.
"""
_factories[name] = factory