This commit is contained in:
mofixx
2025-08-08 10:41:30 +02:00
parent 4444be3799
commit a5df3861fd
1674 changed files with 234266 additions and 0 deletions

View File

@ -0,0 +1 @@
__version__ = "1.2.0"

View File

@ -0,0 +1,422 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdint.h>
#define MODULE_NAME "aioquic._buffer"
static PyObject *BufferReadError;
static PyObject *BufferWriteError;
typedef struct {
PyObject_HEAD
uint8_t *base;
uint8_t *end;
uint8_t *pos;
} BufferObject;
static PyObject *BufferType;
#define CHECK_READ_BOUNDS(self, len) \
if (len < 0 || self->pos + len > self->end) { \
PyErr_SetString(BufferReadError, "Read out of bounds"); \
return NULL; \
}
#define CHECK_WRITE_BOUNDS(self, len) \
if (self->pos + len > self->end) { \
PyErr_SetString(BufferWriteError, "Write out of bounds"); \
return NULL; \
}
static int
Buffer_init(BufferObject *self, PyObject *args, PyObject *kwargs)
{
const char *kwlist[] = {"capacity", "data", NULL};
Py_ssize_t capacity = 0;
const unsigned char *data = NULL;
Py_ssize_t data_len = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ny#", (char**)kwlist, &capacity, &data, &data_len))
return -1;
if (data != NULL) {
self->base = malloc(data_len);
self->end = self->base + data_len;
memcpy(self->base, data, data_len);
} else {
self->base = malloc(capacity);
self->end = self->base + capacity;
}
self->pos = self->base;
return 0;
}
static void
Buffer_dealloc(BufferObject *self)
{
free(self->base);
PyTypeObject *tp = Py_TYPE(self);
freefunc free = PyType_GetSlot(tp, Py_tp_free);
free(self);
Py_DECREF(tp);
}
static PyObject *
Buffer_data_slice(BufferObject *self, PyObject *args)
{
Py_ssize_t start, stop;
if (!PyArg_ParseTuple(args, "nn", &start, &stop))
return NULL;
if (start < 0 || self->base + start > self->end ||
stop < 0 || self->base + stop > self->end ||
stop < start) {
PyErr_SetString(BufferReadError, "Read out of bounds");
return NULL;
}
return PyBytes_FromStringAndSize((const char*)(self->base + start), (stop - start));
}
static PyObject *
Buffer_eof(BufferObject *self, PyObject *args)
{
if (self->pos == self->end)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject *
Buffer_pull_bytes(BufferObject *self, PyObject *args)
{
Py_ssize_t len;
if (!PyArg_ParseTuple(args, "n", &len))
return NULL;
CHECK_READ_BOUNDS(self, len);
PyObject *o = PyBytes_FromStringAndSize((const char*)self->pos, len);
self->pos += len;
return o;
}
static PyObject *
Buffer_pull_uint8(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 1)
return PyLong_FromUnsignedLong(
(uint8_t)(*(self->pos++))
);
}
static PyObject *
Buffer_pull_uint16(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 2)
uint16_t value = (uint16_t)(*(self->pos)) << 8 |
(uint16_t)(*(self->pos + 1));
self->pos += 2;
return PyLong_FromUnsignedLong(value);
}
static PyObject *
Buffer_pull_uint32(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 4)
uint32_t value = (uint32_t)(*(self->pos)) << 24 |
(uint32_t)(*(self->pos + 1)) << 16 |
(uint32_t)(*(self->pos + 2)) << 8 |
(uint32_t)(*(self->pos + 3));
self->pos += 4;
return PyLong_FromUnsignedLong(value);
}
static PyObject *
Buffer_pull_uint64(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 8)
uint64_t value = (uint64_t)(*(self->pos)) << 56 |
(uint64_t)(*(self->pos + 1)) << 48 |
(uint64_t)(*(self->pos + 2)) << 40 |
(uint64_t)(*(self->pos + 3)) << 32 |
(uint64_t)(*(self->pos + 4)) << 24 |
(uint64_t)(*(self->pos + 5)) << 16 |
(uint64_t)(*(self->pos + 6)) << 8 |
(uint64_t)(*(self->pos + 7));
self->pos += 8;
return PyLong_FromUnsignedLongLong(value);
}
static PyObject *
Buffer_pull_uint_var(BufferObject *self, PyObject *args)
{
uint64_t value;
CHECK_READ_BOUNDS(self, 1)
switch (*(self->pos) >> 6) {
case 0:
value = *(self->pos++) & 0x3F;
break;
case 1:
CHECK_READ_BOUNDS(self, 2)
value = (uint16_t)(*(self->pos) & 0x3F) << 8 |
(uint16_t)(*(self->pos + 1));
self->pos += 2;
break;
case 2:
CHECK_READ_BOUNDS(self, 4)
value = (uint32_t)(*(self->pos) & 0x3F) << 24 |
(uint32_t)(*(self->pos + 1)) << 16 |
(uint32_t)(*(self->pos + 2)) << 8 |
(uint32_t)(*(self->pos + 3));
self->pos += 4;
break;
default:
CHECK_READ_BOUNDS(self, 8)
value = (uint64_t)(*(self->pos) & 0x3F) << 56 |
(uint64_t)(*(self->pos + 1)) << 48 |
(uint64_t)(*(self->pos + 2)) << 40 |
(uint64_t)(*(self->pos + 3)) << 32 |
(uint64_t)(*(self->pos + 4)) << 24 |
(uint64_t)(*(self->pos + 5)) << 16 |
(uint64_t)(*(self->pos + 6)) << 8 |
(uint64_t)(*(self->pos + 7));
self->pos += 8;
break;
}
return PyLong_FromUnsignedLongLong(value);
}
static PyObject *
Buffer_push_bytes(BufferObject *self, PyObject *args)
{
const unsigned char *data;
Py_ssize_t data_len;
if (!PyArg_ParseTuple(args, "y#", &data, &data_len))
return NULL;
CHECK_WRITE_BOUNDS(self, data_len)
memcpy(self->pos, data, data_len);
self->pos += data_len;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint8(BufferObject *self, PyObject *args)
{
uint8_t value;
if (!PyArg_ParseTuple(args, "B", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 1)
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint16(BufferObject *self, PyObject *args)
{
uint16_t value;
if (!PyArg_ParseTuple(args, "H", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 2)
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint32(BufferObject *self, PyObject *args)
{
uint32_t value;
if (!PyArg_ParseTuple(args, "I", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 4)
*(self->pos++) = (value >> 24);
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint64(BufferObject *self, PyObject *args)
{
uint64_t value;
if (!PyArg_ParseTuple(args, "K", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 8)
*(self->pos++) = (value >> 56);
*(self->pos++) = (value >> 48);
*(self->pos++) = (value >> 40);
*(self->pos++) = (value >> 32);
*(self->pos++) = (value >> 24);
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint_var(BufferObject *self, PyObject *args)
{
uint64_t value;
if (!PyArg_ParseTuple(args, "K", &value))
return NULL;
if (value <= 0x3F) {
CHECK_WRITE_BOUNDS(self, 1)
*(self->pos++) = value;
Py_RETURN_NONE;
} else if (value <= 0x3FFF) {
CHECK_WRITE_BOUNDS(self, 2)
*(self->pos++) = (value >> 8) | 0x40;
*(self->pos++) = value;
Py_RETURN_NONE;
} else if (value <= 0x3FFFFFFF) {
CHECK_WRITE_BOUNDS(self, 4)
*(self->pos++) = (value >> 24) | 0x80;
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
} else if (value <= 0x3FFFFFFFFFFFFFFF) {
CHECK_WRITE_BOUNDS(self, 8)
*(self->pos++) = (value >> 56) | 0xC0;
*(self->pos++) = (value >> 48);
*(self->pos++) = (value >> 40);
*(self->pos++) = (value >> 32);
*(self->pos++) = (value >> 24);
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
} else {
PyErr_SetString(PyExc_ValueError, "Integer is too big for a variable-length integer");
return NULL;
}
}
static PyObject *
Buffer_seek(BufferObject *self, PyObject *args)
{
Py_ssize_t pos;
if (!PyArg_ParseTuple(args, "n", &pos))
return NULL;
if (pos < 0 || self->base + pos > self->end) {
PyErr_SetString(BufferReadError, "Seek out of bounds");
return NULL;
}
self->pos = self->base + pos;
Py_RETURN_NONE;
}
static PyObject *
Buffer_tell(BufferObject *self, PyObject *args)
{
return PyLong_FromSsize_t(self->pos - self->base);
}
static PyMethodDef Buffer_methods[] = {
{"data_slice", (PyCFunction)Buffer_data_slice, METH_VARARGS, ""},
{"eof", (PyCFunction)Buffer_eof, METH_VARARGS, ""},
{"pull_bytes", (PyCFunction)Buffer_pull_bytes, METH_VARARGS, "Pull bytes."},
{"pull_uint8", (PyCFunction)Buffer_pull_uint8, METH_VARARGS, "Pull an 8-bit unsigned integer."},
{"pull_uint16", (PyCFunction)Buffer_pull_uint16, METH_VARARGS, "Pull a 16-bit unsigned integer."},
{"pull_uint32", (PyCFunction)Buffer_pull_uint32, METH_VARARGS, "Pull a 32-bit unsigned integer."},
{"pull_uint64", (PyCFunction)Buffer_pull_uint64, METH_VARARGS, "Pull a 64-bit unsigned integer."},
{"pull_uint_var", (PyCFunction)Buffer_pull_uint_var, METH_VARARGS, "Pull a QUIC variable-length unsigned integer."},
{"push_bytes", (PyCFunction)Buffer_push_bytes, METH_VARARGS, "Push bytes."},
{"push_uint8", (PyCFunction)Buffer_push_uint8, METH_VARARGS, "Push an 8-bit unsigned integer."},
{"push_uint16", (PyCFunction)Buffer_push_uint16, METH_VARARGS, "Push a 16-bit unsigned integer."},
{"push_uint32", (PyCFunction)Buffer_push_uint32, METH_VARARGS, "Push a 32-bit unsigned integer."},
{"push_uint64", (PyCFunction)Buffer_push_uint64, METH_VARARGS, "Push a 64-bit unsigned integer."},
{"push_uint_var", (PyCFunction)Buffer_push_uint_var, METH_VARARGS, "Push a QUIC variable-length unsigned integer."},
{"seek", (PyCFunction)Buffer_seek, METH_VARARGS, ""},
{"tell", (PyCFunction)Buffer_tell, METH_VARARGS, ""},
{NULL}
};
static PyObject*
Buffer_capacity_getter(BufferObject* self, void *closure) {
return PyLong_FromSsize_t(self->end - self->base);
}
static PyObject*
Buffer_data_getter(BufferObject* self, void *closure) {
return PyBytes_FromStringAndSize((const char*)self->base, self->pos - self->base);
}
static PyGetSetDef Buffer_getset[] = {
{"capacity", (getter) Buffer_capacity_getter, NULL, "", NULL },
{"data", (getter) Buffer_data_getter, NULL, "", NULL },
{NULL}
};
static PyType_Slot BufferType_slots[] = {
{Py_tp_dealloc, Buffer_dealloc},
{Py_tp_methods, Buffer_methods},
{Py_tp_doc, "Buffer objects"},
{Py_tp_getset, Buffer_getset},
{Py_tp_init, Buffer_init},
{0, 0},
};
static PyType_Spec BufferType_spec = {
MODULE_NAME ".Buffer",
sizeof(BufferObject),
0,
Py_TPFLAGS_DEFAULT,
BufferType_slots
};
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
MODULE_NAME, /* m_name */
"Serialization utilities.", /* m_doc */
-1, /* m_size */
NULL, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear */
NULL, /* m_free */
};
PyMODINIT_FUNC
PyInit__buffer(void)
{
PyObject* m;
m = PyModule_Create(&moduledef);
if (m == NULL)
return NULL;
BufferReadError = PyErr_NewException(MODULE_NAME ".BufferReadError", PyExc_ValueError, NULL);
Py_INCREF(BufferReadError);
PyModule_AddObject(m, "BufferReadError", BufferReadError);
BufferWriteError = PyErr_NewException(MODULE_NAME ".BufferWriteError", PyExc_ValueError, NULL);
Py_INCREF(BufferWriteError);
PyModule_AddObject(m, "BufferWriteError", BufferWriteError);
BufferType = PyType_FromSpec(&BufferType_spec);
if (BufferType == NULL)
return NULL;
PyModule_AddObject(m, "Buffer", BufferType);
return m;
}

View File

@ -0,0 +1,27 @@
from typing import Optional
class BufferReadError(ValueError): ...
class BufferWriteError(ValueError): ...
class Buffer:
def __init__(self, capacity: Optional[int] = 0, data: Optional[bytes] = None): ...
@property
def capacity(self) -> int: ...
@property
def data(self) -> bytes: ...
def data_slice(self, start: int, end: int) -> bytes: ...
def eof(self) -> bool: ...
def seek(self, pos: int) -> None: ...
def tell(self) -> int: ...
def pull_bytes(self, length: int) -> bytes: ...
def pull_uint8(self) -> int: ...
def pull_uint16(self) -> int: ...
def pull_uint32(self) -> int: ...
def pull_uint64(self) -> int: ...
def pull_uint_var(self) -> int: ...
def push_bytes(self, value: bytes) -> None: ...
def push_uint8(self, value: int) -> None: ...
def push_uint16(self, value: int) -> None: ...
def push_uint32(self, v: int) -> None: ...
def push_uint64(self, v: int) -> None: ...
def push_uint_var(self, value: int) -> None: ...

View File

@ -0,0 +1,416 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#define MODULE_NAME "aioquic._crypto"
#define AEAD_KEY_LENGTH_MAX 32
#define AEAD_NONCE_LENGTH 12
#define AEAD_TAG_LENGTH 16
#define PACKET_LENGTH_MAX 1500
#define PACKET_NUMBER_LENGTH_MAX 4
#define SAMPLE_LENGTH 16
#define CHECK_RESULT(expr) \
if (!(expr)) { \
ERR_clear_error(); \
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
return NULL; \
}
#define CHECK_RESULT_CTOR(expr) \
if (!(expr)) { \
ERR_clear_error(); \
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
return -1; \
}
static PyObject *CryptoError;
/* AEAD */
typedef struct {
PyObject_HEAD
EVP_CIPHER_CTX *decrypt_ctx;
EVP_CIPHER_CTX *encrypt_ctx;
unsigned char buffer[PACKET_LENGTH_MAX];
unsigned char key[AEAD_KEY_LENGTH_MAX];
unsigned char iv[AEAD_NONCE_LENGTH];
unsigned char nonce[AEAD_NONCE_LENGTH];
} AEADObject;
static PyObject *AEADType;
static EVP_CIPHER_CTX *
create_ctx(const EVP_CIPHER *cipher, int key_length, int operation)
{
EVP_CIPHER_CTX *ctx;
int res;
ctx = EVP_CIPHER_CTX_new();
CHECK_RESULT(ctx != 0);
res = EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, operation);
CHECK_RESULT(res != 0);
res = EVP_CIPHER_CTX_set_key_length(ctx, key_length);
CHECK_RESULT(res != 0);
res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_CCM_SET_IVLEN, AEAD_NONCE_LENGTH, NULL);
CHECK_RESULT(res != 0);
return ctx;
}
static int
AEAD_init(AEADObject *self, PyObject *args, PyObject *kwargs)
{
const char *cipher_name;
const unsigned char *key, *iv;
Py_ssize_t cipher_name_len, key_len, iv_len;
if (!PyArg_ParseTuple(args, "y#y#y#", &cipher_name, &cipher_name_len, &key, &key_len, &iv, &iv_len))
return -1;
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
if (evp_cipher == 0) {
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
return -1;
}
if (key_len > AEAD_KEY_LENGTH_MAX) {
PyErr_SetString(CryptoError, "Invalid key length");
return -1;
}
if (iv_len > AEAD_NONCE_LENGTH) {
PyErr_SetString(CryptoError, "Invalid iv length");
return -1;
}
memcpy(self->key, key, key_len);
memcpy(self->iv, iv, iv_len);
self->decrypt_ctx = create_ctx(evp_cipher, key_len, 0);
CHECK_RESULT_CTOR(self->decrypt_ctx != 0);
self->encrypt_ctx = create_ctx(evp_cipher, key_len, 1);
CHECK_RESULT_CTOR(self->encrypt_ctx != 0);
return 0;
}
static void
AEAD_dealloc(AEADObject *self)
{
EVP_CIPHER_CTX_free(self->decrypt_ctx);
EVP_CIPHER_CTX_free(self->encrypt_ctx);
PyTypeObject *tp = Py_TYPE(self);
freefunc free = PyType_GetSlot(tp, Py_tp_free);
free(self);
Py_DECREF(tp);
}
static PyObject*
AEAD_decrypt(AEADObject *self, PyObject *args)
{
const unsigned char *data, *associated;
Py_ssize_t data_len, associated_len;
int outlen, outlen2, res;
uint64_t pn;
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
return NULL;
if (data_len < AEAD_TAG_LENGTH || data_len > PACKET_LENGTH_MAX) {
PyErr_SetString(CryptoError, "Invalid payload length");
return NULL;
}
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
for (int i = 0; i < 8; ++i) {
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
}
res = EVP_CIPHER_CTX_ctrl(self->decrypt_ctx, EVP_CTRL_CCM_SET_TAG, AEAD_TAG_LENGTH, (void*)(data + (data_len - AEAD_TAG_LENGTH)));
CHECK_RESULT(res != 0);
res = EVP_CipherInit_ex(self->decrypt_ctx, NULL, NULL, self->key, self->nonce, 0);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->decrypt_ctx, NULL, &outlen, associated, associated_len);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->decrypt_ctx, self->buffer, &outlen, data, data_len - AEAD_TAG_LENGTH);
CHECK_RESULT(res != 0);
res = EVP_CipherFinal_ex(self->decrypt_ctx, NULL, &outlen2);
if (res == 0) {
PyErr_SetString(CryptoError, "Payload decryption failed");
return NULL;
}
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen);
}
static PyObject*
AEAD_encrypt(AEADObject *self, PyObject *args)
{
const unsigned char *data, *associated;
Py_ssize_t data_len, associated_len;
int outlen, outlen2, res;
uint64_t pn;
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
return NULL;
if (data_len > PACKET_LENGTH_MAX) {
PyErr_SetString(CryptoError, "Invalid payload length");
return NULL;
}
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
for (int i = 0; i < 8; ++i) {
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
}
res = EVP_CipherInit_ex(self->encrypt_ctx, NULL, NULL, self->key, self->nonce, 1);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->encrypt_ctx, NULL, &outlen, associated, associated_len);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->encrypt_ctx, self->buffer, &outlen, data, data_len);
CHECK_RESULT(res != 0);
res = EVP_CipherFinal_ex(self->encrypt_ctx, NULL, &outlen2);
CHECK_RESULT(res != 0 && outlen2 == 0);
res = EVP_CIPHER_CTX_ctrl(self->encrypt_ctx, EVP_CTRL_CCM_GET_TAG, AEAD_TAG_LENGTH, self->buffer + outlen);
CHECK_RESULT(res != 0);
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen + AEAD_TAG_LENGTH);
}
static PyMethodDef AEAD_methods[] = {
{"decrypt", (PyCFunction)AEAD_decrypt, METH_VARARGS, ""},
{"encrypt", (PyCFunction)AEAD_encrypt, METH_VARARGS, ""},
{NULL}
};
static PyType_Slot AEADType_slots[] = {
{Py_tp_dealloc, AEAD_dealloc},
{Py_tp_methods, AEAD_methods},
{Py_tp_doc, "AEAD objects"},
{Py_tp_init, AEAD_init},
{0, 0},
};
static PyType_Spec AEADType_spec = {
MODULE_NAME ".AEADType",
sizeof(AEADObject),
0,
Py_TPFLAGS_DEFAULT,
AEADType_slots
};
/* HeaderProtection */
typedef struct {
PyObject_HEAD
EVP_CIPHER_CTX *ctx;
int is_chacha20;
unsigned char buffer[PACKET_LENGTH_MAX];
unsigned char mask[31];
unsigned char zero[5];
} HeaderProtectionObject;
static PyObject *HeaderProtectionType;
static int
HeaderProtection_init(HeaderProtectionObject *self, PyObject *args, PyObject *kwargs)
{
const char *cipher_name;
const unsigned char *key;
Py_ssize_t cipher_name_len, key_len;
int res;
if (!PyArg_ParseTuple(args, "y#y#", &cipher_name, &cipher_name_len, &key, &key_len))
return -1;
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
if (evp_cipher == 0) {
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
return -1;
}
memset(self->mask, 0, sizeof(self->mask));
memset(self->zero, 0, sizeof(self->zero));
self->is_chacha20 = cipher_name_len == 8 && memcmp(cipher_name, "chacha20", 8) == 0;
self->ctx = EVP_CIPHER_CTX_new();
CHECK_RESULT_CTOR(self->ctx != 0);
res = EVP_CipherInit_ex(self->ctx, evp_cipher, NULL, NULL, NULL, 1);
CHECK_RESULT_CTOR(res != 0);
res = EVP_CIPHER_CTX_set_key_length(self->ctx, key_len);
CHECK_RESULT_CTOR(res != 0);
res = EVP_CipherInit_ex(self->ctx, NULL, NULL, key, NULL, 1);
CHECK_RESULT_CTOR(res != 0);
return 0;
}
static void
HeaderProtection_dealloc(HeaderProtectionObject *self)
{
EVP_CIPHER_CTX_free(self->ctx);
PyTypeObject *tp = Py_TYPE(self);
freefunc free = PyType_GetSlot(tp, Py_tp_free);
free(self);
Py_DECREF(tp);
}
static int HeaderProtection_mask(HeaderProtectionObject *self, const unsigned char* sample)
{
int outlen;
if (self->is_chacha20) {
return EVP_CipherInit_ex(self->ctx, NULL, NULL, NULL, sample, 1) &&
EVP_CipherUpdate(self->ctx, self->mask, &outlen, self->zero, sizeof(self->zero));
} else {
return EVP_CipherUpdate(self->ctx, self->mask, &outlen, sample, SAMPLE_LENGTH);
}
}
static PyObject*
HeaderProtection_apply(HeaderProtectionObject *self, PyObject *args)
{
const unsigned char *header, *payload;
Py_ssize_t header_len, payload_len;
int res;
if (!PyArg_ParseTuple(args, "y#y#", &header, &header_len, &payload, &payload_len))
return NULL;
int pn_length = (header[0] & 0x03) + 1;
int pn_offset = header_len - pn_length;
res = HeaderProtection_mask(self, payload + PACKET_NUMBER_LENGTH_MAX - pn_length);
CHECK_RESULT(res != 0);
memcpy(self->buffer, header, header_len);
memcpy(self->buffer + header_len, payload, payload_len);
if (self->buffer[0] & 0x80) {
self->buffer[0] ^= self->mask[0] & 0x0F;
} else {
self->buffer[0] ^= self->mask[0] & 0x1F;
}
for (int i = 0; i < pn_length; ++i) {
self->buffer[pn_offset + i] ^= self->mask[1 + i];
}
return PyBytes_FromStringAndSize((const char*)self->buffer, header_len + payload_len);
}
static PyObject*
HeaderProtection_remove(HeaderProtectionObject *self, PyObject *args)
{
const unsigned char *packet;
Py_ssize_t packet_len;
int pn_offset, res;
if (!PyArg_ParseTuple(args, "y#I", &packet, &packet_len, &pn_offset))
return NULL;
res = HeaderProtection_mask(self, packet + pn_offset + PACKET_NUMBER_LENGTH_MAX);
CHECK_RESULT(res != 0);
memcpy(self->buffer, packet, pn_offset + PACKET_NUMBER_LENGTH_MAX);
if (self->buffer[0] & 0x80) {
self->buffer[0] ^= self->mask[0] & 0x0F;
} else {
self->buffer[0] ^= self->mask[0] & 0x1F;
}
int pn_length = (self->buffer[0] & 0x03) + 1;
uint32_t pn_truncated = 0;
for (int i = 0; i < pn_length; ++i) {
self->buffer[pn_offset + i] ^= self->mask[1 + i];
pn_truncated = self->buffer[pn_offset + i] | (pn_truncated << 8);
}
return Py_BuildValue("y#i", self->buffer, pn_offset + pn_length, pn_truncated);
}
static PyMethodDef HeaderProtection_methods[] = {
{"apply", (PyCFunction)HeaderProtection_apply, METH_VARARGS, ""},
{"remove", (PyCFunction)HeaderProtection_remove, METH_VARARGS, ""},
{NULL}
};
static PyType_Slot HeaderProtectionType_slots[] = {
{Py_tp_dealloc, HeaderProtection_dealloc},
{Py_tp_methods, HeaderProtection_methods},
{Py_tp_doc, "HeaderProtection objects"},
{Py_tp_init, HeaderProtection_init},
{0, 0},
};
static PyType_Spec HeaderProtectionType_spec = {
MODULE_NAME ".HeaderProtectionType",
sizeof(HeaderProtectionObject),
0,
Py_TPFLAGS_DEFAULT,
HeaderProtectionType_slots
};
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
MODULE_NAME, /* m_name */
"Cryptography utilities.", /* m_doc */
-1, /* m_size */
NULL, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear */
NULL, /* m_free */
};
PyMODINIT_FUNC
PyInit__crypto(void)
{
PyObject* m;
m = PyModule_Create(&moduledef);
if (m == NULL)
return NULL;
CryptoError = PyErr_NewException(MODULE_NAME ".CryptoError", PyExc_ValueError, NULL);
Py_INCREF(CryptoError);
PyModule_AddObject(m, "CryptoError", CryptoError);
AEADType = PyType_FromSpec(&AEADType_spec);
if (AEADType == NULL)
return NULL;
PyModule_AddObject(m, "AEAD", AEADType);
HeaderProtectionType = PyType_FromSpec(&HeaderProtectionType_spec);
if (HeaderProtectionType == NULL)
return NULL;
PyModule_AddObject(m, "HeaderProtection", HeaderProtectionType);
// ensure required ciphers are initialised
EVP_add_cipher(EVP_aes_128_ecb());
EVP_add_cipher(EVP_aes_128_gcm());
EVP_add_cipher(EVP_aes_256_ecb());
EVP_add_cipher(EVP_aes_256_gcm());
return m;
}

View File

@ -0,0 +1,17 @@
from typing import Tuple
class AEAD:
def __init__(self, cipher_name: bytes, key: bytes, iv: bytes): ...
def decrypt(
self, data: bytes, associated_data: bytes, packet_number: int
) -> bytes: ...
def encrypt(
self, data: bytes, associated_data: bytes, packet_number: int
) -> bytes: ...
class CryptoError(ValueError): ...
class HeaderProtection:
def __init__(self, cipher_name: bytes, key: bytes): ...
def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ...
def remove(self, packet: bytes, encrypted_offset: int) -> Tuple[bytes, int]: ...

View File

@ -0,0 +1,3 @@
from .client import connect # noqa
from .protocol import QuicConnectionProtocol # noqa
from .server import serve # noqa

View File

@ -0,0 +1,98 @@
import asyncio
import socket
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Callable, Optional, cast
from ..quic.configuration import QuicConfiguration
from ..quic.connection import QuicConnection, QuicTokenHandler
from ..tls import SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["connect"]
@asynccontextmanager
async def connect(
host: str,
port: int,
*,
configuration: Optional[QuicConfiguration] = None,
create_protocol: Optional[Callable] = QuicConnectionProtocol,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stream_handler: Optional[QuicStreamHandler] = None,
token_handler: Optional[QuicTokenHandler] = None,
wait_connected: bool = True,
local_port: int = 0,
) -> AsyncGenerator[QuicConnectionProtocol, None]:
"""
Connect to a QUIC server at the given `host` and `port`.
:meth:`connect()` returns an awaitable. Awaiting it yields a
:class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to
create streams.
:func:`connect` also accepts the following optional arguments:
* ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration`
configuration object.
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is received.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
* ``wait_connected`` indicates whether the context manager should wait for the
connection to be established before yielding the
:class:`~aioquic.asyncio.QuicConnectionProtocol`. By default this is `True` but
you can set it to `False` if you want to immediately start sending data using
0-RTT.
* ``local_port`` is the UDP port number that this client wants to bind.
"""
loop = asyncio.get_event_loop()
local_host = "::"
# lookup remote address
infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
addr = infos[0][4]
if len(addr) == 2:
addr = ("::ffff:" + addr[0], addr[1], 0, 0)
# prepare QUIC connection
if configuration is None:
configuration = QuicConfiguration(is_client=True)
if configuration.server_name is None:
configuration.server_name = host
connection = QuicConnection(
configuration=configuration,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)
# explicitly enable IPv4/IPv6 dual stack
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
completed = False
try:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
sock.bind((local_host, local_port, 0, 0))
completed = True
finally:
if not completed:
sock.close()
# connect
transport, protocol = await loop.create_datagram_endpoint(
lambda: create_protocol(connection, stream_handler=stream_handler),
sock=sock,
)
protocol = cast(QuicConnectionProtocol, protocol)
try:
protocol.connect(addr, transmit=wait_connected)
if wait_connected:
await protocol.wait_connected()
yield protocol
finally:
protocol.close()
await protocol.wait_closed()
transport.close()

View File

@ -0,0 +1,272 @@
import asyncio
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
from ..quic import events
from ..quic.connection import NetworkAddress, QuicConnection
from ..quic.packet import QuicErrorCode
QuicConnectionIdHandler = Callable[[bytes], None]
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
class QuicConnectionProtocol(asyncio.DatagramProtocol):
def __init__(
self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
):
loop = asyncio.get_event_loop()
self._closed = asyncio.Event()
self._connected = False
self._connected_waiter: Optional[asyncio.Future[None]] = None
self._loop = loop
self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
self._quic = quic
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
self._timer: Optional[asyncio.TimerHandle] = None
self._timer_at: Optional[float] = None
self._transmit_task: Optional[asyncio.Handle] = None
self._transport: Optional[asyncio.DatagramTransport] = None
# callbacks
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
self._connection_terminated_handler: Callable[[], None] = lambda: None
if stream_handler is not None:
self._stream_handler = stream_handler
else:
self._stream_handler = lambda r, w: None
def change_connection_id(self) -> None:
"""
Change the connection ID used to communicate with the peer.
The previous connection ID will be retired.
"""
self._quic.change_connection_id()
self.transmit()
def close(
self,
error_code: int = QuicErrorCode.NO_ERROR,
reason_phrase: str = "",
) -> None:
"""
Close the connection.
:param error_code: An error code indicating why the connection is
being closed.
:param reason_phrase: A human-readable explanation of why the
connection is being closed.
"""
self._quic.close(
error_code=error_code,
reason_phrase=reason_phrase,
)
self.transmit()
def connect(self, addr: NetworkAddress, transmit=True) -> None:
"""
Initiate the TLS handshake.
This method can only be called for clients and a single time.
"""
self._quic.connect(addr, now=self._loop.time())
if transmit:
self.transmit()
async def create_stream(
self, is_unidirectional: bool = False
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Create a QUIC stream and return a pair of (reader, writer) objects.
The returned reader and writer objects are instances of
:class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes.
"""
stream_id = self._quic.get_next_available_stream_id(
is_unidirectional=is_unidirectional
)
return self._create_stream(stream_id)
def request_key_update(self) -> None:
"""
Request an update of the encryption keys.
"""
self._quic.request_key_update()
self.transmit()
async def ping(self) -> None:
"""
Ping the peer and wait for the response.
"""
waiter = self._loop.create_future()
uid = id(waiter)
self._ping_waiters[uid] = waiter
self._quic.send_ping(uid)
self.transmit()
await asyncio.shield(waiter)
def transmit(self) -> None:
"""
Send pending datagrams to the peer and arm the timer if needed.
This method is called automatically when data is received from the peer
or when a timer goes off. If you interact directly with the underlying
:class:`~aioquic.quic.connection.QuicConnection`, make sure you call this
method whenever data needs to be sent out to the network.
"""
self._transmit_task = None
# send datagrams
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
self._transport.sendto(data, addr)
# re-arm timer
timer_at = self._quic.get_timer()
if self._timer is not None and self._timer_at != timer_at:
self._timer.cancel()
self._timer = None
if self._timer is None and timer_at is not None:
self._timer = self._loop.call_at(timer_at, self._handle_timer)
self._timer_at = timer_at
async def wait_closed(self) -> None:
"""
Wait for the connection to be closed.
"""
await self._closed.wait()
async def wait_connected(self) -> None:
"""
Wait for the TLS handshake to complete.
"""
assert self._connected_waiter is None, "already awaiting connected"
if not self._connected:
self._connected_waiter = self._loop.create_future()
await asyncio.shield(self._connected_waiter)
# asyncio.Transport
def connection_made(self, transport: asyncio.BaseTransport) -> None:
""":meta private:"""
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
""":meta private:"""
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
self._process_events()
self.transmit()
# overridable
def quic_event_received(self, event: events.QuicEvent) -> None:
"""
Called when a QUIC event is received.
Reimplement this in your subclass to handle the events.
"""
# FIXME: move this to a subclass
if isinstance(event, events.ConnectionTerminated):
for reader in self._stream_readers.values():
reader.feed_eof()
elif isinstance(event, events.StreamDataReceived):
reader = self._stream_readers.get(event.stream_id, None)
if reader is None:
reader, writer = self._create_stream(event.stream_id)
self._stream_handler(reader, writer)
reader.feed_data(event.data)
if event.end_stream:
reader.feed_eof()
# private
def _create_stream(
self, stream_id: int
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
adapter = QuicStreamAdapter(self, stream_id)
reader = asyncio.StreamReader()
protocol = asyncio.streams.StreamReaderProtocol(reader)
writer = asyncio.StreamWriter(adapter, protocol, reader, self._loop)
self._stream_readers[stream_id] = reader
return reader, writer
def _handle_timer(self) -> None:
now = max(self._timer_at, self._loop.time())
self._timer = None
self._timer_at = None
self._quic.handle_timer(now=now)
self._process_events()
self.transmit()
def _process_events(self) -> None:
event = self._quic.next_event()
while event is not None:
if isinstance(event, events.ConnectionIdIssued):
self._connection_id_issued_handler(event.connection_id)
elif isinstance(event, events.ConnectionIdRetired):
self._connection_id_retired_handler(event.connection_id)
elif isinstance(event, events.ConnectionTerminated):
self._connection_terminated_handler()
# abort connection waiter
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected_waiter = None
waiter.set_exception(ConnectionError)
# abort ping waiters
for waiter in self._ping_waiters.values():
waiter.set_exception(ConnectionError)
self._ping_waiters.clear()
self._closed.set()
elif isinstance(event, events.HandshakeCompleted):
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected = True
self._connected_waiter = None
waiter.set_result(None)
elif isinstance(event, events.PingAcknowledged):
waiter = self._ping_waiters.pop(event.uid, None)
if waiter is not None:
waiter.set_result(None)
self.quic_event_received(event)
event = self._quic.next_event()
def _transmit_soon(self) -> None:
if self._transmit_task is None:
self._transmit_task = self._loop.call_soon(self.transmit)
class QuicStreamAdapter(asyncio.Transport):
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
self.protocol = protocol
self.stream_id = stream_id
self._closing = False
def can_write_eof(self) -> bool:
return True
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""
Get information about the underlying QUIC stream.
"""
if name == "stream_id":
return self.stream_id
def write(self, data):
self.protocol._quic.send_stream_data(self.stream_id, data)
self.protocol._transmit_soon()
def write_eof(self):
if self._closing:
return
self._closing = True
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
self.protocol._transmit_soon()
def close(self):
self.write_eof()
def is_closing(self) -> bool:
return self._closing

View File

@ -0,0 +1,215 @@
import asyncio
import os
from functools import partial
from typing import Callable, Dict, Optional, Text, Union, cast
from ..buffer import Buffer
from ..quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
from ..quic.connection import NetworkAddress, QuicConnection
from ..quic.packet import (
QuicPacketType,
encode_quic_retry,
encode_quic_version_negotiation,
pull_quic_header,
)
from ..quic.retry import QuicRetryTokenHandler
from ..tls import SessionTicketFetcher, SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["serve"]
class QuicServer(asyncio.DatagramProtocol):
def __init__(
self,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
retry: bool = False,
stream_handler: Optional[QuicStreamHandler] = None,
) -> None:
self._configuration = configuration
self._create_protocol = create_protocol
self._loop = asyncio.get_event_loop()
self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._transport: Optional[asyncio.DatagramTransport] = None
self._stream_handler = stream_handler
if retry:
self._retry = QuicRetryTokenHandler()
else:
self._retry = None
def close(self):
for protocol in set(self._protocols.values()):
protocol.close()
self._protocols.clear()
self._transport.close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
data = cast(bytes, data)
buf = Buffer(data=data)
try:
header = pull_quic_header(
buf, host_cid_length=self._configuration.connection_id_length
)
except ValueError:
return
# version negotiation
if (
header.version is not None
and header.version not in self._configuration.supported_versions
):
self._transport.sendto(
encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self._configuration.supported_versions,
),
addr,
)
return
protocol = self._protocols.get(header.destination_cid, None)
original_destination_connection_id: Optional[bytes] = None
retry_source_connection_id: Optional[bytes] = None
if (
protocol is None
and len(data) >= SMALLEST_MAX_DATAGRAM_SIZE
and header.packet_type == QuicPacketType.INITIAL
):
# retry
if self._retry is not None:
if not header.token:
# create a retry token
source_cid = os.urandom(8)
self._transport.sendto(
encode_quic_retry(
version=header.version,
source_cid=source_cid,
destination_cid=header.source_cid,
original_destination_cid=header.destination_cid,
retry_token=self._retry.create_token(
addr, header.destination_cid, source_cid
),
),
addr,
)
return
else:
# validate retry token
try:
(
original_destination_connection_id,
retry_source_connection_id,
) = self._retry.validate_token(addr, header.token)
except ValueError:
return
else:
original_destination_connection_id = header.destination_cid
# create new connection
connection = QuicConnection(
configuration=self._configuration,
original_destination_connection_id=original_destination_connection_id,
retry_source_connection_id=retry_source_connection_id,
session_ticket_fetcher=self._session_ticket_fetcher,
session_ticket_handler=self._session_ticket_handler,
)
protocol = self._create_protocol(
connection, stream_handler=self._stream_handler
)
protocol.connection_made(self._transport)
# register callbacks
protocol._connection_id_issued_handler = partial(
self._connection_id_issued, protocol=protocol
)
protocol._connection_id_retired_handler = partial(
self._connection_id_retired, protocol=protocol
)
protocol._connection_terminated_handler = partial(
self._connection_terminated, protocol=protocol
)
self._protocols[header.destination_cid] = protocol
self._protocols[connection.host_cid] = protocol
if protocol is not None:
protocol.datagram_received(data, addr)
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
self._protocols[cid] = protocol
def _connection_id_retired(
self, cid: bytes, protocol: QuicConnectionProtocol
) -> None:
assert self._protocols[cid] == protocol
del self._protocols[cid]
def _connection_terminated(self, protocol: QuicConnectionProtocol):
for cid, proto in list(self._protocols.items()):
if proto == protocol:
del self._protocols[cid]
async def serve(
host: str,
port: int,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
retry: bool = False,
stream_handler: QuicStreamHandler = None,
) -> QuicServer:
"""
Start a QUIC server at the given `host` and `port`.
:func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration`
containing TLS certificate and private key as the ``configuration`` argument.
:func:`serve` also accepts the following optional arguments:
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_fetcher`` is a callback which is invoked by the TLS
engine when a session ticket is presented by the peer. It should return
the session ticket with the specified ID or `None` if it is not found.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is issued. It should store the session
ticket for future lookup.
* ``retry`` specifies whether client addresses should be validated prior to
the cryptographic handshake using a retry packet.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
"""
loop = asyncio.get_event_loop()
_, protocol = await loop.create_datagram_endpoint(
lambda: QuicServer(
configuration=configuration,
create_protocol=create_protocol,
session_ticket_fetcher=session_ticket_fetcher,
session_ticket_handler=session_ticket_handler,
retry=retry,
stream_handler=stream_handler,
),
local_addr=(host, port),
)
return protocol

View File

@ -0,0 +1,30 @@
from ._buffer import Buffer, BufferReadError, BufferWriteError # noqa
UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF
UINT_VAR_MAX_SIZE = 8
def encode_uint_var(value: int) -> bytes:
"""
Encode a variable-length unsigned integer.
"""
buf = Buffer(capacity=UINT_VAR_MAX_SIZE)
buf.push_uint_var(value)
return buf.data
def size_uint_var(value: int) -> int:
"""
Return the number of bytes required to encode the given value
as a QUIC variable-length unsigned integer.
"""
if value <= 0x3F:
return 1
elif value <= 0x3FFF:
return 2
elif value <= 0x3FFFFFFF:
return 4
elif value <= 0x3FFFFFFFFFFFFFFF:
return 8
else:
raise ValueError("Integer is too big for a variable-length integer")

View File

@ -0,0 +1,68 @@
from typing import Dict, List
from aioquic.h3.events import DataReceived, H3Event, Headers, HeadersReceived
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import QuicEvent, StreamDataReceived
H0_ALPN = ["hq-interop"]
class H0Connection:
"""
An HTTP/0.9 connection object.
"""
def __init__(self, quic: QuicConnection):
self._buffer: Dict[int, bytes] = {}
self._headers_received: Dict[int, bool] = {}
self._is_client = quic.configuration.is_client
self._quic = quic
def handle_event(self, event: QuicEvent) -> List[H3Event]:
http_events: List[H3Event] = []
if isinstance(event, StreamDataReceived) and (event.stream_id % 4) == 0:
data = self._buffer.pop(event.stream_id, b"") + event.data
if not self._headers_received.get(event.stream_id, False):
if self._is_client:
http_events.append(
HeadersReceived(
headers=[], stream_ended=False, stream_id=event.stream_id
)
)
elif data.endswith(b"\r\n") or event.end_stream:
method, path = data.rstrip().split(b" ", 1)
http_events.append(
HeadersReceived(
headers=[(b":method", method), (b":path", path)],
stream_ended=False,
stream_id=event.stream_id,
)
)
data = b""
else:
# incomplete request, stash the data
self._buffer[event.stream_id] = data
return http_events
self._headers_received[event.stream_id] = True
http_events.append(
DataReceived(
data=data, stream_ended=event.end_stream, stream_id=event.stream_id
)
)
return http_events
def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None:
self._quic.send_stream_data(stream_id, data, end_stream)
def send_headers(
self, stream_id: int, headers: Headers, end_stream: bool = False
) -> None:
if self._is_client:
headers_dict = dict(headers)
data = headers_dict[b":method"] + b" " + headers_dict[b":path"] + b"\r\n"
else:
data = b""
self._quic.send_stream_data(stream_id, data, end_stream)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,100 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
Headers = List[Tuple[bytes, bytes]]
class H3Event:
"""
Base class for HTTP/3 events.
"""
@dataclass
class DataReceived(H3Event):
"""
The DataReceived event is fired whenever data is received on a stream from
the remote peer.
"""
data: bytes
"The data which was received."
stream_id: int
"The ID of the stream the data was received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
push_id: Optional[int] = None
"The Push ID or `None` if this is not a push."
@dataclass
class DatagramReceived(H3Event):
"""
The DatagramReceived is fired whenever a datagram is received from the
the remote peer.
"""
data: bytes
"The data which was received."
stream_id: int
"The ID of the stream the data was received for."
@dataclass
class HeadersReceived(H3Event):
"""
The HeadersReceived event is fired whenever headers are received.
"""
headers: Headers
"The headers."
stream_id: int
"The ID of the stream the headers were received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
push_id: Optional[int] = None
"The Push ID or `None` if this is not a push."
@dataclass
class PushPromiseReceived(H3Event):
"""
The PushedStreamReceived event is fired whenever a pushed stream has been
received from the remote peer.
"""
headers: Headers
"The request headers."
push_id: int
"The Push ID of the push promise."
stream_id: int
"The Stream ID of the stream that the push is related to."
@dataclass
class WebTransportStreamDataReceived(H3Event):
"""
The WebTransportStreamDataReceived is fired whenever data is received
for a WebTransport stream.
"""
data: bytes
"The data which was received."
stream_id: int
"The ID of the stream the data was received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
session_id: int
"The ID of the session the data was received for."

View File

@ -0,0 +1,17 @@
class H3Error(Exception):
"""
Base class for HTTP/3 exceptions.
"""
class InvalidStreamTypeError(H3Error):
"""
An action was attempted on an invalid stream type.
"""
class NoAvailablePushIDError(H3Error):
"""
There are no available push IDs left, or push is not supported
by the remote party.
"""

View File

@ -0,0 +1 @@
Marker

View File

@ -0,0 +1,163 @@
from dataclasses import dataclass, field
from os import PathLike
from re import split
from typing import Any, List, Optional, TextIO, Union
from ..tls import (
CipherSuite,
SessionTicket,
load_pem_private_key,
load_pem_x509_certificates,
)
from .logger import QuicLogger
from .packet import QuicProtocolVersion
SMALLEST_MAX_DATAGRAM_SIZE = 1200
@dataclass
class QuicConfiguration:
"""
A QUIC configuration.
"""
alpn_protocols: Optional[List[str]] = None
"""
A list of supported ALPN protocols.
"""
congestion_control_algorithm: str = "reno"
"""
The name of the congestion control algorithm to use.
Currently supported algorithms: `"reno", `"cubic"`.
"""
connection_id_length: int = 8
"""
The length in bytes of local connection IDs.
"""
idle_timeout: float = 60.0
"""
The idle timeout in seconds.
The connection is terminated if nothing is received for the given duration.
"""
is_client: bool = True
"""
Whether this is the client side of the QUIC connection.
"""
max_data: int = 1048576
"""
Connection-wide flow control limit.
"""
max_datagram_size: int = SMALLEST_MAX_DATAGRAM_SIZE
"""
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
"""
max_stream_data: int = 1048576
"""
Per-stream flow control limit.
"""
quic_logger: Optional[QuicLogger] = None
"""
The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to.
"""
secrets_log_file: TextIO = None
"""
A file-like object in which to log traffic secrets.
This is useful to analyze traffic captures with Wireshark.
"""
server_name: Optional[str] = None
"""
The server name to use when verifying the server's TLS certificate, which
can either be a DNS name or an IP address.
If it is a DNS name, it is also sent during the TLS handshake in the
Server Name Indication (SNI) extension.
.. note:: This is only used by clients.
"""
session_ticket: Optional[SessionTicket] = None
"""
The TLS session ticket which should be used for session resumption.
"""
token: bytes = b""
"""
The address validation token that can be used to validate future connections.
.. note:: This is only used by clients.
"""
# For internal purposes, not guaranteed to be stable.
cadata: Optional[bytes] = None
cafile: Optional[str] = None
capath: Optional[str] = None
certificate: Any = None
certificate_chain: List[Any] = field(default_factory=list)
cipher_suites: Optional[List[CipherSuite]] = None
initial_rtt: float = 0.1
max_datagram_frame_size: Optional[int] = None
original_version: Optional[int] = None
private_key: Any = None
quantum_readiness_test: bool = False
supported_versions: List[int] = field(
default_factory=lambda: [
QuicProtocolVersion.VERSION_1,
QuicProtocolVersion.VERSION_2,
]
)
verify_mode: Optional[int] = None
def load_cert_chain(
self,
certfile: PathLike,
keyfile: Optional[PathLike] = None,
password: Optional[Union[bytes, str]] = None,
) -> None:
"""
Load a private key and the corresponding certificate.
"""
with open(certfile, "rb") as fp:
boundary = b"-----BEGIN PRIVATE KEY-----\n"
chunks = split(b"\n" + boundary, fp.read())
certificates = load_pem_x509_certificates(chunks[0])
if len(chunks) == 2:
private_key = boundary + chunks[1]
self.private_key = load_pem_private_key(private_key)
self.certificate = certificates[0]
self.certificate_chain = certificates[1:]
if keyfile is not None:
with open(keyfile, "rb") as fp:
self.private_key = load_pem_private_key(
fp.read(),
password=password.encode("utf8")
if isinstance(password, str)
else password,
)
def load_verify_locations(
self,
cafile: Optional[str] = None,
capath: Optional[str] = None,
cadata: Optional[bytes] = None,
) -> None:
"""
Load a set of "certification authority" (CA) certificates used to
validate other peers' certificates.
"""
self.cafile = cafile
self.capath = capath
self.cadata = cadata

View File

@ -0,0 +1,128 @@
import abc
from typing import Any, Dict, Iterable, Optional, Protocol
from ..packet_builder import QuicSentPacket
K_GRANULARITY = 0.001 # seconds
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2
class QuicCongestionControl(abc.ABC):
"""
Base class for congestion control implementations.
"""
bytes_in_flight: int = 0
congestion_window: int = 0
ssthresh: Optional[int] = None
def __init__(self, *, max_datagram_size: int) -> None:
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size
@abc.abstractmethod
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ...
@abc.abstractmethod
def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ...
@abc.abstractmethod
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ...
@abc.abstractmethod
def on_packets_lost(
self, *, now: float, packets: Iterable[QuicSentPacket]
) -> None: ...
@abc.abstractmethod
def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ...
def get_log_data(self) -> Dict[str, Any]:
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
if self.ssthresh is not None:
data["ssthresh"] = self.ssthresh
return data
class QuicCongestionControlFactory(Protocol):
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ...
class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""
def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5
self._filtered_min: Optional[float] = None
self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]
def add_rtt(self, *, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1
if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True
if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample
def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt=rtt)
self._sample_time = now
if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max
delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False
_factories: Dict[str, QuicCongestionControlFactory] = {}
def create_congestion_control(
name: str, *, max_datagram_size: int
) -> QuicCongestionControl:
"""
Create an instance of the `name` congestion control algorithm.
"""
try:
factory = _factories[name]
except KeyError:
raise Exception(f"Unknown congestion control algorithm: {name}")
return factory(max_datagram_size=max_datagram_size)
def register_congestion_control(
name: str, factory: QuicCongestionControlFactory
) -> None:
"""
Register a congestion control algorithm named `name`.
"""
_factories[name] = factory

View File

@ -0,0 +1,212 @@
from typing import Any, Dict, Iterable
from ..packet_builder import QuicSentPacket
from .base import (
K_INITIAL_WINDOW,
K_MINIMUM_WINDOW,
QuicCongestionControl,
QuicRttMonitor,
register_congestion_control,
)
# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions)
K_CUBIC_C = 0.4
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity
def better_cube_root(x: float) -> float:
if x < 0:
# avoid precision errors that make the cube root returns an imaginary number
return -((-x) ** (1.0 / 3.0))
else:
return (x) ** (1.0 / 3.0)
class CubicCongestionControl(QuicCongestionControl):
"""
Cubic congestion control implementation for aioquic
"""
def __init__(self, max_datagram_size: int) -> None:
super().__init__(max_datagram_size=max_datagram_size)
# increase by one segment
self.additive_increase_factor: int = max_datagram_size
self._max_datagram_size: int = max_datagram_size
self._congestion_recovery_start_time = 0.0
self._rtt_monitor = QuicRttMonitor()
self.rtt = 0.02 # starting RTT is considered to be 20ms
self.reset()
self.last_ack = 0.0
def W_cubic(self, t) -> int:
W_max_segments = self._W_max / self._max_datagram_size
target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments)
return int(target_segments * self._max_datagram_size)
def is_reno_friendly(self, t) -> bool:
return self.W_cubic(t) < self._W_est
def is_concave(self) -> bool:
return self.congestion_window < self._W_max
def reset(self) -> None:
self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size
self.ssthresh = None
self._first_slow_start = True
self._starting_congestion_avoidance = False
self.K: float = 0.0
self._W_est = 0
self._cwnd_epoch = 0
self._t_epoch = 0.0
self._W_max = self.congestion_window
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
self.last_ack = packet.sent_time
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
if self._first_slow_start and not self._starting_congestion_avoidance:
# exiting slow start without having a loss
self._first_slow_start = False
self._W_max = self.congestion_window
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
# calculate K
W_max_segments = self._W_max / self._max_datagram_size
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
self.K = better_cube_root(
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
)
# initialize the variables used at start of congestion avoidance
if self._starting_congestion_avoidance:
self._starting_congestion_avoidance = False
self._first_slow_start = False
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
# calculate K
W_max_segments = self._W_max / self._max_datagram_size
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
self.K = better_cube_root(
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
)
self._W_est = int(
self._W_est
+ self.additive_increase_factor
* (packet.sent_bytes / self.congestion_window)
)
t = now - self._t_epoch
target: int = 0
W_cubic = self.W_cubic(t + self.rtt)
if W_cubic < self.congestion_window:
target = self.congestion_window
elif W_cubic > 1.5 * self.congestion_window:
target = int(self.congestion_window * 1.5)
else:
target = W_cubic
if self.is_reno_friendly(t):
# reno friendly region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region)
self.congestion_window = self._W_est
elif self.is_concave():
# concave region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region)
self.congestion_window = int(
self.congestion_window
+ (
(target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
)
else:
# convex region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region)
self.congestion_window = int(
self.congestion_window
+ (
(target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
)
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
if self.last_ack == 0.0:
return
elapsed_idle = packet.sent_time - self.last_ack
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
self.reset()
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
# Normal congestion handle, can't be used in same time as fast convergence
# self._W_max = self.congestion_window
# fast convergence
if self._W_max is not None and self.congestion_window < self._W_max:
self._W_max = int(
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
)
else:
self._W_max = self.congestion_window
# normal congestion MD
flight_size = self.bytes_in_flight
new_ssthresh = max(
int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR),
K_MINIMUM_WINDOW * self._max_datagram_size,
)
self.ssthresh = new_ssthresh
self.congestion_window = max(
self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size
)
# restart a new congestion avoidance phase
self._starting_congestion_avoidance = True
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
self.rtt = rtt
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
rtt=rtt, now=now
):
self.ssthresh = self.congestion_window
def get_log_data(self) -> Dict[str, Any]:
data = super().get_log_data()
data["cubic-wmax"] = int(self._W_max)
return data
register_congestion_control("cubic", CubicCongestionControl)

View File

@ -0,0 +1,77 @@
from typing import Iterable
from ..packet_builder import QuicSentPacket
from .base import (
K_MINIMUM_WINDOW,
QuicCongestionControl,
QuicRttMonitor,
register_congestion_control,
)
K_LOSS_REDUCTION_FACTOR = 0.5
class RenoCongestionControl(QuicCongestionControl):
"""
New Reno congestion control.
"""
def __init__(self, *, max_datagram_size: int) -> None:
super().__init__(max_datagram_size=max_datagram_size)
self._max_datagram_size = max_datagram_size
self._congestion_recovery_start_time = 0.0
self._congestion_stash = 0
self._rtt_monitor = QuicRttMonitor()
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
# don't increase window in congestion recovery
if packet.sent_time <= self._congestion_recovery_start_time:
return
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
self._congestion_stash += packet.sent_bytes
count = self._congestion_stash // self.congestion_window
if count:
self._congestion_stash -= count * self.congestion_window
self.congestion_window += count * self._max_datagram_size
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
self.congestion_window = max(
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR),
K_MINIMUM_WINDOW * self._max_datagram_size,
)
self.ssthresh = self.congestion_window
# TODO : collapse congestion window if persistent congestion
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
now=now, rtt=rtt
):
self.ssthresh = self.congestion_window
register_congestion_control("reno", RenoCongestionControl)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,246 @@
import binascii
from typing import Callable, Optional, Tuple
from .._crypto import AEAD, CryptoError, HeaderProtection
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
from .packet import (
QuicProtocolVersion,
decode_packet_number,
is_long_header,
)
CIPHER_SUITES = {
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
}
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
SAMPLE_SIZE = 16
Callback = Callable[[str], None]
def NoCallback(trigger: str) -> None:
pass
class KeyUnavailableError(CryptoError):
pass
def derive_key_iv_hp(
*, cipher_suite: CipherSuite, secret: bytes, version: int
) -> Tuple[bytes, bytes, bytes]:
algorithm = cipher_suite_hash(cipher_suite)
if cipher_suite in [
CipherSuite.AES_256_GCM_SHA384,
CipherSuite.CHACHA20_POLY1305_SHA256,
]:
key_size = 32
else:
key_size = 16
if version == QuicProtocolVersion.VERSION_2:
return (
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
)
else:
return (
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
)
class CryptoContext:
def __init__(
self,
key_phase: int = 0,
setup_cb: Callback = NoCallback,
teardown_cb: Callback = NoCallback,
) -> None:
self.aead: Optional[AEAD] = None
self.cipher_suite: Optional[CipherSuite] = None
self.hp: Optional[HeaderProtection] = None
self.key_phase = key_phase
self.secret: Optional[bytes] = None
self.version: Optional[int] = None
self._setup_cb = setup_cb
self._teardown_cb = teardown_cb
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> Tuple[bytes, bytes, int, bool]:
if self.aead is None:
raise KeyUnavailableError("Decryption key is not available")
# header protection
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
first_byte = plain_header[0]
# packet number
pn_length = (first_byte & 0x03) + 1
packet_number = decode_packet_number(
packet_number, pn_length * 8, expected_packet_number
)
# detect key phase change
crypto = self
if not is_long_header(first_byte):
key_phase = (first_byte & 4) >> 2
if key_phase != self.key_phase:
crypto = next_key_phase(self)
# payload protection
payload = crypto.aead.decrypt(
packet[len(plain_header) :], plain_header, packet_number
)
return plain_header, payload, packet_number, crypto != self
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
assert self.is_valid(), "Encryption key is not available"
# payload protection
protected_payload = self.aead.encrypt(
plain_payload, plain_header, packet_number
)
# header protection
return self.hp.apply(plain_header, protected_payload)
def is_valid(self) -> bool:
return self.aead is not None
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
key, iv, hp = derive_key_iv_hp(
cipher_suite=cipher_suite,
secret=secret,
version=version,
)
self.aead = AEAD(aead_cipher_name, key, iv)
self.cipher_suite = cipher_suite
self.hp = HeaderProtection(hp_cipher_name, hp)
self.secret = secret
self.version = version
# trigger callback
self._setup_cb("tls")
def teardown(self) -> None:
self.aead = None
self.cipher_suite = None
self.hp = None
self.secret = None
# trigger callback
self._teardown_cb("tls")
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
self.aead = crypto.aead
self.key_phase = crypto.key_phase
self.secret = crypto.secret
# trigger callback
self._setup_cb(trigger)
def next_key_phase(self: CryptoContext) -> CryptoContext:
algorithm = cipher_suite_hash(self.cipher_suite)
crypto = CryptoContext(key_phase=int(not self.key_phase))
crypto.setup(
cipher_suite=self.cipher_suite,
secret=hkdf_expand_label(
algorithm, self.secret, b"quic ku", b"", algorithm.digest_size
),
version=self.version,
)
return crypto
class CryptoPair:
def __init__(
self,
recv_setup_cb: Callback = NoCallback,
recv_teardown_cb: Callback = NoCallback,
send_setup_cb: Callback = NoCallback,
send_teardown_cb: Callback = NoCallback,
) -> None:
self.aead_tag_size = 16
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
self._update_key_requested = False
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> Tuple[bytes, bytes, int]:
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
packet, encrypted_offset, expected_packet_number
)
if update_key:
self._update_key("remote_update")
return plain_header, payload, packet_number
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
if self._update_key_requested:
self._update_key("local_update")
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
if is_client:
recv_label, send_label = b"server in", b"client in"
else:
recv_label, send_label = b"client in", b"server in"
if version == QuicProtocolVersion.VERSION_2:
initial_salt = INITIAL_SALT_VERSION_2
else:
initial_salt = INITIAL_SALT_VERSION_1
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
self.recv.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, recv_label, b"", algorithm.digest_size
),
version=version,
)
self.send.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, send_label, b"", algorithm.digest_size
),
version=version,
)
def teardown(self) -> None:
self.recv.teardown()
self.send.teardown()
def update_key(self) -> None:
self._update_key_requested = True
@property
def key_phase(self) -> int:
if self._update_key_requested:
return int(not self.recv.key_phase)
else:
return self.recv.key_phase
def _update_key(self, trigger: str) -> None:
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
self._update_key_requested = False

View File

@ -0,0 +1,126 @@
from dataclasses import dataclass
from typing import Optional
class QuicEvent:
"""
Base class for QUIC events.
"""
pass
@dataclass
class ConnectionIdIssued(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionIdRetired(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionTerminated(QuicEvent):
"""
The ConnectionTerminated event is fired when the QUIC connection is terminated.
"""
error_code: int
"The error code which was specified when closing the connection."
frame_type: Optional[int]
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
@dataclass
class DatagramFrameReceived(QuicEvent):
"""
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
"""
data: bytes
"The data which was received."
@dataclass
class HandshakeCompleted(QuicEvent):
"""
The HandshakeCompleted event is fired when the TLS handshake completes.
"""
alpn_protocol: Optional[str]
"The protocol which was negotiated using ALPN, or `None`."
early_data_accepted: bool
"Whether early (0-RTT) data was accepted by the remote peer."
session_resumed: bool
"Whether a TLS session was resumed."
@dataclass
class PingAcknowledged(QuicEvent):
"""
The PingAcknowledged event is fired when a PING frame is acknowledged.
"""
uid: int
"The unique ID of the PING."
@dataclass
class ProtocolNegotiated(QuicEvent):
"""
The ProtocolNegotiated event is fired when ALPN negotiation completes.
"""
alpn_protocol: Optional[str]
"The protocol which was negotiated using ALPN, or `None`."
@dataclass
class StopSendingReceived(QuicEvent):
"""
The StopSendingReceived event is fired when the remote peer requests
stopping data transmission on a stream.
"""
error_code: int
"The error code that was sent from the peer."
stream_id: int
"The ID of the stream that the peer requested stopping data transmission."
@dataclass
class StreamDataReceived(QuicEvent):
"""
The StreamDataReceived event is fired whenever data is received on a
stream.
"""
data: bytes
"The data which was received."
end_stream: bool
"Whether the STREAM frame had the FIN bit set."
stream_id: int
"The ID of the stream the data was received for."
@dataclass
class StreamReset(QuicEvent):
"""
The StreamReset event is fired when the remote peer resets a stream.
"""
error_code: int
"The error code that triggered the reset."
stream_id: int
"The ID of the stream that was reset."

View File

@ -0,0 +1,329 @@
import binascii
import json
import os
import time
from collections import deque
from typing import Any, Deque, Dict, List, Optional
from ..h3.events import Headers
from .packet import (
QuicFrameType,
QuicPacketType,
QuicStreamFrame,
QuicTransportParameters,
)
from .rangeset import RangeSet
PACKET_TYPE_NAMES = {
QuicPacketType.INITIAL: "initial",
QuicPacketType.HANDSHAKE: "handshake",
QuicPacketType.ZERO_RTT: "0RTT",
QuicPacketType.ONE_RTT: "1RTT",
QuicPacketType.RETRY: "retry",
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
}
QLOG_VERSION = "0.3"
def hexdump(data: bytes) -> str:
return binascii.hexlify(data).decode("ascii")
class QuicLoggerTrace:
"""
A QUIC event trace.
Events are logged in the format defined by qlog.
See:
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
"""
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
self._odcid = odcid
self._events: Deque[Dict[str, Any]] = deque()
self._vantage_point = {
"name": "aioquic",
"type": "client" if is_client else "server",
}
# QUIC
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict:
return {
"ack_delay": self.encode_time(delay),
"acked_ranges": [[x.start, x.stop - 1] for x in ranges],
"frame_type": "ack",
}
def encode_connection_close_frame(
self, error_code: int, frame_type: Optional[int], reason_phrase: str
) -> Dict:
attrs = {
"error_code": error_code,
"error_space": "application" if frame_type is None else "transport",
"frame_type": "connection_close",
"raw_error_code": error_code,
"reason": reason_phrase,
}
if frame_type is not None:
attrs["trigger_frame_type"] = frame_type
return attrs
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict:
if frame_type == QuicFrameType.MAX_DATA:
return {"frame_type": "max_data", "maximum": maximum}
else:
return {
"frame_type": "max_streams",
"maximum": maximum,
"stream_type": "unidirectional"
if frame_type == QuicFrameType.MAX_STREAMS_UNI
else "bidirectional",
}
def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict:
return {
"frame_type": "crypto",
"length": len(frame.data),
"offset": frame.offset,
}
def encode_data_blocked_frame(self, limit: int) -> Dict:
return {"frame_type": "data_blocked", "limit": limit}
def encode_datagram_frame(self, length: int) -> Dict:
return {"frame_type": "datagram", "length": length}
def encode_handshake_done_frame(self) -> Dict:
return {"frame_type": "handshake_done"}
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict:
return {
"frame_type": "max_stream_data",
"maximum": maximum,
"stream_id": stream_id,
}
def encode_new_connection_id_frame(
self,
connection_id: bytes,
retire_prior_to: int,
sequence_number: int,
stateless_reset_token: bytes,
) -> Dict:
return {
"connection_id": hexdump(connection_id),
"frame_type": "new_connection_id",
"length": len(connection_id),
"reset_token": hexdump(stateless_reset_token),
"retire_prior_to": retire_prior_to,
"sequence_number": sequence_number,
}
def encode_new_token_frame(self, token: bytes) -> Dict:
return {
"frame_type": "new_token",
"length": len(token),
"token": hexdump(token),
}
def encode_padding_frame(self) -> Dict:
return {"frame_type": "padding"}
def encode_path_challenge_frame(self, data: bytes) -> Dict:
return {"data": hexdump(data), "frame_type": "path_challenge"}
def encode_path_response_frame(self, data: bytes) -> Dict:
return {"data": hexdump(data), "frame_type": "path_response"}
def encode_ping_frame(self) -> Dict:
return {"frame_type": "ping"}
def encode_reset_stream_frame(
self, error_code: int, final_size: int, stream_id: int
) -> Dict:
return {
"error_code": error_code,
"final_size": final_size,
"frame_type": "reset_stream",
"stream_id": stream_id,
}
def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict:
return {
"frame_type": "retire_connection_id",
"sequence_number": sequence_number,
}
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict:
return {
"frame_type": "stream_data_blocked",
"limit": limit,
"stream_id": stream_id,
}
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict:
return {
"frame_type": "stop_sending",
"error_code": error_code,
"stream_id": stream_id,
}
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict:
return {
"fin": frame.fin,
"frame_type": "stream",
"length": len(frame.data),
"offset": frame.offset,
"stream_id": stream_id,
}
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict:
return {
"frame_type": "streams_blocked",
"limit": limit,
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
}
def encode_time(self, seconds: float) -> float:
"""
Convert a time to milliseconds.
"""
return seconds * 1000
def encode_transport_parameters(
self, owner: str, parameters: QuicTransportParameters
) -> Dict[str, Any]:
data: Dict[str, Any] = {"owner": owner}
for param_name, param_value in parameters.__dict__.items():
if isinstance(param_value, bool):
data[param_name] = param_value
elif isinstance(param_value, bytes):
data[param_name] = hexdump(param_value)
elif isinstance(param_value, int):
data[param_name] = param_value
return data
def packet_type(self, packet_type: QuicPacketType) -> str:
return PACKET_TYPE_NAMES[packet_type]
# HTTP/3
def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict:
return {
"frame": {"frame_type": "data"},
"length": length,
"stream_id": stream_id,
}
def encode_http3_headers_frame(
self, length: int, headers: Headers, stream_id: int
) -> Dict:
return {
"frame": {
"frame_type": "headers",
"headers": self._encode_http3_headers(headers),
},
"length": length,
"stream_id": stream_id,
}
def encode_http3_push_promise_frame(
self, length: int, headers: Headers, push_id: int, stream_id: int
) -> Dict:
return {
"frame": {
"frame_type": "push_promise",
"headers": self._encode_http3_headers(headers),
"push_id": push_id,
},
"length": length,
"stream_id": stream_id,
}
def _encode_http3_headers(self, headers: Headers) -> List[Dict]:
return [
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
]
# CORE
def log_event(self, *, category: str, event: str, data: Dict) -> None:
self._events.append(
{
"data": data,
"name": category + ":" + event,
"time": self.encode_time(time.time()),
}
)
def to_dict(self) -> Dict[str, Any]:
"""
Return the trace as a dictionary which can be written as JSON.
"""
return {
"common_fields": {
"ODCID": hexdump(self._odcid),
},
"events": list(self._events),
"vantage_point": self._vantage_point,
}
class QuicLogger:
"""
A QUIC event logger which stores traces in memory.
"""
def __init__(self) -> None:
self._traces: List[QuicLoggerTrace] = []
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
self._traces.append(trace)
return trace
def end_trace(self, trace: QuicLoggerTrace) -> None:
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
def to_dict(self) -> Dict[str, Any]:
"""
Return the traces as a dictionary which can be written as JSON.
"""
return {
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace.to_dict() for trace in self._traces],
}
class QuicFileLogger(QuicLogger):
"""
A QUIC event logger which writes one trace per file.
"""
def __init__(self, path: str) -> None:
if not os.path.isdir(path):
raise ValueError("QUIC log output directory '%s' does not exist" % path)
self.path = path
super().__init__()
def end_trace(self, trace: QuicLoggerTrace) -> None:
trace_dict = trace.to_dict()
trace_path = os.path.join(
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
)
with open(trace_path, "w") as logger_fp:
json.dump(
{
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace_dict],
},
logger_fp,
)
self._traces.remove(trace)

View File

@ -0,0 +1,640 @@
import binascii
import ipaddress
import os
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import List, Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ..buffer import Buffer
from .rangeset import RangeSet
PACKET_LONG_HEADER = 0x80
PACKET_FIXED_BIT = 0x40
PACKET_SPIN_BIT = 0x20
CONNECTION_ID_MAX_SIZE = 20
PACKET_NUMBER_MAX_SIZE = 4
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
RETRY_INTEGRITY_TAG_SIZE = 16
STATELESS_RESET_TOKEN_SIZE = 16
class QuicErrorCode(IntEnum):
NO_ERROR = 0x0
INTERNAL_ERROR = 0x1
CONNECTION_REFUSED = 0x2
FLOW_CONTROL_ERROR = 0x3
STREAM_LIMIT_ERROR = 0x4
STREAM_STATE_ERROR = 0x5
FINAL_SIZE_ERROR = 0x6
FRAME_ENCODING_ERROR = 0x7
TRANSPORT_PARAMETER_ERROR = 0x8
CONNECTION_ID_LIMIT_ERROR = 0x9
PROTOCOL_VIOLATION = 0xA
INVALID_TOKEN = 0xB
APPLICATION_ERROR = 0xC
CRYPTO_BUFFER_EXCEEDED = 0xD
KEY_UPDATE_ERROR = 0xE
AEAD_LIMIT_REACHED = 0xF
VERSION_NEGOTIATION_ERROR = 0x11
CRYPTO_ERROR = 0x100
class QuicPacketType(Enum):
INITIAL = 0
ZERO_RTT = 1
HANDSHAKE = 2
RETRY = 3
VERSION_NEGOTIATION = 4
ONE_RTT = 5
# For backwards compatibility only, use `QuicPacketType` in new code.
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
# QUIC version 1
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
QuicPacketType.INITIAL: 0,
QuicPacketType.ZERO_RTT: 1,
QuicPacketType.HANDSHAKE: 2,
QuicPacketType.RETRY: 3,
}
PACKET_LONG_TYPE_DECODE_VERSION_1 = dict(
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
)
# QUIC version 2
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
QuicPacketType.INITIAL: 1,
QuicPacketType.ZERO_RTT: 2,
QuicPacketType.HANDSHAKE: 3,
QuicPacketType.RETRY: 0,
}
PACKET_LONG_TYPE_DECODE_VERSION_2 = dict(
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
)
class QuicProtocolVersion(IntEnum):
NEGOTIATION = 0
VERSION_1 = 0x00000001
VERSION_2 = 0x6B3343CF
@dataclass
class QuicHeader:
version: Optional[int]
"The protocol version. Only present in long header packets."
packet_type: QuicPacketType
"The type of the packet."
packet_length: int
"The total length of the packet, in bytes."
destination_cid: bytes
"The destination connection ID."
source_cid: bytes
"The destination connection ID."
token: bytes
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
integrity_tag: bytes
"The retry integrity tag. Only present in `RETRY` packets."
supported_versions: List[int]
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
"""
Recover a packet number from a truncated packet number.
See: Appendix A - Sample Packet Number Decoding Algorithm
"""
window = 1 << num_bits
half_window = window // 2
candidate = (expected & ~(window - 1)) | truncated
if candidate <= expected - half_window and candidate < (1 << 62) - window:
return candidate + window
elif candidate > expected + half_window and candidate >= window:
return candidate - window
else:
return candidate
def get_retry_integrity_tag(
packet_without_tag: bytes, original_destination_cid: bytes, version: int
) -> bytes:
"""
Calculate the integrity tag for a RETRY packet.
"""
# build Retry pseudo packet
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
buf.push_uint8(len(original_destination_cid))
buf.push_bytes(original_destination_cid)
buf.push_bytes(packet_without_tag)
assert buf.eof()
if version == QuicProtocolVersion.VERSION_2:
aead_key = RETRY_AEAD_KEY_VERSION_2
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
else:
aead_key = RETRY_AEAD_KEY_VERSION_1
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
# run AES-128-GCM
aead = AESGCM(aead_key)
integrity_tag = aead.encrypt(aead_nonce, b"", buf.data)
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
return integrity_tag
def get_spin_bit(first_byte: int) -> bool:
return bool(first_byte & PACKET_SPIN_BIT)
def is_long_header(first_byte: int) -> bool:
return bool(first_byte & PACKET_LONG_HEADER)
def pretty_protocol_version(version: int) -> str:
"""
Return a user-friendly representation of a protocol version.
"""
try:
version_name = QuicProtocolVersion(version).name
except ValueError:
version_name = "UNKNOWN"
return f"0x{version:08x} ({version_name})"
def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader:
packet_start = buf.tell()
version = None
integrity_tag = b""
supported_versions = []
token = b""
first_byte = buf.pull_uint8()
if is_long_header(first_byte):
# Long Header Packets.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
version = buf.pull_uint32()
destination_cid_length = buf.pull_uint8()
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError(
"Destination CID is too long (%d bytes)" % destination_cid_length
)
destination_cid = buf.pull_bytes(destination_cid_length)
source_cid_length = buf.pull_uint8()
if source_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError("Source CID is too long (%d bytes)" % source_cid_length)
source_cid = buf.pull_bytes(source_cid_length)
if version == QuicProtocolVersion.NEGOTIATION:
# Version Negotiation Packet.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
packet_type = QuicPacketType.VERSION_NEGOTIATION
while not buf.eof():
supported_versions.append(buf.pull_uint32())
packet_end = buf.tell()
else:
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
if version == QuicProtocolVersion.VERSION_2:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
(first_byte & 0x30) >> 4
]
else:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
(first_byte & 0x30) >> 4
]
if packet_type == QuicPacketType.INITIAL:
token_length = buf.pull_uint_var()
token = buf.pull_bytes(token_length)
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.ZERO_RTT:
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.HANDSHAKE:
rest_length = buf.pull_uint_var()
else:
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
token = buf.pull_bytes(token_length)
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
rest_length = 0
# Check remainder length.
packet_end = buf.tell() + rest_length
if packet_end > buf.capacity:
raise ValueError("Packet payload is truncated")
else:
# Short Header Packets.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
version = None
packet_type = QuicPacketType.ONE_RTT
destination_cid = buf.pull_bytes(host_cid_length)
source_cid = b""
packet_end = buf.capacity
return QuicHeader(
version=version,
packet_type=packet_type,
packet_length=packet_end - packet_start,
destination_cid=destination_cid,
source_cid=source_cid,
token=token,
integrity_tag=integrity_tag,
supported_versions=supported_versions,
)
def encode_long_header_first_byte(
version: int, packet_type: QuicPacketType, bits: int
) -> int:
"""
Encode the first byte of a long header packet.
"""
if version == QuicProtocolVersion.VERSION_2:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
else:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
return (
PACKET_LONG_HEADER
| PACKET_FIXED_BIT
| long_type_encode[packet_type] << 4
| bits
)
def encode_quic_retry(
version: int,
source_cid: bytes,
destination_cid: bytes,
original_destination_cid: bytes,
retry_token: bytes,
unused: int = 0,
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ len(retry_token)
+ RETRY_INTEGRITY_TAG_SIZE
)
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
buf.push_uint32(version)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
buf.push_bytes(retry_token)
buf.push_bytes(
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
)
assert buf.eof()
return buf.data
def encode_quic_version_negotiation(
source_cid: bytes, destination_cid: bytes, supported_versions: List[int]
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ 4 * len(supported_versions)
)
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
for version in supported_versions:
buf.push_uint32(version)
return buf.data
# TLS EXTENSION
@dataclass
class QuicPreferredAddress:
ipv4_address: Optional[Tuple[str, int]]
ipv6_address: Optional[Tuple[str, int]]
connection_id: bytes
stateless_reset_token: bytes
@dataclass
class QuicVersionInformation:
chosen_version: int
available_versions: List[int]
@dataclass
class QuicTransportParameters:
original_destination_connection_id: Optional[bytes] = None
max_idle_timeout: Optional[int] = None
stateless_reset_token: Optional[bytes] = None
max_udp_payload_size: Optional[int] = None
initial_max_data: Optional[int] = None
initial_max_stream_data_bidi_local: Optional[int] = None
initial_max_stream_data_bidi_remote: Optional[int] = None
initial_max_stream_data_uni: Optional[int] = None
initial_max_streams_bidi: Optional[int] = None
initial_max_streams_uni: Optional[int] = None
ack_delay_exponent: Optional[int] = None
max_ack_delay: Optional[int] = None
disable_active_migration: Optional[bool] = False
preferred_address: Optional[QuicPreferredAddress] = None
active_connection_id_limit: Optional[int] = None
initial_source_connection_id: Optional[bytes] = None
retry_source_connection_id: Optional[bytes] = None
version_information: Optional[QuicVersionInformation] = None
max_datagram_frame_size: Optional[int] = None
quantum_readiness: Optional[bytes] = None
PARAMS = {
0x00: ("original_destination_connection_id", bytes),
0x01: ("max_idle_timeout", int),
0x02: ("stateless_reset_token", bytes),
0x03: ("max_udp_payload_size", int),
0x04: ("initial_max_data", int),
0x05: ("initial_max_stream_data_bidi_local", int),
0x06: ("initial_max_stream_data_bidi_remote", int),
0x07: ("initial_max_stream_data_uni", int),
0x08: ("initial_max_streams_bidi", int),
0x09: ("initial_max_streams_uni", int),
0x0A: ("ack_delay_exponent", int),
0x0B: ("max_ack_delay", int),
0x0C: ("disable_active_migration", bool),
0x0D: ("preferred_address", QuicPreferredAddress),
0x0E: ("active_connection_id_limit", int),
0x0F: ("initial_source_connection_id", bytes),
0x10: ("retry_source_connection_id", bytes),
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
0x11: ("version_information", QuicVersionInformation),
# extensions
0x0020: ("max_datagram_frame_size", int),
0x0C37: ("quantum_readiness", bytes),
}
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
ipv4_address = None
ipv4_host = buf.pull_bytes(4)
ipv4_port = buf.pull_uint16()
if ipv4_host != bytes(4):
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
ipv6_address = None
ipv6_host = buf.pull_bytes(16)
ipv6_port = buf.pull_uint16()
if ipv6_host != bytes(16):
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
connection_id_length = buf.pull_uint8()
connection_id = buf.pull_bytes(connection_id_length)
stateless_reset_token = buf.pull_bytes(16)
return QuicPreferredAddress(
ipv4_address=ipv4_address,
ipv6_address=ipv6_address,
connection_id=connection_id,
stateless_reset_token=stateless_reset_token,
)
def push_quic_preferred_address(
buf: Buffer, preferred_address: QuicPreferredAddress
) -> None:
if preferred_address.ipv4_address is not None:
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
buf.push_uint16(preferred_address.ipv4_address[1])
else:
buf.push_bytes(bytes(6))
if preferred_address.ipv6_address is not None:
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
buf.push_uint16(preferred_address.ipv6_address[1])
else:
buf.push_bytes(bytes(18))
buf.push_uint8(len(preferred_address.connection_id))
buf.push_bytes(preferred_address.connection_id)
buf.push_bytes(preferred_address.stateless_reset_token)
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
chosen_version = buf.pull_uint32()
available_versions = []
for i in range(length // 4 - 1):
available_versions.append(buf.pull_uint32())
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
# equal to zero, it MUST treat it as a parsing failure.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if chosen_version == 0 or 0 in available_versions:
raise ValueError("Version Information must not contain version 0")
return QuicVersionInformation(
chosen_version=chosen_version,
available_versions=available_versions,
)
def push_quic_version_information(
buf: Buffer, version_information: QuicVersionInformation
) -> None:
buf.push_uint32(version_information.chosen_version)
for version in version_information.available_versions:
buf.push_uint32(version)
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
params = QuicTransportParameters()
while not buf.eof():
param_id = buf.pull_uint_var()
param_len = buf.pull_uint_var()
param_start = buf.tell()
if param_id in PARAMS:
# Parse known parameter.
param_name, param_type = PARAMS[param_id]
if param_type is int:
setattr(params, param_name, buf.pull_uint_var())
elif param_type is bytes:
setattr(params, param_name, buf.pull_bytes(param_len))
elif param_type is QuicPreferredAddress:
setattr(params, param_name, pull_quic_preferred_address(buf))
elif param_type is QuicVersionInformation:
setattr(
params,
param_name,
pull_quic_version_information(buf, param_len),
)
else:
setattr(params, param_name, True)
else:
# Skip unknown parameter.
buf.pull_bytes(param_len)
if buf.tell() != param_start + param_len:
raise ValueError("Transport parameter length does not match")
return params
def push_quic_transport_parameters(
buf: Buffer, params: QuicTransportParameters
) -> None:
for param_id, (param_name, param_type) in PARAMS.items():
param_value = getattr(params, param_name)
if param_value is not None and param_value is not False:
param_buf = Buffer(capacity=65536)
if param_type is int:
param_buf.push_uint_var(param_value)
elif param_type is bytes:
param_buf.push_bytes(param_value)
elif param_type is QuicPreferredAddress:
push_quic_preferred_address(param_buf, param_value)
elif param_type is QuicVersionInformation:
push_quic_version_information(param_buf, param_value)
buf.push_uint_var(param_id)
buf.push_uint_var(param_buf.tell())
buf.push_bytes(param_buf.data)
# FRAMES
class QuicFrameType(IntEnum):
PADDING = 0x00
PING = 0x01
ACK = 0x02
ACK_ECN = 0x03
RESET_STREAM = 0x04
STOP_SENDING = 0x05
CRYPTO = 0x06
NEW_TOKEN = 0x07
STREAM_BASE = 0x08
MAX_DATA = 0x10
MAX_STREAM_DATA = 0x11
MAX_STREAMS_BIDI = 0x12
MAX_STREAMS_UNI = 0x13
DATA_BLOCKED = 0x14
STREAM_DATA_BLOCKED = 0x15
STREAMS_BLOCKED_BIDI = 0x16
STREAMS_BLOCKED_UNI = 0x17
NEW_CONNECTION_ID = 0x18
RETIRE_CONNECTION_ID = 0x19
PATH_CHALLENGE = 0x1A
PATH_RESPONSE = 0x1B
TRANSPORT_CLOSE = 0x1C
APPLICATION_CLOSE = 0x1D
HANDSHAKE_DONE = 0x1E
DATAGRAM = 0x30
DATAGRAM_WITH_LENGTH = 0x31
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.PADDING,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
PROBING_FRAME_TYPES = frozenset(
[
QuicFrameType.PATH_CHALLENGE,
QuicFrameType.PATH_RESPONSE,
QuicFrameType.PADDING,
QuicFrameType.NEW_CONNECTION_ID,
]
)
@dataclass
class QuicResetStreamFrame:
error_code: int
final_size: int
stream_id: int
@dataclass
class QuicStopSendingFrame:
error_code: int
stream_id: int
@dataclass
class QuicStreamFrame:
data: bytes = b""
fin: bool = False
offset: int = 0
def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]:
rangeset = RangeSet()
end = buf.pull_uint_var() # largest acknowledged
delay = buf.pull_uint_var()
ack_range_count = buf.pull_uint_var()
ack_count = buf.pull_uint_var() # first ack range
rangeset.add(end - ack_count, end + 1)
end -= ack_count
for _ in range(ack_range_count):
end -= buf.pull_uint_var() + 2
ack_count = buf.pull_uint_var()
rangeset.add(end - ack_count, end + 1)
end -= ack_count
return rangeset, delay
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
ranges = len(rangeset)
index = ranges - 1
r = rangeset[index]
buf.push_uint_var(r.stop - 1)
buf.push_uint_var(delay)
buf.push_uint_var(index)
buf.push_uint_var(r.stop - 1 - r.start)
start = r.start
while index > 0:
index -= 1
r = rangeset[index]
buf.push_uint_var(start - r.stop - 1)
buf.push_uint_var(r.stop - r.start - 1)
start = r.start
return ranges

View File

@ -0,0 +1,384 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from ..buffer import Buffer, size_uint_var
from ..tls import Epoch
from .crypto import CryptoPair
from .logger import QuicLoggerTrace
from .packet import (
NON_ACK_ELICITING_FRAME_TYPES,
NON_IN_FLIGHT_FRAME_TYPES,
PACKET_FIXED_BIT,
PACKET_NUMBER_MAX_SIZE,
QuicFrameType,
QuicPacketType,
encode_long_header_first_byte,
)
PACKET_LENGTH_SEND_SIZE = 2
PACKET_NUMBER_SEND_SIZE = 2
QuicDeliveryHandler = Callable[..., None]
class QuicDeliveryState(Enum):
ACKED = 0
LOST = 1
@dataclass
class QuicSentPacket:
epoch: Epoch
in_flight: bool
is_ack_eliciting: bool
is_crypto_packet: bool
packet_number: int
packet_type: QuicPacketType
sent_time: Optional[float] = None
sent_bytes: int = 0
delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field(
default_factory=list
)
quic_logger_frames: List[Dict] = field(default_factory=list)
class QuicPacketBuilderStop(Exception):
pass
class QuicPacketBuilder:
"""
Helper for building QUIC packets.
"""
def __init__(
self,
*,
host_cid: bytes,
peer_cid: bytes,
version: int,
is_client: bool,
max_datagram_size: int,
packet_number: int = 0,
peer_token: bytes = b"",
quic_logger: Optional[QuicLoggerTrace] = None,
spin_bit: bool = False,
):
self.max_flight_bytes: Optional[int] = None
self.max_total_bytes: Optional[int] = None
self.quic_logger_frames: Optional[List[Dict]] = None
self._host_cid = host_cid
self._is_client = is_client
self._peer_cid = peer_cid
self._peer_token = peer_token
self._quic_logger = quic_logger
self._spin_bit = spin_bit
self._version = version
# assembled datagrams and packets
self._datagrams: List[bytes] = []
self._datagram_flight_bytes = 0
self._datagram_init = True
self._datagram_needs_padding = False
self._packets: List[QuicSentPacket] = []
self._flight_bytes = 0
self._total_bytes = 0
# current packet
self._header_size = 0
self._packet: Optional[QuicSentPacket] = None
self._packet_crypto: Optional[CryptoPair] = None
self._packet_number = packet_number
self._packet_start = 0
self._packet_type: Optional[QuicPacketType] = None
self._buffer = Buffer(max_datagram_size)
self._buffer_capacity = max_datagram_size
self._flight_capacity = max_datagram_size
@property
def packet_is_empty(self) -> bool:
"""
Returns `True` if the current packet is empty.
"""
assert self._packet is not None
packet_size = self._buffer.tell() - self._packet_start
return packet_size <= self._header_size
@property
def packet_number(self) -> int:
"""
Returns the packet number for the next packet.
"""
return self._packet_number
@property
def remaining_buffer_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._buffer_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
@property
def remaining_flight_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._flight_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]:
"""
Returns the assembled datagrams.
"""
if self._packet is not None:
self._end_packet()
self._flush_current_datagram()
datagrams = self._datagrams
packets = self._packets
self._datagrams = []
self._packets = []
return datagrams, packets
def start_frame(
self,
frame_type: int,
capacity: int = 1,
handler: Optional[QuicDeliveryHandler] = None,
handler_args: Sequence[Any] = [],
) -> Buffer:
"""
Starts a new frame.
"""
if self.remaining_buffer_space < capacity or (
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
and self.remaining_flight_space < capacity
):
raise QuicPacketBuilderStop
self._buffer.push_uint_var(frame_type)
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
self._packet.is_ack_eliciting = True
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
self._packet.in_flight = True
if frame_type == QuicFrameType.CRYPTO:
self._packet.is_crypto_packet = True
if handler is not None:
self._packet.delivery_handlers.append((handler, handler_args))
return self._buffer
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
"""
Starts a new packet.
"""
assert packet_type in (
QuicPacketType.INITIAL,
QuicPacketType.HANDSHAKE,
QuicPacketType.ZERO_RTT,
QuicPacketType.ONE_RTT,
), "Invalid packet type"
buf = self._buffer
# finish previous datagram
if self._packet is not None:
self._end_packet()
# if there is too little space remaining, start a new datagram
# FIXME: the limit is arbitrary!
packet_start = buf.tell()
if self._buffer_capacity - packet_start < 128:
self._flush_current_datagram()
packet_start = 0
# initialize datagram if needed
if self._datagram_init:
if self.max_total_bytes is not None:
remaining_total_bytes = self.max_total_bytes - self._total_bytes
if remaining_total_bytes < self._buffer_capacity:
self._buffer_capacity = remaining_total_bytes
self._flight_capacity = self._buffer_capacity
if self.max_flight_bytes is not None:
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
if remaining_flight_bytes < self._flight_capacity:
self._flight_capacity = remaining_flight_bytes
self._datagram_flight_bytes = 0
self._datagram_init = False
self._datagram_needs_padding = False
# calculate header size
if packet_type != QuicPacketType.ONE_RTT:
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
if packet_type == QuicPacketType.INITIAL:
token_length = len(self._peer_token)
header_size += size_uint_var(token_length) + token_length
else:
header_size = 3 + len(self._peer_cid)
# check we have enough space
if packet_start + header_size >= self._buffer_capacity:
raise QuicPacketBuilderStop
# determine ack epoch
if packet_type == QuicPacketType.INITIAL:
epoch = Epoch.INITIAL
elif packet_type == QuicPacketType.HANDSHAKE:
epoch = Epoch.HANDSHAKE
else:
epoch = Epoch.ONE_RTT
self._header_size = header_size
self._packet = QuicSentPacket(
epoch=epoch,
in_flight=False,
is_ack_eliciting=False,
is_crypto_packet=False,
packet_number=self._packet_number,
packet_type=packet_type,
)
self._packet_crypto = crypto
self._packet_start = packet_start
self._packet_type = packet_type
self.quic_logger_frames = self._packet.quic_logger_frames
buf.seek(self._packet_start + self._header_size)
def _end_packet(self) -> None:
"""
Ends the current packet.
"""
buf = self._buffer
packet_size = buf.tell() - self._packet_start
if packet_size > self._header_size:
# padding to ensure sufficient sample size
padding_size = (
PACKET_NUMBER_MAX_SIZE
- PACKET_NUMBER_SEND_SIZE
+ self._header_size
- packet_size
)
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if (
self._is_client or self._packet.is_ack_eliciting
) and self._packet_type == QuicPacketType.INITIAL:
self._datagram_needs_padding = True
# For datagrams containing 1-RTT data, we *must* apply the padding
# inside the packet, we cannot tack bytes onto the end of the
# datagram.
if (
self._datagram_needs_padding
and self._packet_type == QuicPacketType.ONE_RTT
):
if self.remaining_flight_space > padding_size:
padding_size = self.remaining_flight_space
self._datagram_needs_padding = False
# write padding
if padding_size > 0:
buf.push_bytes(bytes(padding_size))
packet_size += padding_size
self._packet.in_flight = True
# log frame
if self._quic_logger is not None:
self._packet.quic_logger_frames.append(
self._quic_logger.encode_padding_frame()
)
# write header
if self._packet_type != QuicPacketType.ONE_RTT:
length = (
packet_size
- self._header_size
+ PACKET_NUMBER_SEND_SIZE
+ self._packet_crypto.aead_tag_size
)
buf.seek(self._packet_start)
buf.push_uint8(
encode_long_header_first_byte(
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
)
)
buf.push_uint32(self._version)
buf.push_uint8(len(self._peer_cid))
buf.push_bytes(self._peer_cid)
buf.push_uint8(len(self._host_cid))
buf.push_bytes(self._host_cid)
if self._packet_type == QuicPacketType.INITIAL:
buf.push_uint_var(len(self._peer_token))
buf.push_bytes(self._peer_token)
buf.push_uint16(length | 0x4000)
buf.push_uint16(self._packet_number & 0xFFFF)
else:
buf.seek(self._packet_start)
buf.push_uint8(
PACKET_FIXED_BIT
| (self._spin_bit << 5)
| (self._packet_crypto.key_phase << 2)
| (PACKET_NUMBER_SEND_SIZE - 1)
)
buf.push_bytes(self._peer_cid)
buf.push_uint16(self._packet_number & 0xFFFF)
# encrypt in place
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
buf.seek(self._packet_start)
buf.push_bytes(
self._packet_crypto.encrypt_packet(
plain[0 : self._header_size],
plain[self._header_size : packet_size],
self._packet_number,
)
)
self._packet.sent_bytes = buf.tell() - self._packet_start
self._packets.append(self._packet)
if self._packet.in_flight:
self._datagram_flight_bytes += self._packet.sent_bytes
# Short header packets cannot be coalesced, we need a new datagram.
if self._packet_type == QuicPacketType.ONE_RTT:
self._flush_current_datagram()
self._packet_number += 1
else:
# "cancel" the packet
buf.seek(self._packet_start)
self._packet = None
self.quic_logger_frames = None
def _flush_current_datagram(self) -> None:
datagram_bytes = self._buffer.tell()
if datagram_bytes:
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if self._datagram_needs_padding:
extra_bytes = self._flight_capacity - self._buffer.tell()
if extra_bytes > 0:
self._buffer.push_bytes(bytes(extra_bytes))
self._datagram_flight_bytes += extra_bytes
datagram_bytes += extra_bytes
self._datagrams.append(self._buffer.data)
self._flight_bytes += self._datagram_flight_bytes
self._total_bytes += datagram_bytes
self._datagram_init = True
self._buffer.seek(0)

View File

@ -0,0 +1,98 @@
from collections.abc import Sequence
from typing import Any, Iterable, List, Optional
class RangeSet(Sequence):
def __init__(self, ranges: Iterable[range] = []):
self.__ranges: List[range] = []
for r in ranges:
assert r.step == 1
self.add(r.start, r.stop)
def add(self, start: int, stop: Optional[int] = None) -> None:
if stop is None:
stop = start + 1
assert stop > start
for i, r in enumerate(self.__ranges):
# the added range is entirely before current item, insert here
if stop < r.start:
self.__ranges.insert(i, range(start, stop))
return
# the added range is entirely after current item, keep looking
if start > r.stop:
continue
# the added range touches the current item, merge it
start = min(start, r.start)
stop = max(stop, r.stop)
while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop:
stop = max(self.__ranges[i + 1].stop, stop)
self.__ranges.pop(i + 1)
self.__ranges[i] = range(start, stop)
return
# the added range is entirely after all existing items, append it
self.__ranges.append(range(start, stop))
def bounds(self) -> range:
return range(self.__ranges[0].start, self.__ranges[-1].stop)
def shift(self) -> range:
return self.__ranges.pop(0)
def subtract(self, start: int, stop: int) -> None:
assert stop > start
i = 0
while i < len(self.__ranges):
r = self.__ranges[i]
# the removed range is entirely before current item, stop here
if stop <= r.start:
return
# the removed range is entirely after current item, keep looking
if start >= r.stop:
i += 1
continue
# the removed range completely covers the current item, remove it
if start <= r.start and stop >= r.stop:
self.__ranges.pop(i)
continue
# the removed range touches the current item
if start > r.start:
self.__ranges[i] = range(r.start, start)
if stop < r.stop:
self.__ranges.insert(i + 1, range(stop, r.stop))
else:
self.__ranges[i] = range(stop, r.stop)
i += 1
def __bool__(self) -> bool:
raise NotImplementedError
def __contains__(self, val: Any) -> bool:
for r in self.__ranges:
if val in r:
return True
return False
def __eq__(self, other: object) -> bool:
if not isinstance(other, RangeSet):
return NotImplemented
return self.__ranges == other.__ranges
def __getitem__(self, key: Any) -> range:
return self.__ranges[key]
def __len__(self) -> int:
return len(self.__ranges)
def __repr__(self) -> str:
return "RangeSet({})".format(repr(self.__ranges))

View File

@ -0,0 +1,389 @@
import logging
import math
from typing import Any, Callable, Dict, Iterable, List, Optional
from .congestion import cubic, reno # noqa
from .congestion.base import K_GRANULARITY, create_congestion_control
from .logger import QuicLoggerTrace
from .packet_builder import QuicDeliveryState, QuicSentPacket
from .rangeset import RangeSet
# loss detection
K_PACKET_THRESHOLD = 3
K_TIME_THRESHOLD = 9 / 8
K_MICRO_SECOND = 0.000001
K_SECOND = 1.0
class QuicPacketSpace:
def __init__(self) -> None:
self.ack_at: Optional[float] = None
self.ack_queue = RangeSet()
self.discarded = False
self.expected_packet_number = 0
self.largest_received_packet = -1
self.largest_received_time: Optional[float] = None
# sent packets and loss
self.ack_eliciting_in_flight = 0
self.largest_acked_packet = 0
self.loss_time: Optional[float] = None
self.sent_packets: Dict[int, QuicSentPacket] = {}
class QuicPacketPacer:
def __init__(self, *, max_datagram_size: int) -> None:
self._max_datagram_size = max_datagram_size
self.bucket_max: float = 0.0
self.bucket_time: float = 0.0
self.evaluation_time: float = 0.0
self.packet_time: Optional[float] = None
def next_send_time(self, now: float) -> float:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time <= 0:
return now + self.packet_time
return None
def update_after_send(self, now: float) -> None:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time < self.packet_time:
self.bucket_time = 0.0
else:
self.bucket_time -= self.packet_time
def update_bucket(self, now: float) -> None:
if now > self.evaluation_time:
self.bucket_time = min(
self.bucket_time + (now - self.evaluation_time), self.bucket_max
)
self.evaluation_time = now
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None:
pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND)
self.packet_time = max(
K_MICRO_SECOND, min(self._max_datagram_size / pacing_rate, K_SECOND)
)
self.bucket_max = (
max(
2 * self._max_datagram_size,
min(congestion_window // 4, 16 * self._max_datagram_size),
)
/ pacing_rate
)
if self.bucket_time > self.bucket_max:
self.bucket_time = self.bucket_max
class QuicPacketRecovery:
"""
Packet loss and congestion controller.
"""
def __init__(
self,
*,
congestion_control_algorithm: str,
initial_rtt: float,
max_datagram_size: int,
peer_completed_address_validation: bool,
send_probe: Callable[[], None],
logger: Optional[logging.LoggerAdapter] = None,
quic_logger: Optional[QuicLoggerTrace] = None,
) -> None:
self.max_ack_delay = 0.025
self.peer_completed_address_validation = peer_completed_address_validation
self.spaces: List[QuicPacketSpace] = []
# callbacks
self._logger = logger
self._quic_logger = quic_logger
self._send_probe = send_probe
# loss detection
self._pto_count = 0
self._rtt_initial = initial_rtt
self._rtt_initialized = False
self._rtt_latest = 0.0
self._rtt_min = math.inf
self._rtt_smoothed = 0.0
self._rtt_variance = 0.0
self._time_of_last_sent_ack_eliciting_packet = 0.0
# congestion control
self._cc = create_congestion_control(
congestion_control_algorithm, max_datagram_size=max_datagram_size
)
self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size)
@property
def bytes_in_flight(self) -> int:
return self._cc.bytes_in_flight
@property
def congestion_window(self) -> int:
return self._cc.congestion_window
def discard_space(self, space: QuicPacketSpace) -> None:
assert space in self.spaces
self._cc.on_packets_expired(
packets=filter(lambda x: x.in_flight, space.sent_packets.values())
)
space.sent_packets.clear()
space.ack_at = None
space.ack_eliciting_in_flight = 0
space.loss_time = None
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated()
def get_loss_detection_time(self) -> float:
# loss timer
loss_space = self._get_loss_space()
if loss_space is not None:
return loss_space.loss_time
# packet timer
if (
not self.peer_completed_address_validation
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
):
timeout = self.get_probe_timeout() * (2**self._pto_count)
return self._time_of_last_sent_ack_eliciting_packet + timeout
return None
def get_probe_timeout(self) -> float:
if not self._rtt_initialized:
return 2 * self._rtt_initial
return (
self._rtt_smoothed
+ max(4 * self._rtt_variance, K_GRANULARITY)
+ self.max_ack_delay
)
def on_ack_received(
self,
*,
ack_rangeset: RangeSet,
ack_delay: float,
now: float,
space: QuicPacketSpace,
) -> None:
"""
Update metrics as the result of an ACK being received.
"""
is_ack_eliciting = False
largest_acked = ack_rangeset.bounds().stop - 1
largest_newly_acked = None
largest_sent_time = None
if largest_acked > space.largest_acked_packet:
space.largest_acked_packet = largest_acked
for packet_number in sorted(space.sent_packets.keys()):
if packet_number > largest_acked:
break
if packet_number in ack_rangeset:
# remove packet and update counters
packet = space.sent_packets.pop(packet_number)
if packet.is_ack_eliciting:
is_ack_eliciting = True
space.ack_eliciting_in_flight -= 1
if packet.in_flight:
self._cc.on_packet_acked(packet=packet, now=now)
largest_newly_acked = packet_number
largest_sent_time = packet.sent_time
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.ACKED, *args)
# nothing to do if there are no newly acked packets
if largest_newly_acked is None:
return
if largest_acked == largest_newly_acked and is_ack_eliciting:
latest_rtt = now - largest_sent_time
log_rtt = True
# limit ACK delay to max_ack_delay
ack_delay = min(ack_delay, self.max_ack_delay)
# update RTT estimate, which cannot be < 1 ms
self._rtt_latest = max(latest_rtt, 0.001)
if self._rtt_latest < self._rtt_min:
self._rtt_min = self._rtt_latest
if self._rtt_latest > self._rtt_min + ack_delay:
self._rtt_latest -= ack_delay
if not self._rtt_initialized:
self._rtt_initialized = True
self._rtt_variance = latest_rtt / 2
self._rtt_smoothed = latest_rtt
else:
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
self._rtt_min - self._rtt_latest
)
self._rtt_smoothed = (
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
)
# inform congestion controller
self._cc.on_rtt_measurement(now=now, rtt=latest_rtt)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
else:
log_rtt = False
self._detect_loss(now=now, space=space)
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated(log_rtt=log_rtt)
def on_loss_detection_timeout(self, *, now: float) -> None:
loss_space = self._get_loss_space()
if loss_space is not None:
self._detect_loss(now=now, space=loss_space)
else:
self._pto_count += 1
self.reschedule_data(now=now)
def on_packet_sent(self, *, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
space.sent_packets[packet.packet_number] = packet
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight += 1
if packet.in_flight:
if packet.is_ack_eliciting:
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
# add packet to bytes in flight
self._cc.on_packet_sent(packet=packet)
if self._quic_logger is not None:
self._log_metrics_updated()
def reschedule_data(self, *, now: float) -> None:
"""
Schedule some data for retransmission.
"""
# if there is any outstanding CRYPTO, retransmit it
crypto_scheduled = False
for space in self.spaces:
packets = tuple(
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
)
if packets:
self._on_packets_lost(now=now, packets=packets, space=space)
crypto_scheduled = True
if crypto_scheduled and self._logger is not None:
self._logger.debug("Scheduled CRYPTO data for retransmission")
# ensure an ACK-elliciting packet is sent
self._send_probe()
def _detect_loss(self, *, now: float, space: QuicPacketSpace) -> None:
"""
Check whether any packets should be declared lost.
"""
loss_delay = K_TIME_THRESHOLD * (
max(self._rtt_latest, self._rtt_smoothed)
if self._rtt_initialized
else self._rtt_initial
)
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
time_threshold = now - loss_delay
lost_packets = []
space.loss_time = None
for packet_number, packet in space.sent_packets.items():
if packet_number > space.largest_acked_packet:
break
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
lost_packets.append(packet)
else:
packet_loss_time = packet.sent_time + loss_delay
if space.loss_time is None or space.loss_time > packet_loss_time:
space.loss_time = packet_loss_time
self._on_packets_lost(now=now, packets=lost_packets, space=space)
def _get_loss_space(self) -> Optional[QuicPacketSpace]:
loss_space = None
for space in self.spaces:
if space.loss_time is not None and (
loss_space is None or space.loss_time < loss_space.loss_time
):
loss_space = space
return loss_space
def _log_metrics_updated(self, log_rtt=False) -> None:
data: Dict[str, Any] = self._cc.get_log_data()
if log_rtt:
data.update(
{
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
}
)
self._quic_logger.log_event(
category="recovery", event="metrics_updated", data=data
)
def _on_packets_lost(
self, *, now: float, packets: Iterable[QuicSentPacket], space: QuicPacketSpace
) -> None:
lost_packets_cc = []
for packet in packets:
del space.sent_packets[packet.packet_number]
if packet.in_flight:
lost_packets_cc.append(packet)
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight -= 1
if self._quic_logger is not None:
self._quic_logger.log_event(
category="recovery",
event="packet_lost",
data={
"type": self._quic_logger.packet_type(packet.packet_type),
"packet_number": packet.packet_number,
},
)
self._log_metrics_updated()
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.LOST, *args)
# inform congestion controller
if lost_packets_cc:
self._cc.on_packets_lost(now=now, packets=lost_packets_cc)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
if self._quic_logger is not None:
self._log_metrics_updated()

View File

@ -0,0 +1,53 @@
import ipaddress
from typing import Tuple
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from ..buffer import Buffer
from ..tls import pull_opaque, push_opaque
from .connection import NetworkAddress
def encode_address(addr: NetworkAddress) -> bytes:
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
class QuicRetryTokenHandler:
def __init__(self) -> None:
self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
def create_token(
self,
addr: NetworkAddress,
original_destination_connection_id: bytes,
retry_source_connection_id: bytes,
) -> bytes:
buf = Buffer(capacity=512)
push_opaque(buf, 1, encode_address(addr))
push_opaque(buf, 1, original_destination_connection_id)
push_opaque(buf, 1, retry_source_connection_id)
return self._key.public_key().encrypt(
buf.data,
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
),
)
def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]:
buf = Buffer(
data=self._key.decrypt(
token,
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
)
encoded_addr = pull_opaque(buf, 1)
original_destination_connection_id = pull_opaque(buf, 1)
retry_source_connection_id = pull_opaque(buf, 1)
if encoded_addr != encode_address(addr):
raise ValueError("Remote address does not match.")
return original_destination_connection_id, retry_source_connection_id

View File

@ -0,0 +1,364 @@
from typing import Optional
from . import events
from .packet import (
QuicErrorCode,
QuicResetStreamFrame,
QuicStopSendingFrame,
QuicStreamFrame,
)
from .packet_builder import QuicDeliveryState
from .rangeset import RangeSet
class FinalSizeError(Exception):
pass
class StreamFinishedError(Exception):
pass
class QuicStreamReceiver:
"""
The receive part of a QUIC stream.
It finishes:
- immediately for a send-only stream
- upon reception of a STREAM_RESET frame
- upon reception of a data frame with the FIN bit set
"""
def __init__(self, stream_id: Optional[int], readable: bool) -> None:
self.highest_offset = 0 # the highest offset ever seen
self.is_finished = False
self.stop_pending = False
self._buffer = bytearray()
self._buffer_start = 0 # the offset for the start of the buffer
self._final_size: Optional[int] = None
self._ranges = RangeSet()
self._stream_id = stream_id
self._stop_error_code: Optional[int] = None
def get_stop_frame(self) -> QuicStopSendingFrame:
self.stop_pending = False
return QuicStopSendingFrame(
error_code=self._stop_error_code,
stream_id=self._stream_id,
)
def starting_offset(self) -> int:
return self._buffer_start
def handle_frame(
self, frame: QuicStreamFrame
) -> Optional[events.StreamDataReceived]:
"""
Handle a frame of received data.
"""
pos = frame.offset - self._buffer_start
count = len(frame.data)
frame_end = frame.offset + count
# we should receive no more data beyond FIN!
if self._final_size is not None:
if frame_end > self._final_size:
raise FinalSizeError("Data received beyond final size")
elif frame.fin and frame_end != self._final_size:
raise FinalSizeError("Cannot change final size")
if frame.fin:
self._final_size = frame_end
if frame_end > self.highest_offset:
self.highest_offset = frame_end
# fast path: new in-order chunk
if pos == 0 and count and not self._buffer:
self._buffer_start += count
if frame.fin:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
return events.StreamDataReceived(
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
)
# discard duplicate data
if pos < 0:
frame.data = frame.data[-pos:]
frame.offset -= pos
pos = 0
count = len(frame.data)
# marked received range
if frame_end > frame.offset:
self._ranges.add(frame.offset, frame_end)
# add new data
gap = pos - len(self._buffer)
if gap > 0:
self._buffer += bytearray(gap)
self._buffer[pos : pos + count] = frame.data
# return data from the front of the buffer
data = self._pull_data()
end_stream = self._buffer_start == self._final_size
if end_stream:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
if data or end_stream:
return events.StreamDataReceived(
data=data, end_stream=end_stream, stream_id=self._stream_id
)
else:
return None
def handle_reset(
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
) -> Optional[events.StreamReset]:
"""
Handle an abrupt termination of the receiving part of the QUIC stream.
"""
if self._final_size is not None and final_size != self._final_size:
raise FinalSizeError("Cannot change final size")
# we are done receiving
self._final_size = final_size
self.is_finished = True
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a STOP_SENDING is ACK'd.
"""
if delivery != QuicDeliveryState.ACKED:
self.stop_pending = True
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
"""
Request the peer stop sending data on the QUIC stream.
"""
self._stop_error_code = error_code
self.stop_pending = True
def _pull_data(self) -> bytes:
"""
Remove data from the front of the buffer.
"""
try:
has_data_to_read = self._ranges[0].start == self._buffer_start
except IndexError:
has_data_to_read = False
if not has_data_to_read:
return b""
r = self._ranges.shift()
pos = r.stop - r.start
data = bytes(self._buffer[:pos])
del self._buffer[:pos]
self._buffer_start = r.stop
return data
class QuicStreamSender:
"""
The send part of a QUIC stream.
It finishes:
- immediately for a receive-only stream
- upon acknowledgement of a STREAM_RESET frame
- upon acknowledgement of a data frame with the FIN bit set
"""
def __init__(self, stream_id: Optional[int], writable: bool) -> None:
self.buffer_is_empty = True
self.highest_offset = 0
self.is_finished = not writable
self.reset_pending = False
self._acked = RangeSet()
self._acked_fin = False
self._buffer = bytearray()
self._buffer_fin: Optional[int] = None
self._buffer_start = 0 # the offset for the start of the buffer
self._buffer_stop = 0 # the offset for the stop of the buffer
self._pending = RangeSet()
self._pending_eof = False
self._reset_error_code: Optional[int] = None
self._stream_id = stream_id
@property
def next_offset(self) -> int:
"""
The offset for the next frame to send.
This is used to determine the space needed for the frame's `offset` field.
"""
try:
return self._pending[0].start
except IndexError:
return self._buffer_stop
def get_frame(
self, max_size: int, max_offset: Optional[int] = None
) -> Optional[QuicStreamFrame]:
"""
Get a frame of data to send.
"""
assert self._reset_error_code is None, "cannot call get_frame() after reset()"
# get the first pending data range
try:
r = self._pending[0]
except IndexError:
if self._pending_eof:
# FIN only
self._pending_eof = False
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
self.buffer_is_empty = True
return None
# apply flow control
start = r.start
stop = min(r.stop, start + max_size)
if max_offset is not None and stop > max_offset:
stop = max_offset
if stop <= start:
return None
# create frame
frame = QuicStreamFrame(
data=bytes(
self._buffer[start - self._buffer_start : stop - self._buffer_start]
),
offset=start,
)
self._pending.subtract(start, stop)
# track the highest offset ever sent
if stop > self.highest_offset:
self.highest_offset = stop
# if the buffer is empty and EOF was written, set the FIN bit
if self._buffer_fin == stop:
frame.fin = True
self._pending_eof = False
return frame
def get_reset_frame(self) -> QuicResetStreamFrame:
self.reset_pending = False
return QuicResetStreamFrame(
error_code=self._reset_error_code,
final_size=self.highest_offset,
stream_id=self._stream_id,
)
def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
) -> None:
"""
Callback when sent data is ACK'd.
"""
# If the frame had the FIN bit set, its end MUST match otherwise
# we have a programming error.
assert (
not fin or stop == self._buffer_fin
), "on_data_delivered() was called with inconsistent fin / stop"
# If a reset has been requested, stop processing data delivery.
# The transition to the finished state only depends on the reset
# being acknowledged.
if self._reset_error_code is not None:
return
if delivery == QuicDeliveryState.ACKED:
if stop > start:
# Some data has been ACK'd, discard it.
self._acked.add(start, stop)
first_range = self._acked[0]
if first_range.start == self._buffer_start:
size = first_range.stop - first_range.start
self._acked.shift()
self._buffer_start += size
del self._buffer[:size]
if fin:
# The FIN has been ACK'd.
self._acked_fin = True
if self._buffer_start == self._buffer_fin and self._acked_fin:
# All data and the FIN have been ACK'd, we're done sending.
self.is_finished = True
else:
if stop > start:
# Some data has been lost, reschedule it.
self.buffer_is_empty = False
self._pending.add(start, stop)
if fin:
# The FIN has been lost, reschedule it.
self.buffer_is_empty = False
self._pending_eof = True
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a reset is ACK'd.
"""
if delivery == QuicDeliveryState.ACKED:
# The reset has been ACK'd, we're done sending.
self.is_finished = True
else:
# The reset has been lost, reschedule it.
self.reset_pending = True
def reset(self, error_code: int) -> None:
"""
Abruptly terminate the sending part of the QUIC stream.
"""
assert self._reset_error_code is None, "cannot call reset() more than once"
self._reset_error_code = error_code
self.reset_pending = True
# Prevent any more data from being sent or re-sent.
self.buffer_is_empty = True
def write(self, data: bytes, end_stream: bool = False) -> None:
"""
Write some data bytes to the QUIC stream.
"""
assert self._buffer_fin is None, "cannot call write() after FIN"
assert self._reset_error_code is None, "cannot call write() after reset()"
size = len(data)
if size:
self.buffer_is_empty = False
self._pending.add(self._buffer_stop, self._buffer_stop + size)
self._buffer += data
self._buffer_stop += size
if end_stream:
self.buffer_is_empty = False
self._buffer_fin = self._buffer_stop
self._pending_eof = True
class QuicStream:
def __init__(
self,
stream_id: Optional[int] = None,
max_stream_data_local: int = 0,
max_stream_data_remote: int = 0,
readable: bool = True,
writable: bool = True,
) -> None:
self.is_blocked = False
self.max_stream_data_local = max_stream_data_local
self.max_stream_data_local_sent = max_stream_data_local
self.max_stream_data_remote = max_stream_data_remote
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
self.stream_id = stream_id
@property
def is_finished(self) -> bool:
return self.receiver.is_finished and self.sender.is_finished

File diff suppressed because it is too large Load Diff