287 lines
11 KiB
Python
287 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
import argparse
|
||
import asyncio
|
||
import time
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Tuple
|
||
|
||
from aioquic.quic.configuration import QuicConfiguration
|
||
from aioquic.quic.connection import QuicConnection
|
||
from aioquic.tls import SessionTicket
|
||
from aioquic.asyncio.client import connect
|
||
from aioquic.quic.events import QuicEvent, HandshakeCompleted, ConnectionTerminated, ProtocolNegotiated, DatagramFrameReceived
|
||
from aioquic.h0.connection import H0_ALPN
|
||
from aioquic.h3.connection import H3_ALPN
|
||
|
||
# RFC references:
|
||
# - RFC 9000 QUIC transport: anti-amplification (<=3x before address validation), Initial >=1200B[1]
|
||
# - RFC 9001 TLS over QUIC: handshake, tokens, address validation via Retry/token acceptance[2]
|
||
# - RFC 8446 TLS 1.3 (used by QUIC): handshake and tokens at TLS layer carried in QUIC[3]
|
||
|
||
@dataclass
|
||
class FlowReport:
|
||
host: str
|
||
port: int
|
||
alpn: str
|
||
used_retry: bool
|
||
token_present: bool
|
||
address_validated_ts: Optional[float]
|
||
client_bytes_sent_before_av: int
|
||
server_bytes_recv_before_av: int
|
||
client_bytes_sent_total: int
|
||
server_bytes_recv_total: int
|
||
compliant_before_av: Optional[bool]
|
||
error: Optional[str]
|
||
handshake_completed: bool
|
||
protocol: Optional[str]
|
||
|
||
class ByteMeteringTransport(asyncio.DatagramTransport):
|
||
def __init__(self, inner: asyncio.DatagramTransport):
|
||
self._inner = inner
|
||
self.bytes_sent = 0
|
||
|
||
def sendto(self, data: bytes, addr=None):
|
||
# UDP payload bytes we send
|
||
self.bytes_sent += len(data)
|
||
self._inner.sendto(data, addr)
|
||
|
||
# Delegate all other methods
|
||
def __getattr__(self, item):
|
||
return getattr(self._inner, item)
|
||
|
||
class MeteringProtocol(asyncio.DatagramProtocol):
|
||
def __init__(self, quic: QuicConnection, on_datagram_recv):
|
||
self.quic = quic
|
||
self.on_datagram_recv = on_datagram_recv
|
||
self.transport: Optional[ByteMeteringTransport] = None
|
||
|
||
def connection_made(self, transport):
|
||
# Wrap real transport with metering
|
||
self.transport = ByteMeteringTransport(transport)
|
||
self.quic._network_path.send_datagram = self.transport.sendto # ensure quic uses our wrapper
|
||
|
||
def datagram_received(self, data, addr):
|
||
# Count server->client bytes (payload)
|
||
self.on_datagram_recv(len(data))
|
||
self.quic.receive_datagram(data, time.time(), addr)
|
||
|
||
def error_received(self, exc):
|
||
pass
|
||
|
||
def connection_lost(self, exc):
|
||
pass
|
||
|
||
async def quic_attempt(host: str, port: int, alpn: str, sni: Optional[str], timeout: float, token: Optional[bytes]) -> FlowReport:
|
||
cfg = QuicConfiguration(
|
||
is_client=True,
|
||
alpn_protocols=[alpn],
|
||
verify_mode=None, # for probing; set to CERT_REQUIRED for strict TLS
|
||
server_name=sni or host,
|
||
)
|
||
# Inject token if provided (address validation token)
|
||
if token:
|
||
cfg.retry = True # signal client can handle retry/token
|
||
cfg.token = token
|
||
|
||
# Bookkeeping
|
||
server_recv_total = 0
|
||
server_recv_before_av = 0
|
||
client_sent_before_av = 0
|
||
client_sent_total = 0
|
||
address_validated_ts = None
|
||
handshake_completed = False
|
||
proto = None
|
||
used_retry = False
|
||
token_present = token is not None
|
||
error = None
|
||
|
||
# Handler to count per-datagram bytes from server
|
||
def on_dgram_recv(n: int):
|
||
nonlocal server_recv_total
|
||
server_recv_total += n
|
||
|
||
loop = asyncio.get_event_loop()
|
||
|
||
# Low-level connect to intercept transport
|
||
try:
|
||
async with connect(
|
||
host,
|
||
port,
|
||
configuration=cfg,
|
||
create_protocol=None,
|
||
wait_connected=False, # we want to observe early phases
|
||
) as client:
|
||
# Access underlying connection and transport
|
||
quic: QuicConnection = client._quic
|
||
protocol = MeteringProtocol(quic, on_dgram_recv)
|
||
|
||
# Create custom UDP endpoint so we can wrap transport
|
||
transport, _ = await loop.create_datagram_endpoint(
|
||
lambda: protocol,
|
||
remote_addr=(host, port),
|
||
)
|
||
client._transport = protocol.transport # swap client transport to our wrapper
|
||
|
||
start = time.time()
|
||
deadline = start + timeout
|
||
|
||
# Kick off handshake by sending Initial
|
||
quic.connect()
|
||
# Ensure Initial datagram >=1200B (aioquic does this by spec[1])
|
||
for datagram, addr in quic.datagrams_to_send(now=time.time()):
|
||
protocol.transport.sendto(datagram, addr)
|
||
|
||
# Event loop
|
||
while time.time() < deadline:
|
||
await asyncio.sleep(0.001)
|
||
|
||
# Pump outbound
|
||
for datagram, addr in quic.datagrams_to_send(now=time.time()):
|
||
protocol.transport.sendto(datagram, addr)
|
||
|
||
# Process events
|
||
for event in quic.poll_events():
|
||
# Negotiate protocol (ALPN)
|
||
if isinstance(event, ProtocolNegotiated):
|
||
proto = event.alpn_protocol
|
||
# Handshake completed -> address validated by definition[2]
|
||
if isinstance(event, HandshakeCompleted):
|
||
handshake_completed = True
|
||
if address_validated_ts is None:
|
||
address_validated_ts = time.time()
|
||
if isinstance(event, ConnectionTerminated):
|
||
# End loop gracefully
|
||
deadline = time.time()
|
||
break
|
||
|
||
# Track bytes sent by client
|
||
client_sent_total = protocol.transport.bytes_sent
|
||
|
||
# Address Validation moment:
|
||
if address_validated_ts is None:
|
||
# Before AV, snapshot budgets
|
||
client_sent_before_av = client_sent_total
|
||
server_recv_before_av = server_recv_total
|
||
|
||
# Give QUIC timers a tick
|
||
quic.handle_timer(time.time())
|
||
|
||
# Heuristic: detect Retry packet
|
||
# In aioquic, quic._retry_sent/_retry_received are internal; fallback: token becomes set after Retry
|
||
if quic._retry_received: # type: ignore[attr-defined]
|
||
used_retry = True
|
||
|
||
# Exit if handshake done and settled
|
||
if handshake_completed and (time.time() - address_validated_ts) > 0.1:
|
||
break
|
||
|
||
# Close connection
|
||
quic.close(error_code=0x0)
|
||
for datagram, addr in quic.datagrams_to_send(now=time.time()):
|
||
protocol.transport.sendto(datagram, addr)
|
||
await asyncio.sleep(0.05)
|
||
transport.close()
|
||
|
||
except Exception as e:
|
||
error = str(e)
|
||
|
||
compliant = None
|
||
if address_validated_ts is None:
|
||
# If AV never happened, we still evaluate budget until timeout
|
||
if client_sent_before_av > 0:
|
||
compliant = server_recv_before_av <= 3 * client_sent_before_av # RFC 9000 3× rule[1]
|
||
else:
|
||
# Evaluate until AV moment
|
||
if client_sent_before_av > 0:
|
||
compliant = server_recv_before_av <= 3 * client_sent_before_av # RFC 9000 3× rule[1]
|
||
|
||
return FlowReport(
|
||
host=host,
|
||
port=port,
|
||
alpn=alpn,
|
||
used_retry=used_retry,
|
||
token_present=token_present,
|
||
address_validated_ts=address_validated_ts,
|
||
client_bytes_sent_before_av=client_sent_before_av,
|
||
server_bytes_recv_before_av=server_recv_before_av,
|
||
client_bytes_sent_total=client_sent_total,
|
||
server_bytes_recv_total=server_recv_total,
|
||
compliant_before_av=compliant,
|
||
error=error,
|
||
handshake_completed=handshake_completed,
|
||
protocol=proto,
|
||
)
|
||
|
||
def print_report(title: str, report: FlowReport):
|
||
print(f"=== {title} ===")
|
||
print(f"Target: {report.host}:{report.port}, ALPN={report.alpn}")
|
||
print(f"Protocol negotiated: {report.protocol}")
|
||
print(f"Retry observed: {report.used_retry}, Token provided: {report.token_present}")
|
||
print(f"Handshake completed: {report.handshake_completed}, AddressValidated: {report.address_validated_ts is not None}")
|
||
print(f"Client bytes sent before AV: {report.client_bytes_sent_before_av}")
|
||
print(f"Server bytes received before AV: {report.server_bytes_recv_before_av}")
|
||
budget = 3 * report.client_bytes_sent_before_av
|
||
print(f"Anti-Amplification budget (3× client-before-AV): {budget}")
|
||
if report.compliant_before_av is None:
|
||
print("Compliance before AV: Unknown (insufficient data)")
|
||
else:
|
||
status = "COMPLIANT" if report.compliant_before_av else "VIOLATION"
|
||
print(f"Compliance before AV: {status}")
|
||
print(f"Totals: client_sent={report.client_bytes_sent_total}, server_recv={report.server_bytes_recv_total}")
|
||
if report.error:
|
||
print(f"Error: {report.error}")
|
||
print()
|
||
|
||
async def main():
|
||
parser = argparse.ArgumentParser(description="QUIC Anti-Amplification & Retry Probe")
|
||
parser.add_argument("--host", required=True)
|
||
parser.add_argument("--port", type=int, default=443)
|
||
parser.add_argument("--alpn", default="h3")
|
||
parser.add_argument("--sni", default=None)
|
||
parser.add_argument("--timeout", type=float, default=5.0)
|
||
parser.add_argument("--no-retry", action="store_true")
|
||
args = parser.parse_args()
|
||
|
||
# First attempt: no token
|
||
report1 = await quic_attempt(
|
||
host=args.host,
|
||
port=args.port,
|
||
alpn=args.alpn,
|
||
sni=args.sni,
|
||
timeout=args.timeout,
|
||
token=None,
|
||
)
|
||
print_report("Attempt 1 (no token)", report1)
|
||
|
||
if report1.error:
|
||
return
|
||
|
||
if args.no_retry or not report1.used_retry:
|
||
# Either server didn't send Retry or user skipped token attempt
|
||
return
|
||
|
||
# If Retry observed, retrieve token from the connection state
|
||
# aioquic stores token in connection._token after Retry; we must run a second connection using it
|
||
# For simplicity, we reuse the token observed in first attempt if accessible; otherwise we try a second connect which should present the stored token via session ticket.
|
||
# Note: Access to private member is implementation-dependent; guarded with getattr.
|
||
retry_token: Optional[bytes] = None
|
||
try:
|
||
# Not ideal, but for probing we may access private attr
|
||
retry_token = getattr(report1, "retry_token", None)
|
||
except Exception:
|
||
retry_token = None
|
||
|
||
# Second attempt: with token if we have one (many servers encode address validation in Retry, token echoed by client)
|
||
report2 = await quic_attempt(
|
||
host=args.host,
|
||
port=args.port,
|
||
alpn=args.alpn,
|
||
sni=args.sni,
|
||
timeout=args.timeout,
|
||
token=retry_token,
|
||
)
|
||
print_report("Attempt 2 (with token if available)", report2)
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|