Code
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,128 @@
|
||||
import abc
|
||||
from typing import Any, Dict, Iterable, Optional, Protocol
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
|
||||
K_GRANULARITY = 0.001 # seconds
|
||||
K_INITIAL_WINDOW = 10
|
||||
K_MINIMUM_WINDOW = 2
|
||||
|
||||
|
||||
class QuicCongestionControl(abc.ABC):
|
||||
"""
|
||||
Base class for congestion control implementations.
|
||||
"""
|
||||
|
||||
bytes_in_flight: int = 0
|
||||
congestion_window: int = 0
|
||||
ssthresh: Optional[int] = None
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_lost(
|
||||
self, *, now: float, packets: Iterable[QuicSentPacket]
|
||||
) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ...
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
|
||||
if self.ssthresh is not None:
|
||||
data["ssthresh"] = self.ssthresh
|
||||
return data
|
||||
|
||||
|
||||
class QuicCongestionControlFactory(Protocol):
|
||||
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ...
|
||||
|
||||
|
||||
class QuicRttMonitor:
|
||||
"""
|
||||
Roundtrip time monitor for HyStart.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._increases = 0
|
||||
self._last_time = None
|
||||
self._ready = False
|
||||
self._size = 5
|
||||
|
||||
self._filtered_min: Optional[float] = None
|
||||
|
||||
self._sample_idx = 0
|
||||
self._sample_max: Optional[float] = None
|
||||
self._sample_min: Optional[float] = None
|
||||
self._sample_time = 0.0
|
||||
self._samples = [0.0 for i in range(self._size)]
|
||||
|
||||
def add_rtt(self, *, rtt: float) -> None:
|
||||
self._samples[self._sample_idx] = rtt
|
||||
self._sample_idx += 1
|
||||
|
||||
if self._sample_idx >= self._size:
|
||||
self._sample_idx = 0
|
||||
self._ready = True
|
||||
|
||||
if self._ready:
|
||||
self._sample_max = self._samples[0]
|
||||
self._sample_min = self._samples[0]
|
||||
for sample in self._samples[1:]:
|
||||
if sample < self._sample_min:
|
||||
self._sample_min = sample
|
||||
elif sample > self._sample_max:
|
||||
self._sample_max = sample
|
||||
|
||||
def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
|
||||
if now > self._sample_time + K_GRANULARITY:
|
||||
self.add_rtt(rtt=rtt)
|
||||
self._sample_time = now
|
||||
|
||||
if self._ready:
|
||||
if self._filtered_min is None or self._filtered_min > self._sample_max:
|
||||
self._filtered_min = self._sample_max
|
||||
|
||||
delta = self._sample_min - self._filtered_min
|
||||
if delta * 4 >= self._filtered_min:
|
||||
self._increases += 1
|
||||
if self._increases >= self._size:
|
||||
return True
|
||||
elif delta > 0:
|
||||
self._increases = 0
|
||||
return False
|
||||
|
||||
|
||||
_factories: Dict[str, QuicCongestionControlFactory] = {}
|
||||
|
||||
|
||||
def create_congestion_control(
|
||||
name: str, *, max_datagram_size: int
|
||||
) -> QuicCongestionControl:
|
||||
"""
|
||||
Create an instance of the `name` congestion control algorithm.
|
||||
"""
|
||||
try:
|
||||
factory = _factories[name]
|
||||
except KeyError:
|
||||
raise Exception(f"Unknown congestion control algorithm: {name}")
|
||||
return factory(max_datagram_size=max_datagram_size)
|
||||
|
||||
|
||||
def register_congestion_control(
|
||||
name: str, factory: QuicCongestionControlFactory
|
||||
) -> None:
|
||||
"""
|
||||
Register a congestion control algorithm named `name`.
|
||||
"""
|
||||
_factories[name] = factory
|
||||
@ -0,0 +1,212 @@
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_INITIAL_WINDOW,
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions)
|
||||
K_CUBIC_C = 0.4
|
||||
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
|
||||
K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity
|
||||
|
||||
|
||||
def better_cube_root(x: float) -> float:
|
||||
if x < 0:
|
||||
# avoid precision errors that make the cube root returns an imaginary number
|
||||
return -((-x) ** (1.0 / 3.0))
|
||||
else:
|
||||
return (x) ** (1.0 / 3.0)
|
||||
|
||||
|
||||
class CubicCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
Cubic congestion control implementation for aioquic
|
||||
"""
|
||||
|
||||
def __init__(self, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
# increase by one segment
|
||||
self.additive_increase_factor: int = max_datagram_size
|
||||
self._max_datagram_size: int = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
self.rtt = 0.02 # starting RTT is considered to be 20ms
|
||||
|
||||
self.reset()
|
||||
|
||||
self.last_ack = 0.0
|
||||
|
||||
def W_cubic(self, t) -> int:
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments)
|
||||
return int(target_segments * self._max_datagram_size)
|
||||
|
||||
def is_reno_friendly(self, t) -> bool:
|
||||
return self.W_cubic(t) < self._W_est
|
||||
|
||||
def is_concave(self) -> bool:
|
||||
return self.congestion_window < self._W_max
|
||||
|
||||
def reset(self) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size
|
||||
self.ssthresh = None
|
||||
|
||||
self._first_slow_start = True
|
||||
self._starting_congestion_avoidance = False
|
||||
self.K: float = 0.0
|
||||
self._W_est = 0
|
||||
self._cwnd_epoch = 0
|
||||
self._t_epoch = 0.0
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
self.last_ack = packet.sent_time
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
if self._first_slow_start and not self._starting_congestion_avoidance:
|
||||
# exiting slow start without having a loss
|
||||
self._first_slow_start = False
|
||||
self._W_max = self.congestion_window
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
# initialize the variables used at start of congestion avoidance
|
||||
if self._starting_congestion_avoidance:
|
||||
self._starting_congestion_avoidance = False
|
||||
self._first_slow_start = False
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
self._W_est = int(
|
||||
self._W_est
|
||||
+ self.additive_increase_factor
|
||||
* (packet.sent_bytes / self.congestion_window)
|
||||
)
|
||||
|
||||
t = now - self._t_epoch
|
||||
|
||||
target: int = 0
|
||||
W_cubic = self.W_cubic(t + self.rtt)
|
||||
if W_cubic < self.congestion_window:
|
||||
target = self.congestion_window
|
||||
elif W_cubic > 1.5 * self.congestion_window:
|
||||
target = int(self.congestion_window * 1.5)
|
||||
else:
|
||||
target = W_cubic
|
||||
|
||||
if self.is_reno_friendly(t):
|
||||
# reno friendly region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region)
|
||||
self.congestion_window = self._W_est
|
||||
elif self.is_concave():
|
||||
# concave region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# convex region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
if self.last_ack == 0.0:
|
||||
return
|
||||
elapsed_idle = packet.sent_time - self.last_ack
|
||||
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
|
||||
self.reset()
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
|
||||
# Normal congestion handle, can't be used in same time as fast convergence
|
||||
# self._W_max = self.congestion_window
|
||||
|
||||
# fast convergence
|
||||
if self._W_max is not None and self.congestion_window < self._W_max:
|
||||
self._W_max = int(
|
||||
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
|
||||
)
|
||||
else:
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
# normal congestion MD
|
||||
flight_size = self.bytes_in_flight
|
||||
new_ssthresh = max(
|
||||
int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = new_ssthresh
|
||||
self.congestion_window = max(
|
||||
self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size
|
||||
)
|
||||
|
||||
# restart a new congestion avoidance phase
|
||||
self._starting_congestion_avoidance = True
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
self.rtt = rtt
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
rtt=rtt, now=now
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = super().get_log_data()
|
||||
|
||||
data["cubic-wmax"] = int(self._W_max)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
register_congestion_control("cubic", CubicCongestionControl)
|
||||
@ -0,0 +1,77 @@
|
||||
from typing import Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
K_LOSS_REDUCTION_FACTOR = 0.5
|
||||
|
||||
|
||||
class RenoCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
New Reno congestion control.
|
||||
"""
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
self._congestion_stash = 0
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
# don't increase window in congestion recovery
|
||||
if packet.sent_time <= self._congestion_recovery_start_time:
|
||||
return
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
self._congestion_stash += packet.sent_bytes
|
||||
count = self._congestion_stash // self.congestion_window
|
||||
if count:
|
||||
self._congestion_stash -= count * self.congestion_window
|
||||
self.congestion_window += count * self._max_datagram_size
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
self.congestion_window = max(
|
||||
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
# TODO : collapse congestion window if persistent congestion
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
now=now, rtt=rtt
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
|
||||
register_congestion_control("reno", RenoCongestionControl)
|
||||
Reference in New Issue
Block a user