Code
This commit is contained in:
@ -0,0 +1 @@
|
||||
__version__ = "1.2.0"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Code/venv/lib/python3.13/site-packages/aioquic/_buffer.abi3.so
Executable file
BIN
Code/venv/lib/python3.13/site-packages/aioquic/_buffer.abi3.so
Executable file
Binary file not shown.
422
Code/venv/lib/python3.13/site-packages/aioquic/_buffer.c
Normal file
422
Code/venv/lib/python3.13/site-packages/aioquic/_buffer.c
Normal file
@ -0,0 +1,422 @@
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
|
||||
#include <Python.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define MODULE_NAME "aioquic._buffer"
|
||||
|
||||
static PyObject *BufferReadError;
|
||||
static PyObject *BufferWriteError;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
uint8_t *base;
|
||||
uint8_t *end;
|
||||
uint8_t *pos;
|
||||
} BufferObject;
|
||||
|
||||
static PyObject *BufferType;
|
||||
|
||||
#define CHECK_READ_BOUNDS(self, len) \
|
||||
if (len < 0 || self->pos + len > self->end) { \
|
||||
PyErr_SetString(BufferReadError, "Read out of bounds"); \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
#define CHECK_WRITE_BOUNDS(self, len) \
|
||||
if (self->pos + len > self->end) { \
|
||||
PyErr_SetString(BufferWriteError, "Write out of bounds"); \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
static int
|
||||
Buffer_init(BufferObject *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
const char *kwlist[] = {"capacity", "data", NULL};
|
||||
Py_ssize_t capacity = 0;
|
||||
const unsigned char *data = NULL;
|
||||
Py_ssize_t data_len = 0;
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ny#", (char**)kwlist, &capacity, &data, &data_len))
|
||||
return -1;
|
||||
|
||||
if (data != NULL) {
|
||||
self->base = malloc(data_len);
|
||||
self->end = self->base + data_len;
|
||||
memcpy(self->base, data, data_len);
|
||||
} else {
|
||||
self->base = malloc(capacity);
|
||||
self->end = self->base + capacity;
|
||||
}
|
||||
self->pos = self->base;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
Buffer_dealloc(BufferObject *self)
|
||||
{
|
||||
free(self->base);
|
||||
PyTypeObject *tp = Py_TYPE(self);
|
||||
freefunc free = PyType_GetSlot(tp, Py_tp_free);
|
||||
free(self);
|
||||
Py_DECREF(tp);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_data_slice(BufferObject *self, PyObject *args)
|
||||
{
|
||||
Py_ssize_t start, stop;
|
||||
if (!PyArg_ParseTuple(args, "nn", &start, &stop))
|
||||
return NULL;
|
||||
|
||||
if (start < 0 || self->base + start > self->end ||
|
||||
stop < 0 || self->base + stop > self->end ||
|
||||
stop < start) {
|
||||
PyErr_SetString(BufferReadError, "Read out of bounds");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)(self->base + start), (stop - start));
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_eof(BufferObject *self, PyObject *args)
|
||||
{
|
||||
if (self->pos == self->end)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_bytes(BufferObject *self, PyObject *args)
|
||||
{
|
||||
Py_ssize_t len;
|
||||
if (!PyArg_ParseTuple(args, "n", &len))
|
||||
return NULL;
|
||||
|
||||
CHECK_READ_BOUNDS(self, len);
|
||||
|
||||
PyObject *o = PyBytes_FromStringAndSize((const char*)self->pos, len);
|
||||
self->pos += len;
|
||||
return o;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint8(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 1)
|
||||
|
||||
return PyLong_FromUnsignedLong(
|
||||
(uint8_t)(*(self->pos++))
|
||||
);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint16(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 2)
|
||||
|
||||
uint16_t value = (uint16_t)(*(self->pos)) << 8 |
|
||||
(uint16_t)(*(self->pos + 1));
|
||||
self->pos += 2;
|
||||
return PyLong_FromUnsignedLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint32(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 4)
|
||||
|
||||
uint32_t value = (uint32_t)(*(self->pos)) << 24 |
|
||||
(uint32_t)(*(self->pos + 1)) << 16 |
|
||||
(uint32_t)(*(self->pos + 2)) << 8 |
|
||||
(uint32_t)(*(self->pos + 3));
|
||||
self->pos += 4;
|
||||
return PyLong_FromUnsignedLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint64(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 8)
|
||||
|
||||
uint64_t value = (uint64_t)(*(self->pos)) << 56 |
|
||||
(uint64_t)(*(self->pos + 1)) << 48 |
|
||||
(uint64_t)(*(self->pos + 2)) << 40 |
|
||||
(uint64_t)(*(self->pos + 3)) << 32 |
|
||||
(uint64_t)(*(self->pos + 4)) << 24 |
|
||||
(uint64_t)(*(self->pos + 5)) << 16 |
|
||||
(uint64_t)(*(self->pos + 6)) << 8 |
|
||||
(uint64_t)(*(self->pos + 7));
|
||||
self->pos += 8;
|
||||
return PyLong_FromUnsignedLongLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint_var(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint64_t value;
|
||||
CHECK_READ_BOUNDS(self, 1)
|
||||
switch (*(self->pos) >> 6) {
|
||||
case 0:
|
||||
value = *(self->pos++) & 0x3F;
|
||||
break;
|
||||
case 1:
|
||||
CHECK_READ_BOUNDS(self, 2)
|
||||
value = (uint16_t)(*(self->pos) & 0x3F) << 8 |
|
||||
(uint16_t)(*(self->pos + 1));
|
||||
self->pos += 2;
|
||||
break;
|
||||
case 2:
|
||||
CHECK_READ_BOUNDS(self, 4)
|
||||
value = (uint32_t)(*(self->pos) & 0x3F) << 24 |
|
||||
(uint32_t)(*(self->pos + 1)) << 16 |
|
||||
(uint32_t)(*(self->pos + 2)) << 8 |
|
||||
(uint32_t)(*(self->pos + 3));
|
||||
self->pos += 4;
|
||||
break;
|
||||
default:
|
||||
CHECK_READ_BOUNDS(self, 8)
|
||||
value = (uint64_t)(*(self->pos) & 0x3F) << 56 |
|
||||
(uint64_t)(*(self->pos + 1)) << 48 |
|
||||
(uint64_t)(*(self->pos + 2)) << 40 |
|
||||
(uint64_t)(*(self->pos + 3)) << 32 |
|
||||
(uint64_t)(*(self->pos + 4)) << 24 |
|
||||
(uint64_t)(*(self->pos + 5)) << 16 |
|
||||
(uint64_t)(*(self->pos + 6)) << 8 |
|
||||
(uint64_t)(*(self->pos + 7));
|
||||
self->pos += 8;
|
||||
break;
|
||||
}
|
||||
return PyLong_FromUnsignedLongLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_bytes(BufferObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *data;
|
||||
Py_ssize_t data_len;
|
||||
if (!PyArg_ParseTuple(args, "y#", &data, &data_len))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, data_len)
|
||||
|
||||
memcpy(self->pos, data, data_len);
|
||||
self->pos += data_len;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint8(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint8_t value;
|
||||
if (!PyArg_ParseTuple(args, "B", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 1)
|
||||
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint16(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint16_t value;
|
||||
if (!PyArg_ParseTuple(args, "H", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 2)
|
||||
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint32(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint32_t value;
|
||||
if (!PyArg_ParseTuple(args, "I", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 4)
|
||||
*(self->pos++) = (value >> 24);
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint64(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint64_t value;
|
||||
if (!PyArg_ParseTuple(args, "K", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 8)
|
||||
*(self->pos++) = (value >> 56);
|
||||
*(self->pos++) = (value >> 48);
|
||||
*(self->pos++) = (value >> 40);
|
||||
*(self->pos++) = (value >> 32);
|
||||
*(self->pos++) = (value >> 24);
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint_var(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint64_t value;
|
||||
if (!PyArg_ParseTuple(args, "K", &value))
|
||||
return NULL;
|
||||
|
||||
if (value <= 0x3F) {
|
||||
CHECK_WRITE_BOUNDS(self, 1)
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else if (value <= 0x3FFF) {
|
||||
CHECK_WRITE_BOUNDS(self, 2)
|
||||
*(self->pos++) = (value >> 8) | 0x40;
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else if (value <= 0x3FFFFFFF) {
|
||||
CHECK_WRITE_BOUNDS(self, 4)
|
||||
*(self->pos++) = (value >> 24) | 0x80;
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else if (value <= 0x3FFFFFFFFFFFFFFF) {
|
||||
CHECK_WRITE_BOUNDS(self, 8)
|
||||
*(self->pos++) = (value >> 56) | 0xC0;
|
||||
*(self->pos++) = (value >> 48);
|
||||
*(self->pos++) = (value >> 40);
|
||||
*(self->pos++) = (value >> 32);
|
||||
*(self->pos++) = (value >> 24);
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_ValueError, "Integer is too big for a variable-length integer");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_seek(BufferObject *self, PyObject *args)
|
||||
{
|
||||
Py_ssize_t pos;
|
||||
if (!PyArg_ParseTuple(args, "n", &pos))
|
||||
return NULL;
|
||||
|
||||
if (pos < 0 || self->base + pos > self->end) {
|
||||
PyErr_SetString(BufferReadError, "Seek out of bounds");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
self->pos = self->base + pos;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_tell(BufferObject *self, PyObject *args)
|
||||
{
|
||||
return PyLong_FromSsize_t(self->pos - self->base);
|
||||
}
|
||||
|
||||
static PyMethodDef Buffer_methods[] = {
|
||||
{"data_slice", (PyCFunction)Buffer_data_slice, METH_VARARGS, ""},
|
||||
{"eof", (PyCFunction)Buffer_eof, METH_VARARGS, ""},
|
||||
{"pull_bytes", (PyCFunction)Buffer_pull_bytes, METH_VARARGS, "Pull bytes."},
|
||||
{"pull_uint8", (PyCFunction)Buffer_pull_uint8, METH_VARARGS, "Pull an 8-bit unsigned integer."},
|
||||
{"pull_uint16", (PyCFunction)Buffer_pull_uint16, METH_VARARGS, "Pull a 16-bit unsigned integer."},
|
||||
{"pull_uint32", (PyCFunction)Buffer_pull_uint32, METH_VARARGS, "Pull a 32-bit unsigned integer."},
|
||||
{"pull_uint64", (PyCFunction)Buffer_pull_uint64, METH_VARARGS, "Pull a 64-bit unsigned integer."},
|
||||
{"pull_uint_var", (PyCFunction)Buffer_pull_uint_var, METH_VARARGS, "Pull a QUIC variable-length unsigned integer."},
|
||||
{"push_bytes", (PyCFunction)Buffer_push_bytes, METH_VARARGS, "Push bytes."},
|
||||
{"push_uint8", (PyCFunction)Buffer_push_uint8, METH_VARARGS, "Push an 8-bit unsigned integer."},
|
||||
{"push_uint16", (PyCFunction)Buffer_push_uint16, METH_VARARGS, "Push a 16-bit unsigned integer."},
|
||||
{"push_uint32", (PyCFunction)Buffer_push_uint32, METH_VARARGS, "Push a 32-bit unsigned integer."},
|
||||
{"push_uint64", (PyCFunction)Buffer_push_uint64, METH_VARARGS, "Push a 64-bit unsigned integer."},
|
||||
{"push_uint_var", (PyCFunction)Buffer_push_uint_var, METH_VARARGS, "Push a QUIC variable-length unsigned integer."},
|
||||
{"seek", (PyCFunction)Buffer_seek, METH_VARARGS, ""},
|
||||
{"tell", (PyCFunction)Buffer_tell, METH_VARARGS, ""},
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyObject*
|
||||
Buffer_capacity_getter(BufferObject* self, void *closure) {
|
||||
return PyLong_FromSsize_t(self->end - self->base);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
Buffer_data_getter(BufferObject* self, void *closure) {
|
||||
return PyBytes_FromStringAndSize((const char*)self->base, self->pos - self->base);
|
||||
}
|
||||
|
||||
static PyGetSetDef Buffer_getset[] = {
|
||||
{"capacity", (getter) Buffer_capacity_getter, NULL, "", NULL },
|
||||
{"data", (getter) Buffer_data_getter, NULL, "", NULL },
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyType_Slot BufferType_slots[] = {
|
||||
{Py_tp_dealloc, Buffer_dealloc},
|
||||
{Py_tp_methods, Buffer_methods},
|
||||
{Py_tp_doc, "Buffer objects"},
|
||||
{Py_tp_getset, Buffer_getset},
|
||||
{Py_tp_init, Buffer_init},
|
||||
{0, 0},
|
||||
};
|
||||
|
||||
static PyType_Spec BufferType_spec = {
|
||||
MODULE_NAME ".Buffer",
|
||||
sizeof(BufferObject),
|
||||
0,
|
||||
Py_TPFLAGS_DEFAULT,
|
||||
BufferType_slots
|
||||
};
|
||||
|
||||
|
||||
static struct PyModuleDef moduledef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME, /* m_name */
|
||||
"Serialization utilities.", /* m_doc */
|
||||
-1, /* m_size */
|
||||
NULL, /* m_methods */
|
||||
NULL, /* m_reload */
|
||||
NULL, /* m_traverse */
|
||||
NULL, /* m_clear */
|
||||
NULL, /* m_free */
|
||||
};
|
||||
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit__buffer(void)
|
||||
{
|
||||
PyObject* m;
|
||||
|
||||
m = PyModule_Create(&moduledef);
|
||||
if (m == NULL)
|
||||
return NULL;
|
||||
|
||||
BufferReadError = PyErr_NewException(MODULE_NAME ".BufferReadError", PyExc_ValueError, NULL);
|
||||
Py_INCREF(BufferReadError);
|
||||
PyModule_AddObject(m, "BufferReadError", BufferReadError);
|
||||
|
||||
BufferWriteError = PyErr_NewException(MODULE_NAME ".BufferWriteError", PyExc_ValueError, NULL);
|
||||
Py_INCREF(BufferWriteError);
|
||||
PyModule_AddObject(m, "BufferWriteError", BufferWriteError);
|
||||
|
||||
BufferType = PyType_FromSpec(&BufferType_spec);
|
||||
if (BufferType == NULL)
|
||||
return NULL;
|
||||
PyModule_AddObject(m, "Buffer", BufferType);
|
||||
|
||||
return m;
|
||||
}
|
||||
27
Code/venv/lib/python3.13/site-packages/aioquic/_buffer.pyi
Normal file
27
Code/venv/lib/python3.13/site-packages/aioquic/_buffer.pyi
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Optional
|
||||
|
||||
class BufferReadError(ValueError): ...
|
||||
class BufferWriteError(ValueError): ...
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, capacity: Optional[int] = 0, data: Optional[bytes] = None): ...
|
||||
@property
|
||||
def capacity(self) -> int: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def data_slice(self, start: int, end: int) -> bytes: ...
|
||||
def eof(self) -> bool: ...
|
||||
def seek(self, pos: int) -> None: ...
|
||||
def tell(self) -> int: ...
|
||||
def pull_bytes(self, length: int) -> bytes: ...
|
||||
def pull_uint8(self) -> int: ...
|
||||
def pull_uint16(self) -> int: ...
|
||||
def pull_uint32(self) -> int: ...
|
||||
def pull_uint64(self) -> int: ...
|
||||
def pull_uint_var(self) -> int: ...
|
||||
def push_bytes(self, value: bytes) -> None: ...
|
||||
def push_uint8(self, value: int) -> None: ...
|
||||
def push_uint16(self, value: int) -> None: ...
|
||||
def push_uint32(self, v: int) -> None: ...
|
||||
def push_uint64(self, v: int) -> None: ...
|
||||
def push_uint_var(self, value: int) -> None: ...
|
||||
BIN
Code/venv/lib/python3.13/site-packages/aioquic/_crypto.abi3.so
Executable file
BIN
Code/venv/lib/python3.13/site-packages/aioquic/_crypto.abi3.so
Executable file
Binary file not shown.
416
Code/venv/lib/python3.13/site-packages/aioquic/_crypto.c
Normal file
416
Code/venv/lib/python3.13/site-packages/aioquic/_crypto.c
Normal file
@ -0,0 +1,416 @@
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
|
||||
#include <Python.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
|
||||
#define MODULE_NAME "aioquic._crypto"
|
||||
|
||||
#define AEAD_KEY_LENGTH_MAX 32
|
||||
#define AEAD_NONCE_LENGTH 12
|
||||
#define AEAD_TAG_LENGTH 16
|
||||
|
||||
#define PACKET_LENGTH_MAX 1500
|
||||
#define PACKET_NUMBER_LENGTH_MAX 4
|
||||
#define SAMPLE_LENGTH 16
|
||||
|
||||
#define CHECK_RESULT(expr) \
|
||||
if (!(expr)) { \
|
||||
ERR_clear_error(); \
|
||||
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
#define CHECK_RESULT_CTOR(expr) \
|
||||
if (!(expr)) { \
|
||||
ERR_clear_error(); \
|
||||
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
|
||||
return -1; \
|
||||
}
|
||||
|
||||
static PyObject *CryptoError;
|
||||
|
||||
/* AEAD */
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
EVP_CIPHER_CTX *decrypt_ctx;
|
||||
EVP_CIPHER_CTX *encrypt_ctx;
|
||||
unsigned char buffer[PACKET_LENGTH_MAX];
|
||||
unsigned char key[AEAD_KEY_LENGTH_MAX];
|
||||
unsigned char iv[AEAD_NONCE_LENGTH];
|
||||
unsigned char nonce[AEAD_NONCE_LENGTH];
|
||||
} AEADObject;
|
||||
|
||||
static PyObject *AEADType;
|
||||
|
||||
static EVP_CIPHER_CTX *
|
||||
create_ctx(const EVP_CIPHER *cipher, int key_length, int operation)
|
||||
{
|
||||
EVP_CIPHER_CTX *ctx;
|
||||
int res;
|
||||
|
||||
ctx = EVP_CIPHER_CTX_new();
|
||||
CHECK_RESULT(ctx != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, operation);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_set_key_length(ctx, key_length);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_CCM_SET_IVLEN, AEAD_NONCE_LENGTH, NULL);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static int
|
||||
AEAD_init(AEADObject *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
const char *cipher_name;
|
||||
const unsigned char *key, *iv;
|
||||
Py_ssize_t cipher_name_len, key_len, iv_len;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#y#", &cipher_name, &cipher_name_len, &key, &key_len, &iv, &iv_len))
|
||||
return -1;
|
||||
|
||||
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
|
||||
if (evp_cipher == 0) {
|
||||
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
|
||||
return -1;
|
||||
}
|
||||
if (key_len > AEAD_KEY_LENGTH_MAX) {
|
||||
PyErr_SetString(CryptoError, "Invalid key length");
|
||||
return -1;
|
||||
}
|
||||
if (iv_len > AEAD_NONCE_LENGTH) {
|
||||
PyErr_SetString(CryptoError, "Invalid iv length");
|
||||
return -1;
|
||||
}
|
||||
|
||||
memcpy(self->key, key, key_len);
|
||||
memcpy(self->iv, iv, iv_len);
|
||||
|
||||
self->decrypt_ctx = create_ctx(evp_cipher, key_len, 0);
|
||||
CHECK_RESULT_CTOR(self->decrypt_ctx != 0);
|
||||
|
||||
self->encrypt_ctx = create_ctx(evp_cipher, key_len, 1);
|
||||
CHECK_RESULT_CTOR(self->encrypt_ctx != 0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
AEAD_dealloc(AEADObject *self)
|
||||
{
|
||||
EVP_CIPHER_CTX_free(self->decrypt_ctx);
|
||||
EVP_CIPHER_CTX_free(self->encrypt_ctx);
|
||||
PyTypeObject *tp = Py_TYPE(self);
|
||||
freefunc free = PyType_GetSlot(tp, Py_tp_free);
|
||||
free(self);
|
||||
Py_DECREF(tp);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
AEAD_decrypt(AEADObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *data, *associated;
|
||||
Py_ssize_t data_len, associated_len;
|
||||
int outlen, outlen2, res;
|
||||
uint64_t pn;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
|
||||
return NULL;
|
||||
|
||||
if (data_len < AEAD_TAG_LENGTH || data_len > PACKET_LENGTH_MAX) {
|
||||
PyErr_SetString(CryptoError, "Invalid payload length");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
|
||||
}
|
||||
|
||||
res = EVP_CIPHER_CTX_ctrl(self->decrypt_ctx, EVP_CTRL_CCM_SET_TAG, AEAD_TAG_LENGTH, (void*)(data + (data_len - AEAD_TAG_LENGTH)));
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(self->decrypt_ctx, NULL, NULL, self->key, self->nonce, 0);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->decrypt_ctx, NULL, &outlen, associated, associated_len);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->decrypt_ctx, self->buffer, &outlen, data, data_len - AEAD_TAG_LENGTH);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherFinal_ex(self->decrypt_ctx, NULL, &outlen2);
|
||||
if (res == 0) {
|
||||
PyErr_SetString(CryptoError, "Payload decryption failed");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
AEAD_encrypt(AEADObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *data, *associated;
|
||||
Py_ssize_t data_len, associated_len;
|
||||
int outlen, outlen2, res;
|
||||
uint64_t pn;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
|
||||
return NULL;
|
||||
|
||||
if (data_len > PACKET_LENGTH_MAX) {
|
||||
PyErr_SetString(CryptoError, "Invalid payload length");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
|
||||
}
|
||||
|
||||
res = EVP_CipherInit_ex(self->encrypt_ctx, NULL, NULL, self->key, self->nonce, 1);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->encrypt_ctx, NULL, &outlen, associated, associated_len);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->encrypt_ctx, self->buffer, &outlen, data, data_len);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherFinal_ex(self->encrypt_ctx, NULL, &outlen2);
|
||||
CHECK_RESULT(res != 0 && outlen2 == 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_ctrl(self->encrypt_ctx, EVP_CTRL_CCM_GET_TAG, AEAD_TAG_LENGTH, self->buffer + outlen);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen + AEAD_TAG_LENGTH);
|
||||
}
|
||||
|
||||
static PyMethodDef AEAD_methods[] = {
|
||||
{"decrypt", (PyCFunction)AEAD_decrypt, METH_VARARGS, ""},
|
||||
{"encrypt", (PyCFunction)AEAD_encrypt, METH_VARARGS, ""},
|
||||
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyType_Slot AEADType_slots[] = {
|
||||
{Py_tp_dealloc, AEAD_dealloc},
|
||||
{Py_tp_methods, AEAD_methods},
|
||||
{Py_tp_doc, "AEAD objects"},
|
||||
{Py_tp_init, AEAD_init},
|
||||
{0, 0},
|
||||
};
|
||||
|
||||
static PyType_Spec AEADType_spec = {
|
||||
MODULE_NAME ".AEADType",
|
||||
sizeof(AEADObject),
|
||||
0,
|
||||
Py_TPFLAGS_DEFAULT,
|
||||
AEADType_slots
|
||||
};
|
||||
|
||||
/* HeaderProtection */
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
EVP_CIPHER_CTX *ctx;
|
||||
int is_chacha20;
|
||||
unsigned char buffer[PACKET_LENGTH_MAX];
|
||||
unsigned char mask[31];
|
||||
unsigned char zero[5];
|
||||
} HeaderProtectionObject;
|
||||
|
||||
static PyObject *HeaderProtectionType;
|
||||
|
||||
static int
|
||||
HeaderProtection_init(HeaderProtectionObject *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
const char *cipher_name;
|
||||
const unsigned char *key;
|
||||
Py_ssize_t cipher_name_len, key_len;
|
||||
int res;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#", &cipher_name, &cipher_name_len, &key, &key_len))
|
||||
return -1;
|
||||
|
||||
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
|
||||
if (evp_cipher == 0) {
|
||||
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
|
||||
return -1;
|
||||
}
|
||||
|
||||
memset(self->mask, 0, sizeof(self->mask));
|
||||
memset(self->zero, 0, sizeof(self->zero));
|
||||
self->is_chacha20 = cipher_name_len == 8 && memcmp(cipher_name, "chacha20", 8) == 0;
|
||||
|
||||
self->ctx = EVP_CIPHER_CTX_new();
|
||||
CHECK_RESULT_CTOR(self->ctx != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(self->ctx, evp_cipher, NULL, NULL, NULL, 1);
|
||||
CHECK_RESULT_CTOR(res != 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_set_key_length(self->ctx, key_len);
|
||||
CHECK_RESULT_CTOR(res != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(self->ctx, NULL, NULL, key, NULL, 1);
|
||||
CHECK_RESULT_CTOR(res != 0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
HeaderProtection_dealloc(HeaderProtectionObject *self)
|
||||
{
|
||||
EVP_CIPHER_CTX_free(self->ctx);
|
||||
PyTypeObject *tp = Py_TYPE(self);
|
||||
freefunc free = PyType_GetSlot(tp, Py_tp_free);
|
||||
free(self);
|
||||
Py_DECREF(tp);
|
||||
}
|
||||
|
||||
static int HeaderProtection_mask(HeaderProtectionObject *self, const unsigned char* sample)
|
||||
{
|
||||
int outlen;
|
||||
if (self->is_chacha20) {
|
||||
return EVP_CipherInit_ex(self->ctx, NULL, NULL, NULL, sample, 1) &&
|
||||
EVP_CipherUpdate(self->ctx, self->mask, &outlen, self->zero, sizeof(self->zero));
|
||||
} else {
|
||||
return EVP_CipherUpdate(self->ctx, self->mask, &outlen, sample, SAMPLE_LENGTH);
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
HeaderProtection_apply(HeaderProtectionObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *header, *payload;
|
||||
Py_ssize_t header_len, payload_len;
|
||||
int res;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#", &header, &header_len, &payload, &payload_len))
|
||||
return NULL;
|
||||
|
||||
int pn_length = (header[0] & 0x03) + 1;
|
||||
int pn_offset = header_len - pn_length;
|
||||
|
||||
res = HeaderProtection_mask(self, payload + PACKET_NUMBER_LENGTH_MAX - pn_length);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
memcpy(self->buffer, header, header_len);
|
||||
memcpy(self->buffer + header_len, payload, payload_len);
|
||||
|
||||
if (self->buffer[0] & 0x80) {
|
||||
self->buffer[0] ^= self->mask[0] & 0x0F;
|
||||
} else {
|
||||
self->buffer[0] ^= self->mask[0] & 0x1F;
|
||||
}
|
||||
|
||||
for (int i = 0; i < pn_length; ++i) {
|
||||
self->buffer[pn_offset + i] ^= self->mask[1 + i];
|
||||
}
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)self->buffer, header_len + payload_len);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
HeaderProtection_remove(HeaderProtectionObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *packet;
|
||||
Py_ssize_t packet_len;
|
||||
int pn_offset, res;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#I", &packet, &packet_len, &pn_offset))
|
||||
return NULL;
|
||||
|
||||
res = HeaderProtection_mask(self, packet + pn_offset + PACKET_NUMBER_LENGTH_MAX);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
memcpy(self->buffer, packet, pn_offset + PACKET_NUMBER_LENGTH_MAX);
|
||||
|
||||
if (self->buffer[0] & 0x80) {
|
||||
self->buffer[0] ^= self->mask[0] & 0x0F;
|
||||
} else {
|
||||
self->buffer[0] ^= self->mask[0] & 0x1F;
|
||||
}
|
||||
|
||||
int pn_length = (self->buffer[0] & 0x03) + 1;
|
||||
uint32_t pn_truncated = 0;
|
||||
for (int i = 0; i < pn_length; ++i) {
|
||||
self->buffer[pn_offset + i] ^= self->mask[1 + i];
|
||||
pn_truncated = self->buffer[pn_offset + i] | (pn_truncated << 8);
|
||||
}
|
||||
|
||||
return Py_BuildValue("y#i", self->buffer, pn_offset + pn_length, pn_truncated);
|
||||
}
|
||||
|
||||
static PyMethodDef HeaderProtection_methods[] = {
|
||||
{"apply", (PyCFunction)HeaderProtection_apply, METH_VARARGS, ""},
|
||||
{"remove", (PyCFunction)HeaderProtection_remove, METH_VARARGS, ""},
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyType_Slot HeaderProtectionType_slots[] = {
|
||||
{Py_tp_dealloc, HeaderProtection_dealloc},
|
||||
{Py_tp_methods, HeaderProtection_methods},
|
||||
{Py_tp_doc, "HeaderProtection objects"},
|
||||
{Py_tp_init, HeaderProtection_init},
|
||||
{0, 0},
|
||||
};
|
||||
|
||||
static PyType_Spec HeaderProtectionType_spec = {
|
||||
MODULE_NAME ".HeaderProtectionType",
|
||||
sizeof(HeaderProtectionObject),
|
||||
0,
|
||||
Py_TPFLAGS_DEFAULT,
|
||||
HeaderProtectionType_slots
|
||||
};
|
||||
|
||||
static struct PyModuleDef moduledef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME, /* m_name */
|
||||
"Cryptography utilities.", /* m_doc */
|
||||
-1, /* m_size */
|
||||
NULL, /* m_methods */
|
||||
NULL, /* m_reload */
|
||||
NULL, /* m_traverse */
|
||||
NULL, /* m_clear */
|
||||
NULL, /* m_free */
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit__crypto(void)
|
||||
{
|
||||
PyObject* m;
|
||||
|
||||
m = PyModule_Create(&moduledef);
|
||||
if (m == NULL)
|
||||
return NULL;
|
||||
|
||||
CryptoError = PyErr_NewException(MODULE_NAME ".CryptoError", PyExc_ValueError, NULL);
|
||||
Py_INCREF(CryptoError);
|
||||
PyModule_AddObject(m, "CryptoError", CryptoError);
|
||||
|
||||
AEADType = PyType_FromSpec(&AEADType_spec);
|
||||
if (AEADType == NULL)
|
||||
return NULL;
|
||||
PyModule_AddObject(m, "AEAD", AEADType);
|
||||
|
||||
HeaderProtectionType = PyType_FromSpec(&HeaderProtectionType_spec);
|
||||
if (HeaderProtectionType == NULL)
|
||||
return NULL;
|
||||
PyModule_AddObject(m, "HeaderProtection", HeaderProtectionType);
|
||||
|
||||
// ensure required ciphers are initialised
|
||||
EVP_add_cipher(EVP_aes_128_ecb());
|
||||
EVP_add_cipher(EVP_aes_128_gcm());
|
||||
EVP_add_cipher(EVP_aes_256_ecb());
|
||||
EVP_add_cipher(EVP_aes_256_gcm());
|
||||
|
||||
return m;
|
||||
}
|
||||
17
Code/venv/lib/python3.13/site-packages/aioquic/_crypto.pyi
Normal file
17
Code/venv/lib/python3.13/site-packages/aioquic/_crypto.pyi
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Tuple
|
||||
|
||||
class AEAD:
|
||||
def __init__(self, cipher_name: bytes, key: bytes, iv: bytes): ...
|
||||
def decrypt(
|
||||
self, data: bytes, associated_data: bytes, packet_number: int
|
||||
) -> bytes: ...
|
||||
def encrypt(
|
||||
self, data: bytes, associated_data: bytes, packet_number: int
|
||||
) -> bytes: ...
|
||||
|
||||
class CryptoError(ValueError): ...
|
||||
|
||||
class HeaderProtection:
|
||||
def __init__(self, cipher_name: bytes, key: bytes): ...
|
||||
def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ...
|
||||
def remove(self, packet: bytes, encrypted_offset: int) -> Tuple[bytes, int]: ...
|
||||
@ -0,0 +1,3 @@
|
||||
from .client import connect # noqa
|
||||
from .protocol import QuicConnectionProtocol # noqa
|
||||
from .server import serve # noqa
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Callable, Optional, cast
|
||||
|
||||
from ..quic.configuration import QuicConfiguration
|
||||
from ..quic.connection import QuicConnection, QuicTokenHandler
|
||||
from ..tls import SessionTicketHandler
|
||||
from .protocol import QuicConnectionProtocol, QuicStreamHandler
|
||||
|
||||
__all__ = ["connect"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
configuration: Optional[QuicConfiguration] = None,
|
||||
create_protocol: Optional[Callable] = QuicConnectionProtocol,
|
||||
session_ticket_handler: Optional[SessionTicketHandler] = None,
|
||||
stream_handler: Optional[QuicStreamHandler] = None,
|
||||
token_handler: Optional[QuicTokenHandler] = None,
|
||||
wait_connected: bool = True,
|
||||
local_port: int = 0,
|
||||
) -> AsyncGenerator[QuicConnectionProtocol, None]:
|
||||
"""
|
||||
Connect to a QUIC server at the given `host` and `port`.
|
||||
|
||||
:meth:`connect()` returns an awaitable. Awaiting it yields a
|
||||
:class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to
|
||||
create streams.
|
||||
|
||||
:func:`connect` also accepts the following optional arguments:
|
||||
|
||||
* ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration`
|
||||
configuration object.
|
||||
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
|
||||
manages the connection. It should be a callable or class accepting the same
|
||||
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
|
||||
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
|
||||
* ``session_ticket_handler`` is a callback which is invoked by the TLS
|
||||
engine when a new session ticket is received.
|
||||
* ``stream_handler`` is a callback which is invoked whenever a stream is
|
||||
created. It must accept two arguments: a :class:`asyncio.StreamReader`
|
||||
and a :class:`asyncio.StreamWriter`.
|
||||
* ``wait_connected`` indicates whether the context manager should wait for the
|
||||
connection to be established before yielding the
|
||||
:class:`~aioquic.asyncio.QuicConnectionProtocol`. By default this is `True` but
|
||||
you can set it to `False` if you want to immediately start sending data using
|
||||
0-RTT.
|
||||
* ``local_port`` is the UDP port number that this client wants to bind.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
local_host = "::"
|
||||
|
||||
# lookup remote address
|
||||
infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
|
||||
addr = infos[0][4]
|
||||
if len(addr) == 2:
|
||||
addr = ("::ffff:" + addr[0], addr[1], 0, 0)
|
||||
|
||||
# prepare QUIC connection
|
||||
if configuration is None:
|
||||
configuration = QuicConfiguration(is_client=True)
|
||||
if configuration.server_name is None:
|
||||
configuration.server_name = host
|
||||
connection = QuicConnection(
|
||||
configuration=configuration,
|
||||
session_ticket_handler=session_ticket_handler,
|
||||
token_handler=token_handler,
|
||||
)
|
||||
|
||||
# explicitly enable IPv4/IPv6 dual stack
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
completed = False
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||
sock.bind((local_host, local_port, 0, 0))
|
||||
completed = True
|
||||
finally:
|
||||
if not completed:
|
||||
sock.close()
|
||||
# connect
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: create_protocol(connection, stream_handler=stream_handler),
|
||||
sock=sock,
|
||||
)
|
||||
protocol = cast(QuicConnectionProtocol, protocol)
|
||||
try:
|
||||
protocol.connect(addr, transmit=wait_connected)
|
||||
if wait_connected:
|
||||
await protocol.wait_connected()
|
||||
yield protocol
|
||||
finally:
|
||||
protocol.close()
|
||||
await protocol.wait_closed()
|
||||
transport.close()
|
||||
@ -0,0 +1,272 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
|
||||
|
||||
from ..quic import events
|
||||
from ..quic.connection import NetworkAddress, QuicConnection
|
||||
from ..quic.packet import QuicErrorCode
|
||||
|
||||
QuicConnectionIdHandler = Callable[[bytes], None]
|
||||
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
|
||||
|
||||
|
||||
class QuicConnectionProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(
|
||||
self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
|
||||
):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._closed = asyncio.Event()
|
||||
self._connected = False
|
||||
self._connected_waiter: Optional[asyncio.Future[None]] = None
|
||||
self._loop = loop
|
||||
self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
|
||||
self._quic = quic
|
||||
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
|
||||
self._timer: Optional[asyncio.TimerHandle] = None
|
||||
self._timer_at: Optional[float] = None
|
||||
self._transmit_task: Optional[asyncio.Handle] = None
|
||||
self._transport: Optional[asyncio.DatagramTransport] = None
|
||||
|
||||
# callbacks
|
||||
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_terminated_handler: Callable[[], None] = lambda: None
|
||||
if stream_handler is not None:
|
||||
self._stream_handler = stream_handler
|
||||
else:
|
||||
self._stream_handler = lambda r, w: None
|
||||
|
||||
def change_connection_id(self) -> None:
|
||||
"""
|
||||
Change the connection ID used to communicate with the peer.
|
||||
|
||||
The previous connection ID will be retired.
|
||||
"""
|
||||
self._quic.change_connection_id()
|
||||
self.transmit()
|
||||
|
||||
def close(
|
||||
self,
|
||||
error_code: int = QuicErrorCode.NO_ERROR,
|
||||
reason_phrase: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Close the connection.
|
||||
|
||||
:param error_code: An error code indicating why the connection is
|
||||
being closed.
|
||||
:param reason_phrase: A human-readable explanation of why the
|
||||
connection is being closed.
|
||||
"""
|
||||
self._quic.close(
|
||||
error_code=error_code,
|
||||
reason_phrase=reason_phrase,
|
||||
)
|
||||
self.transmit()
|
||||
|
||||
def connect(self, addr: NetworkAddress, transmit=True) -> None:
|
||||
"""
|
||||
Initiate the TLS handshake.
|
||||
|
||||
This method can only be called for clients and a single time.
|
||||
"""
|
||||
self._quic.connect(addr, now=self._loop.time())
|
||||
if transmit:
|
||||
self.transmit()
|
||||
|
||||
async def create_stream(
|
||||
self, is_unidirectional: bool = False
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
"""
|
||||
Create a QUIC stream and return a pair of (reader, writer) objects.
|
||||
|
||||
The returned reader and writer objects are instances of
|
||||
:class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes.
|
||||
"""
|
||||
stream_id = self._quic.get_next_available_stream_id(
|
||||
is_unidirectional=is_unidirectional
|
||||
)
|
||||
return self._create_stream(stream_id)
|
||||
|
||||
def request_key_update(self) -> None:
|
||||
"""
|
||||
Request an update of the encryption keys.
|
||||
"""
|
||||
self._quic.request_key_update()
|
||||
self.transmit()
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""
|
||||
Ping the peer and wait for the response.
|
||||
"""
|
||||
waiter = self._loop.create_future()
|
||||
uid = id(waiter)
|
||||
self._ping_waiters[uid] = waiter
|
||||
self._quic.send_ping(uid)
|
||||
self.transmit()
|
||||
await asyncio.shield(waiter)
|
||||
|
||||
def transmit(self) -> None:
|
||||
"""
|
||||
Send pending datagrams to the peer and arm the timer if needed.
|
||||
|
||||
This method is called automatically when data is received from the peer
|
||||
or when a timer goes off. If you interact directly with the underlying
|
||||
:class:`~aioquic.quic.connection.QuicConnection`, make sure you call this
|
||||
method whenever data needs to be sent out to the network.
|
||||
"""
|
||||
self._transmit_task = None
|
||||
|
||||
# send datagrams
|
||||
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
|
||||
self._transport.sendto(data, addr)
|
||||
|
||||
# re-arm timer
|
||||
timer_at = self._quic.get_timer()
|
||||
if self._timer is not None and self._timer_at != timer_at:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
if self._timer is None and timer_at is not None:
|
||||
self._timer = self._loop.call_at(timer_at, self._handle_timer)
|
||||
self._timer_at = timer_at
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
"""
|
||||
Wait for the connection to be closed.
|
||||
"""
|
||||
await self._closed.wait()
|
||||
|
||||
async def wait_connected(self) -> None:
|
||||
"""
|
||||
Wait for the TLS handshake to complete.
|
||||
"""
|
||||
assert self._connected_waiter is None, "already awaiting connected"
|
||||
if not self._connected:
|
||||
self._connected_waiter = self._loop.create_future()
|
||||
await asyncio.shield(self._connected_waiter)
|
||||
|
||||
# asyncio.Transport
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
""":meta private:"""
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
|
||||
""":meta private:"""
|
||||
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
# overridable
|
||||
|
||||
def quic_event_received(self, event: events.QuicEvent) -> None:
|
||||
"""
|
||||
Called when a QUIC event is received.
|
||||
|
||||
Reimplement this in your subclass to handle the events.
|
||||
"""
|
||||
# FIXME: move this to a subclass
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
for reader in self._stream_readers.values():
|
||||
reader.feed_eof()
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
reader = self._stream_readers.get(event.stream_id, None)
|
||||
if reader is None:
|
||||
reader, writer = self._create_stream(event.stream_id)
|
||||
self._stream_handler(reader, writer)
|
||||
reader.feed_data(event.data)
|
||||
if event.end_stream:
|
||||
reader.feed_eof()
|
||||
|
||||
# private
|
||||
|
||||
def _create_stream(
|
||||
self, stream_id: int
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
adapter = QuicStreamAdapter(self, stream_id)
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.streams.StreamReaderProtocol(reader)
|
||||
writer = asyncio.StreamWriter(adapter, protocol, reader, self._loop)
|
||||
self._stream_readers[stream_id] = reader
|
||||
return reader, writer
|
||||
|
||||
def _handle_timer(self) -> None:
|
||||
now = max(self._timer_at, self._loop.time())
|
||||
self._timer = None
|
||||
self._timer_at = None
|
||||
self._quic.handle_timer(now=now)
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
def _process_events(self) -> None:
|
||||
event = self._quic.next_event()
|
||||
while event is not None:
|
||||
if isinstance(event, events.ConnectionIdIssued):
|
||||
self._connection_id_issued_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionIdRetired):
|
||||
self._connection_id_retired_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionTerminated):
|
||||
self._connection_terminated_handler()
|
||||
|
||||
# abort connection waiter
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected_waiter = None
|
||||
waiter.set_exception(ConnectionError)
|
||||
|
||||
# abort ping waiters
|
||||
for waiter in self._ping_waiters.values():
|
||||
waiter.set_exception(ConnectionError)
|
||||
self._ping_waiters.clear()
|
||||
|
||||
self._closed.set()
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected = True
|
||||
self._connected_waiter = None
|
||||
waiter.set_result(None)
|
||||
elif isinstance(event, events.PingAcknowledged):
|
||||
waiter = self._ping_waiters.pop(event.uid, None)
|
||||
if waiter is not None:
|
||||
waiter.set_result(None)
|
||||
self.quic_event_received(event)
|
||||
event = self._quic.next_event()
|
||||
|
||||
def _transmit_soon(self) -> None:
|
||||
if self._transmit_task is None:
|
||||
self._transmit_task = self._loop.call_soon(self.transmit)
|
||||
|
||||
|
||||
class QuicStreamAdapter(asyncio.Transport):
|
||||
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
|
||||
self.protocol = protocol
|
||||
self.stream_id = stream_id
|
||||
self._closing = False
|
||||
|
||||
def can_write_eof(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get information about the underlying QUIC stream.
|
||||
"""
|
||||
if name == "stream_id":
|
||||
return self.stream_id
|
||||
|
||||
def write(self, data):
|
||||
self.protocol._quic.send_stream_data(self.stream_id, data)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def write_eof(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def close(self):
|
||||
self.write_eof()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
215
Code/venv/lib/python3.13/site-packages/aioquic/asyncio/server.py
Normal file
215
Code/venv/lib/python3.13/site-packages/aioquic/asyncio/server.py
Normal file
@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional, Text, Union, cast
|
||||
|
||||
from ..buffer import Buffer
|
||||
from ..quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
|
||||
from ..quic.connection import NetworkAddress, QuicConnection
|
||||
from ..quic.packet import (
|
||||
QuicPacketType,
|
||||
encode_quic_retry,
|
||||
encode_quic_version_negotiation,
|
||||
pull_quic_header,
|
||||
)
|
||||
from ..quic.retry import QuicRetryTokenHandler
|
||||
from ..tls import SessionTicketFetcher, SessionTicketHandler
|
||||
from .protocol import QuicConnectionProtocol, QuicStreamHandler
|
||||
|
||||
__all__ = ["serve"]
|
||||
|
||||
|
||||
class QuicServer(asyncio.DatagramProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
configuration: QuicConfiguration,
|
||||
create_protocol: Callable = QuicConnectionProtocol,
|
||||
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
|
||||
session_ticket_handler: Optional[SessionTicketHandler] = None,
|
||||
retry: bool = False,
|
||||
stream_handler: Optional[QuicStreamHandler] = None,
|
||||
) -> None:
|
||||
self._configuration = configuration
|
||||
self._create_protocol = create_protocol
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
|
||||
self._session_ticket_fetcher = session_ticket_fetcher
|
||||
self._session_ticket_handler = session_ticket_handler
|
||||
self._transport: Optional[asyncio.DatagramTransport] = None
|
||||
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
if retry:
|
||||
self._retry = QuicRetryTokenHandler()
|
||||
else:
|
||||
self._retry = None
|
||||
|
||||
def close(self):
|
||||
for protocol in set(self._protocols.values()):
|
||||
protocol.close()
|
||||
self._protocols.clear()
|
||||
self._transport.close()
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
|
||||
data = cast(bytes, data)
|
||||
buf = Buffer(data=data)
|
||||
|
||||
try:
|
||||
header = pull_quic_header(
|
||||
buf, host_cid_length=self._configuration.connection_id_length
|
||||
)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
# version negotiation
|
||||
if (
|
||||
header.version is not None
|
||||
and header.version not in self._configuration.supported_versions
|
||||
):
|
||||
self._transport.sendto(
|
||||
encode_quic_version_negotiation(
|
||||
source_cid=header.destination_cid,
|
||||
destination_cid=header.source_cid,
|
||||
supported_versions=self._configuration.supported_versions,
|
||||
),
|
||||
addr,
|
||||
)
|
||||
return
|
||||
|
||||
protocol = self._protocols.get(header.destination_cid, None)
|
||||
original_destination_connection_id: Optional[bytes] = None
|
||||
retry_source_connection_id: Optional[bytes] = None
|
||||
if (
|
||||
protocol is None
|
||||
and len(data) >= SMALLEST_MAX_DATAGRAM_SIZE
|
||||
and header.packet_type == QuicPacketType.INITIAL
|
||||
):
|
||||
# retry
|
||||
if self._retry is not None:
|
||||
if not header.token:
|
||||
# create a retry token
|
||||
source_cid = os.urandom(8)
|
||||
self._transport.sendto(
|
||||
encode_quic_retry(
|
||||
version=header.version,
|
||||
source_cid=source_cid,
|
||||
destination_cid=header.source_cid,
|
||||
original_destination_cid=header.destination_cid,
|
||||
retry_token=self._retry.create_token(
|
||||
addr, header.destination_cid, source_cid
|
||||
),
|
||||
),
|
||||
addr,
|
||||
)
|
||||
return
|
||||
else:
|
||||
# validate retry token
|
||||
try:
|
||||
(
|
||||
original_destination_connection_id,
|
||||
retry_source_connection_id,
|
||||
) = self._retry.validate_token(addr, header.token)
|
||||
except ValueError:
|
||||
return
|
||||
else:
|
||||
original_destination_connection_id = header.destination_cid
|
||||
|
||||
# create new connection
|
||||
connection = QuicConnection(
|
||||
configuration=self._configuration,
|
||||
original_destination_connection_id=original_destination_connection_id,
|
||||
retry_source_connection_id=retry_source_connection_id,
|
||||
session_ticket_fetcher=self._session_ticket_fetcher,
|
||||
session_ticket_handler=self._session_ticket_handler,
|
||||
)
|
||||
protocol = self._create_protocol(
|
||||
connection, stream_handler=self._stream_handler
|
||||
)
|
||||
protocol.connection_made(self._transport)
|
||||
|
||||
# register callbacks
|
||||
protocol._connection_id_issued_handler = partial(
|
||||
self._connection_id_issued, protocol=protocol
|
||||
)
|
||||
protocol._connection_id_retired_handler = partial(
|
||||
self._connection_id_retired, protocol=protocol
|
||||
)
|
||||
protocol._connection_terminated_handler = partial(
|
||||
self._connection_terminated, protocol=protocol
|
||||
)
|
||||
|
||||
self._protocols[header.destination_cid] = protocol
|
||||
self._protocols[connection.host_cid] = protocol
|
||||
|
||||
if protocol is not None:
|
||||
protocol.datagram_received(data, addr)
|
||||
|
||||
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
|
||||
self._protocols[cid] = protocol
|
||||
|
||||
def _connection_id_retired(
|
||||
self, cid: bytes, protocol: QuicConnectionProtocol
|
||||
) -> None:
|
||||
assert self._protocols[cid] == protocol
|
||||
del self._protocols[cid]
|
||||
|
||||
def _connection_terminated(self, protocol: QuicConnectionProtocol):
|
||||
for cid, proto in list(self._protocols.items()):
|
||||
if proto == protocol:
|
||||
del self._protocols[cid]
|
||||
|
||||
|
||||
async def serve(
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
configuration: QuicConfiguration,
|
||||
create_protocol: Callable = QuicConnectionProtocol,
|
||||
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
|
||||
session_ticket_handler: Optional[SessionTicketHandler] = None,
|
||||
retry: bool = False,
|
||||
stream_handler: QuicStreamHandler = None,
|
||||
) -> QuicServer:
|
||||
"""
|
||||
Start a QUIC server at the given `host` and `port`.
|
||||
|
||||
:func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration`
|
||||
containing TLS certificate and private key as the ``configuration`` argument.
|
||||
|
||||
:func:`serve` also accepts the following optional arguments:
|
||||
|
||||
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
|
||||
manages the connection. It should be a callable or class accepting the same
|
||||
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
|
||||
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
|
||||
* ``session_ticket_fetcher`` is a callback which is invoked by the TLS
|
||||
engine when a session ticket is presented by the peer. It should return
|
||||
the session ticket with the specified ID or `None` if it is not found.
|
||||
* ``session_ticket_handler`` is a callback which is invoked by the TLS
|
||||
engine when a new session ticket is issued. It should store the session
|
||||
ticket for future lookup.
|
||||
* ``retry`` specifies whether client addresses should be validated prior to
|
||||
the cryptographic handshake using a retry packet.
|
||||
* ``stream_handler`` is a callback which is invoked whenever a stream is
|
||||
created. It must accept two arguments: a :class:`asyncio.StreamReader`
|
||||
and a :class:`asyncio.StreamWriter`.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
_, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: QuicServer(
|
||||
configuration=configuration,
|
||||
create_protocol=create_protocol,
|
||||
session_ticket_fetcher=session_ticket_fetcher,
|
||||
session_ticket_handler=session_ticket_handler,
|
||||
retry=retry,
|
||||
stream_handler=stream_handler,
|
||||
),
|
||||
local_addr=(host, port),
|
||||
)
|
||||
return protocol
|
||||
30
Code/venv/lib/python3.13/site-packages/aioquic/buffer.py
Normal file
30
Code/venv/lib/python3.13/site-packages/aioquic/buffer.py
Normal file
@ -0,0 +1,30 @@
|
||||
from ._buffer import Buffer, BufferReadError, BufferWriteError # noqa
|
||||
|
||||
UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF
|
||||
UINT_VAR_MAX_SIZE = 8
|
||||
|
||||
|
||||
def encode_uint_var(value: int) -> bytes:
|
||||
"""
|
||||
Encode a variable-length unsigned integer.
|
||||
"""
|
||||
buf = Buffer(capacity=UINT_VAR_MAX_SIZE)
|
||||
buf.push_uint_var(value)
|
||||
return buf.data
|
||||
|
||||
|
||||
def size_uint_var(value: int) -> int:
|
||||
"""
|
||||
Return the number of bytes required to encode the given value
|
||||
as a QUIC variable-length unsigned integer.
|
||||
"""
|
||||
if value <= 0x3F:
|
||||
return 1
|
||||
elif value <= 0x3FFF:
|
||||
return 2
|
||||
elif value <= 0x3FFFFFFF:
|
||||
return 4
|
||||
elif value <= 0x3FFFFFFFFFFFFFFF:
|
||||
return 8
|
||||
else:
|
||||
raise ValueError("Integer is too big for a variable-length integer")
|
||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,68 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from aioquic.h3.events import DataReceived, H3Event, Headers, HeadersReceived
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.events import QuicEvent, StreamDataReceived
|
||||
|
||||
H0_ALPN = ["hq-interop"]
|
||||
|
||||
|
||||
class H0Connection:
|
||||
"""
|
||||
An HTTP/0.9 connection object.
|
||||
"""
|
||||
|
||||
def __init__(self, quic: QuicConnection):
|
||||
self._buffer: Dict[int, bytes] = {}
|
||||
self._headers_received: Dict[int, bool] = {}
|
||||
self._is_client = quic.configuration.is_client
|
||||
self._quic = quic
|
||||
|
||||
def handle_event(self, event: QuicEvent) -> List[H3Event]:
|
||||
http_events: List[H3Event] = []
|
||||
|
||||
if isinstance(event, StreamDataReceived) and (event.stream_id % 4) == 0:
|
||||
data = self._buffer.pop(event.stream_id, b"") + event.data
|
||||
if not self._headers_received.get(event.stream_id, False):
|
||||
if self._is_client:
|
||||
http_events.append(
|
||||
HeadersReceived(
|
||||
headers=[], stream_ended=False, stream_id=event.stream_id
|
||||
)
|
||||
)
|
||||
elif data.endswith(b"\r\n") or event.end_stream:
|
||||
method, path = data.rstrip().split(b" ", 1)
|
||||
http_events.append(
|
||||
HeadersReceived(
|
||||
headers=[(b":method", method), (b":path", path)],
|
||||
stream_ended=False,
|
||||
stream_id=event.stream_id,
|
||||
)
|
||||
)
|
||||
data = b""
|
||||
else:
|
||||
# incomplete request, stash the data
|
||||
self._buffer[event.stream_id] = data
|
||||
return http_events
|
||||
self._headers_received[event.stream_id] = True
|
||||
|
||||
http_events.append(
|
||||
DataReceived(
|
||||
data=data, stream_ended=event.end_stream, stream_id=event.stream_id
|
||||
)
|
||||
)
|
||||
|
||||
return http_events
|
||||
|
||||
def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None:
|
||||
self._quic.send_stream_data(stream_id, data, end_stream)
|
||||
|
||||
def send_headers(
|
||||
self, stream_id: int, headers: Headers, end_stream: bool = False
|
||||
) -> None:
|
||||
if self._is_client:
|
||||
headers_dict = dict(headers)
|
||||
data = headers_dict[b":method"] + b" " + headers_dict[b":path"] + b"\r\n"
|
||||
else:
|
||||
data = b""
|
||||
self._quic.send_stream_data(stream_id, data, end_stream)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1218
Code/venv/lib/python3.13/site-packages/aioquic/h3/connection.py
Normal file
1218
Code/venv/lib/python3.13/site-packages/aioquic/h3/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
100
Code/venv/lib/python3.13/site-packages/aioquic/h3/events.py
Normal file
100
Code/venv/lib/python3.13/site-packages/aioquic/h3/events.py
Normal file
@ -0,0 +1,100 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
Headers = List[Tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class H3Event:
|
||||
"""
|
||||
Base class for HTTP/3 events.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataReceived(H3Event):
|
||||
"""
|
||||
The DataReceived event is fired whenever data is received on a stream from
|
||||
the remote peer.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
push_id: Optional[int] = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatagramReceived(H3Event):
|
||||
"""
|
||||
The DatagramReceived is fired whenever a datagram is received from the
|
||||
the remote peer.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeadersReceived(H3Event):
|
||||
"""
|
||||
The HeadersReceived event is fired whenever headers are received.
|
||||
"""
|
||||
|
||||
headers: Headers
|
||||
"The headers."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the headers were received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
push_id: Optional[int] = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
@dataclass
|
||||
class PushPromiseReceived(H3Event):
|
||||
"""
|
||||
The PushedStreamReceived event is fired whenever a pushed stream has been
|
||||
received from the remote peer.
|
||||
"""
|
||||
|
||||
headers: Headers
|
||||
"The request headers."
|
||||
|
||||
push_id: int
|
||||
"The Push ID of the push promise."
|
||||
|
||||
stream_id: int
|
||||
"The Stream ID of the stream that the push is related to."
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebTransportStreamDataReceived(H3Event):
|
||||
"""
|
||||
The WebTransportStreamDataReceived is fired whenever data is received
|
||||
for a WebTransport stream.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
session_id: int
|
||||
"The ID of the session the data was received for."
|
||||
@ -0,0 +1,17 @@
|
||||
class H3Error(Exception):
|
||||
"""
|
||||
Base class for HTTP/3 exceptions.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStreamTypeError(H3Error):
|
||||
"""
|
||||
An action was attempted on an invalid stream type.
|
||||
"""
|
||||
|
||||
|
||||
class NoAvailablePushIDError(H3Error):
|
||||
"""
|
||||
There are no available push IDs left, or push is not supported
|
||||
by the remote party.
|
||||
"""
|
||||
1
Code/venv/lib/python3.13/site-packages/aioquic/py.typed
Normal file
1
Code/venv/lib/python3.13/site-packages/aioquic/py.typed
Normal file
@ -0,0 +1 @@
|
||||
Marker
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,163 @@
|
||||
from dataclasses import dataclass, field
|
||||
from os import PathLike
|
||||
from re import split
|
||||
from typing import Any, List, Optional, TextIO, Union
|
||||
|
||||
from ..tls import (
|
||||
CipherSuite,
|
||||
SessionTicket,
|
||||
load_pem_private_key,
|
||||
load_pem_x509_certificates,
|
||||
)
|
||||
from .logger import QuicLogger
|
||||
from .packet import QuicProtocolVersion
|
||||
|
||||
SMALLEST_MAX_DATAGRAM_SIZE = 1200
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicConfiguration:
|
||||
"""
|
||||
A QUIC configuration.
|
||||
"""
|
||||
|
||||
alpn_protocols: Optional[List[str]] = None
|
||||
"""
|
||||
A list of supported ALPN protocols.
|
||||
"""
|
||||
|
||||
congestion_control_algorithm: str = "reno"
|
||||
"""
|
||||
The name of the congestion control algorithm to use.
|
||||
|
||||
Currently supported algorithms: `"reno", `"cubic"`.
|
||||
"""
|
||||
|
||||
connection_id_length: int = 8
|
||||
"""
|
||||
The length in bytes of local connection IDs.
|
||||
"""
|
||||
|
||||
idle_timeout: float = 60.0
|
||||
"""
|
||||
The idle timeout in seconds.
|
||||
|
||||
The connection is terminated if nothing is received for the given duration.
|
||||
"""
|
||||
|
||||
is_client: bool = True
|
||||
"""
|
||||
Whether this is the client side of the QUIC connection.
|
||||
"""
|
||||
|
||||
max_data: int = 1048576
|
||||
"""
|
||||
Connection-wide flow control limit.
|
||||
"""
|
||||
|
||||
max_datagram_size: int = SMALLEST_MAX_DATAGRAM_SIZE
|
||||
"""
|
||||
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
|
||||
"""
|
||||
|
||||
max_stream_data: int = 1048576
|
||||
"""
|
||||
Per-stream flow control limit.
|
||||
"""
|
||||
|
||||
quic_logger: Optional[QuicLogger] = None
|
||||
"""
|
||||
The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to.
|
||||
"""
|
||||
|
||||
secrets_log_file: TextIO = None
|
||||
"""
|
||||
A file-like object in which to log traffic secrets.
|
||||
|
||||
This is useful to analyze traffic captures with Wireshark.
|
||||
"""
|
||||
|
||||
server_name: Optional[str] = None
|
||||
"""
|
||||
The server name to use when verifying the server's TLS certificate, which
|
||||
can either be a DNS name or an IP address.
|
||||
|
||||
If it is a DNS name, it is also sent during the TLS handshake in the
|
||||
Server Name Indication (SNI) extension.
|
||||
|
||||
.. note:: This is only used by clients.
|
||||
"""
|
||||
|
||||
session_ticket: Optional[SessionTicket] = None
|
||||
"""
|
||||
The TLS session ticket which should be used for session resumption.
|
||||
"""
|
||||
|
||||
token: bytes = b""
|
||||
"""
|
||||
The address validation token that can be used to validate future connections.
|
||||
|
||||
.. note:: This is only used by clients.
|
||||
"""
|
||||
|
||||
# For internal purposes, not guaranteed to be stable.
|
||||
cadata: Optional[bytes] = None
|
||||
cafile: Optional[str] = None
|
||||
capath: Optional[str] = None
|
||||
certificate: Any = None
|
||||
certificate_chain: List[Any] = field(default_factory=list)
|
||||
cipher_suites: Optional[List[CipherSuite]] = None
|
||||
initial_rtt: float = 0.1
|
||||
max_datagram_frame_size: Optional[int] = None
|
||||
original_version: Optional[int] = None
|
||||
private_key: Any = None
|
||||
quantum_readiness_test: bool = False
|
||||
supported_versions: List[int] = field(
|
||||
default_factory=lambda: [
|
||||
QuicProtocolVersion.VERSION_1,
|
||||
QuicProtocolVersion.VERSION_2,
|
||||
]
|
||||
)
|
||||
verify_mode: Optional[int] = None
|
||||
|
||||
def load_cert_chain(
|
||||
self,
|
||||
certfile: PathLike,
|
||||
keyfile: Optional[PathLike] = None,
|
||||
password: Optional[Union[bytes, str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a private key and the corresponding certificate.
|
||||
"""
|
||||
with open(certfile, "rb") as fp:
|
||||
boundary = b"-----BEGIN PRIVATE KEY-----\n"
|
||||
chunks = split(b"\n" + boundary, fp.read())
|
||||
certificates = load_pem_x509_certificates(chunks[0])
|
||||
if len(chunks) == 2:
|
||||
private_key = boundary + chunks[1]
|
||||
self.private_key = load_pem_private_key(private_key)
|
||||
self.certificate = certificates[0]
|
||||
self.certificate_chain = certificates[1:]
|
||||
|
||||
if keyfile is not None:
|
||||
with open(keyfile, "rb") as fp:
|
||||
self.private_key = load_pem_private_key(
|
||||
fp.read(),
|
||||
password=password.encode("utf8")
|
||||
if isinstance(password, str)
|
||||
else password,
|
||||
)
|
||||
|
||||
def load_verify_locations(
|
||||
self,
|
||||
cafile: Optional[str] = None,
|
||||
capath: Optional[str] = None,
|
||||
cadata: Optional[bytes] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a set of "certification authority" (CA) certificates used to
|
||||
validate other peers' certificates.
|
||||
"""
|
||||
self.cafile = cafile
|
||||
self.capath = capath
|
||||
self.cadata = cadata
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,128 @@
|
||||
import abc
|
||||
from typing import Any, Dict, Iterable, Optional, Protocol
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
|
||||
K_GRANULARITY = 0.001 # seconds
|
||||
K_INITIAL_WINDOW = 10
|
||||
K_MINIMUM_WINDOW = 2
|
||||
|
||||
|
||||
class QuicCongestionControl(abc.ABC):
|
||||
"""
|
||||
Base class for congestion control implementations.
|
||||
"""
|
||||
|
||||
bytes_in_flight: int = 0
|
||||
congestion_window: int = 0
|
||||
ssthresh: Optional[int] = None
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_lost(
|
||||
self, *, now: float, packets: Iterable[QuicSentPacket]
|
||||
) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ...
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
|
||||
if self.ssthresh is not None:
|
||||
data["ssthresh"] = self.ssthresh
|
||||
return data
|
||||
|
||||
|
||||
class QuicCongestionControlFactory(Protocol):
|
||||
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ...
|
||||
|
||||
|
||||
class QuicRttMonitor:
|
||||
"""
|
||||
Roundtrip time monitor for HyStart.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._increases = 0
|
||||
self._last_time = None
|
||||
self._ready = False
|
||||
self._size = 5
|
||||
|
||||
self._filtered_min: Optional[float] = None
|
||||
|
||||
self._sample_idx = 0
|
||||
self._sample_max: Optional[float] = None
|
||||
self._sample_min: Optional[float] = None
|
||||
self._sample_time = 0.0
|
||||
self._samples = [0.0 for i in range(self._size)]
|
||||
|
||||
def add_rtt(self, *, rtt: float) -> None:
|
||||
self._samples[self._sample_idx] = rtt
|
||||
self._sample_idx += 1
|
||||
|
||||
if self._sample_idx >= self._size:
|
||||
self._sample_idx = 0
|
||||
self._ready = True
|
||||
|
||||
if self._ready:
|
||||
self._sample_max = self._samples[0]
|
||||
self._sample_min = self._samples[0]
|
||||
for sample in self._samples[1:]:
|
||||
if sample < self._sample_min:
|
||||
self._sample_min = sample
|
||||
elif sample > self._sample_max:
|
||||
self._sample_max = sample
|
||||
|
||||
def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
|
||||
if now > self._sample_time + K_GRANULARITY:
|
||||
self.add_rtt(rtt=rtt)
|
||||
self._sample_time = now
|
||||
|
||||
if self._ready:
|
||||
if self._filtered_min is None or self._filtered_min > self._sample_max:
|
||||
self._filtered_min = self._sample_max
|
||||
|
||||
delta = self._sample_min - self._filtered_min
|
||||
if delta * 4 >= self._filtered_min:
|
||||
self._increases += 1
|
||||
if self._increases >= self._size:
|
||||
return True
|
||||
elif delta > 0:
|
||||
self._increases = 0
|
||||
return False
|
||||
|
||||
|
||||
_factories: Dict[str, QuicCongestionControlFactory] = {}
|
||||
|
||||
|
||||
def create_congestion_control(
|
||||
name: str, *, max_datagram_size: int
|
||||
) -> QuicCongestionControl:
|
||||
"""
|
||||
Create an instance of the `name` congestion control algorithm.
|
||||
"""
|
||||
try:
|
||||
factory = _factories[name]
|
||||
except KeyError:
|
||||
raise Exception(f"Unknown congestion control algorithm: {name}")
|
||||
return factory(max_datagram_size=max_datagram_size)
|
||||
|
||||
|
||||
def register_congestion_control(
|
||||
name: str, factory: QuicCongestionControlFactory
|
||||
) -> None:
|
||||
"""
|
||||
Register a congestion control algorithm named `name`.
|
||||
"""
|
||||
_factories[name] = factory
|
||||
@ -0,0 +1,212 @@
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_INITIAL_WINDOW,
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions)
|
||||
K_CUBIC_C = 0.4
|
||||
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
|
||||
K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity
|
||||
|
||||
|
||||
def better_cube_root(x: float) -> float:
|
||||
if x < 0:
|
||||
# avoid precision errors that make the cube root returns an imaginary number
|
||||
return -((-x) ** (1.0 / 3.0))
|
||||
else:
|
||||
return (x) ** (1.0 / 3.0)
|
||||
|
||||
|
||||
class CubicCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
Cubic congestion control implementation for aioquic
|
||||
"""
|
||||
|
||||
def __init__(self, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
# increase by one segment
|
||||
self.additive_increase_factor: int = max_datagram_size
|
||||
self._max_datagram_size: int = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
self.rtt = 0.02 # starting RTT is considered to be 20ms
|
||||
|
||||
self.reset()
|
||||
|
||||
self.last_ack = 0.0
|
||||
|
||||
def W_cubic(self, t) -> int:
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments)
|
||||
return int(target_segments * self._max_datagram_size)
|
||||
|
||||
def is_reno_friendly(self, t) -> bool:
|
||||
return self.W_cubic(t) < self._W_est
|
||||
|
||||
def is_concave(self) -> bool:
|
||||
return self.congestion_window < self._W_max
|
||||
|
||||
def reset(self) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size
|
||||
self.ssthresh = None
|
||||
|
||||
self._first_slow_start = True
|
||||
self._starting_congestion_avoidance = False
|
||||
self.K: float = 0.0
|
||||
self._W_est = 0
|
||||
self._cwnd_epoch = 0
|
||||
self._t_epoch = 0.0
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
self.last_ack = packet.sent_time
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
if self._first_slow_start and not self._starting_congestion_avoidance:
|
||||
# exiting slow start without having a loss
|
||||
self._first_slow_start = False
|
||||
self._W_max = self.congestion_window
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
# initialize the variables used at start of congestion avoidance
|
||||
if self._starting_congestion_avoidance:
|
||||
self._starting_congestion_avoidance = False
|
||||
self._first_slow_start = False
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
self._W_est = int(
|
||||
self._W_est
|
||||
+ self.additive_increase_factor
|
||||
* (packet.sent_bytes / self.congestion_window)
|
||||
)
|
||||
|
||||
t = now - self._t_epoch
|
||||
|
||||
target: int = 0
|
||||
W_cubic = self.W_cubic(t + self.rtt)
|
||||
if W_cubic < self.congestion_window:
|
||||
target = self.congestion_window
|
||||
elif W_cubic > 1.5 * self.congestion_window:
|
||||
target = int(self.congestion_window * 1.5)
|
||||
else:
|
||||
target = W_cubic
|
||||
|
||||
if self.is_reno_friendly(t):
|
||||
# reno friendly region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region)
|
||||
self.congestion_window = self._W_est
|
||||
elif self.is_concave():
|
||||
# concave region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# convex region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
if self.last_ack == 0.0:
|
||||
return
|
||||
elapsed_idle = packet.sent_time - self.last_ack
|
||||
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
|
||||
self.reset()
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
|
||||
# Normal congestion handle, can't be used in same time as fast convergence
|
||||
# self._W_max = self.congestion_window
|
||||
|
||||
# fast convergence
|
||||
if self._W_max is not None and self.congestion_window < self._W_max:
|
||||
self._W_max = int(
|
||||
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
|
||||
)
|
||||
else:
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
# normal congestion MD
|
||||
flight_size = self.bytes_in_flight
|
||||
new_ssthresh = max(
|
||||
int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = new_ssthresh
|
||||
self.congestion_window = max(
|
||||
self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size
|
||||
)
|
||||
|
||||
# restart a new congestion avoidance phase
|
||||
self._starting_congestion_avoidance = True
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
self.rtt = rtt
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
rtt=rtt, now=now
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = super().get_log_data()
|
||||
|
||||
data["cubic-wmax"] = int(self._W_max)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
register_congestion_control("cubic", CubicCongestionControl)
|
||||
@ -0,0 +1,77 @@
|
||||
from typing import Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
K_LOSS_REDUCTION_FACTOR = 0.5
|
||||
|
||||
|
||||
class RenoCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
New Reno congestion control.
|
||||
"""
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
self._congestion_stash = 0
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
# don't increase window in congestion recovery
|
||||
if packet.sent_time <= self._congestion_recovery_start_time:
|
||||
return
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
self._congestion_stash += packet.sent_bytes
|
||||
count = self._congestion_stash // self.congestion_window
|
||||
if count:
|
||||
self._congestion_stash -= count * self.congestion_window
|
||||
self.congestion_window += count * self._max_datagram_size
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
self.congestion_window = max(
|
||||
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
# TODO : collapse congestion window if persistent congestion
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
now=now, rtt=rtt
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
|
||||
register_congestion_control("reno", RenoCongestionControl)
|
||||
3623
Code/venv/lib/python3.13/site-packages/aioquic/quic/connection.py
Normal file
3623
Code/venv/lib/python3.13/site-packages/aioquic/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
246
Code/venv/lib/python3.13/site-packages/aioquic/quic/crypto.py
Normal file
246
Code/venv/lib/python3.13/site-packages/aioquic/quic/crypto.py
Normal file
@ -0,0 +1,246 @@
|
||||
import binascii
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from .._crypto import AEAD, CryptoError, HeaderProtection
|
||||
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
|
||||
from .packet import (
|
||||
QuicProtocolVersion,
|
||||
decode_packet_number,
|
||||
is_long_header,
|
||||
)
|
||||
|
||||
CIPHER_SUITES = {
|
||||
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
|
||||
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
|
||||
}
|
||||
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
|
||||
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
|
||||
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
|
||||
SAMPLE_SIZE = 16
|
||||
|
||||
|
||||
Callback = Callable[[str], None]
|
||||
|
||||
|
||||
def NoCallback(trigger: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class KeyUnavailableError(CryptoError):
|
||||
pass
|
||||
|
||||
|
||||
def derive_key_iv_hp(
|
||||
*, cipher_suite: CipherSuite, secret: bytes, version: int
|
||||
) -> Tuple[bytes, bytes, bytes]:
|
||||
algorithm = cipher_suite_hash(cipher_suite)
|
||||
if cipher_suite in [
|
||||
CipherSuite.AES_256_GCM_SHA384,
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256,
|
||||
]:
|
||||
key_size = 32
|
||||
else:
|
||||
key_size = 16
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
|
||||
)
|
||||
|
||||
|
||||
class CryptoContext:
|
||||
def __init__(
|
||||
self,
|
||||
key_phase: int = 0,
|
||||
setup_cb: Callback = NoCallback,
|
||||
teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead: Optional[AEAD] = None
|
||||
self.cipher_suite: Optional[CipherSuite] = None
|
||||
self.hp: Optional[HeaderProtection] = None
|
||||
self.key_phase = key_phase
|
||||
self.secret: Optional[bytes] = None
|
||||
self.version: Optional[int] = None
|
||||
self._setup_cb = setup_cb
|
||||
self._teardown_cb = teardown_cb
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> Tuple[bytes, bytes, int, bool]:
|
||||
if self.aead is None:
|
||||
raise KeyUnavailableError("Decryption key is not available")
|
||||
|
||||
# header protection
|
||||
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
|
||||
first_byte = plain_header[0]
|
||||
|
||||
# packet number
|
||||
pn_length = (first_byte & 0x03) + 1
|
||||
packet_number = decode_packet_number(
|
||||
packet_number, pn_length * 8, expected_packet_number
|
||||
)
|
||||
|
||||
# detect key phase change
|
||||
crypto = self
|
||||
if not is_long_header(first_byte):
|
||||
key_phase = (first_byte & 4) >> 2
|
||||
if key_phase != self.key_phase:
|
||||
crypto = next_key_phase(self)
|
||||
|
||||
# payload protection
|
||||
payload = crypto.aead.decrypt(
|
||||
packet[len(plain_header) :], plain_header, packet_number
|
||||
)
|
||||
|
||||
return plain_header, payload, packet_number, crypto != self
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
assert self.is_valid(), "Encryption key is not available"
|
||||
|
||||
# payload protection
|
||||
protected_payload = self.aead.encrypt(
|
||||
plain_payload, plain_header, packet_number
|
||||
)
|
||||
|
||||
# header protection
|
||||
return self.hp.apply(plain_header, protected_payload)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.aead is not None
|
||||
|
||||
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
|
||||
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
|
||||
|
||||
key, iv, hp = derive_key_iv_hp(
|
||||
cipher_suite=cipher_suite,
|
||||
secret=secret,
|
||||
version=version,
|
||||
)
|
||||
self.aead = AEAD(aead_cipher_name, key, iv)
|
||||
self.cipher_suite = cipher_suite
|
||||
self.hp = HeaderProtection(hp_cipher_name, hp)
|
||||
self.secret = secret
|
||||
self.version = version
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb("tls")
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.aead = None
|
||||
self.cipher_suite = None
|
||||
self.hp = None
|
||||
self.secret = None
|
||||
|
||||
# trigger callback
|
||||
self._teardown_cb("tls")
|
||||
|
||||
|
||||
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
|
||||
self.aead = crypto.aead
|
||||
self.key_phase = crypto.key_phase
|
||||
self.secret = crypto.secret
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb(trigger)
|
||||
|
||||
|
||||
def next_key_phase(self: CryptoContext) -> CryptoContext:
|
||||
algorithm = cipher_suite_hash(self.cipher_suite)
|
||||
|
||||
crypto = CryptoContext(key_phase=int(not self.key_phase))
|
||||
crypto.setup(
|
||||
cipher_suite=self.cipher_suite,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, self.secret, b"quic ku", b"", algorithm.digest_size
|
||||
),
|
||||
version=self.version,
|
||||
)
|
||||
return crypto
|
||||
|
||||
|
||||
class CryptoPair:
|
||||
def __init__(
|
||||
self,
|
||||
recv_setup_cb: Callback = NoCallback,
|
||||
recv_teardown_cb: Callback = NoCallback,
|
||||
send_setup_cb: Callback = NoCallback,
|
||||
send_teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead_tag_size = 16
|
||||
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
|
||||
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
|
||||
self._update_key_requested = False
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> Tuple[bytes, bytes, int]:
|
||||
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
|
||||
packet, encrypted_offset, expected_packet_number
|
||||
)
|
||||
if update_key:
|
||||
self._update_key("remote_update")
|
||||
return plain_header, payload, packet_number
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
if self._update_key_requested:
|
||||
self._update_key("local_update")
|
||||
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
|
||||
|
||||
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
|
||||
if is_client:
|
||||
recv_label, send_label = b"server in", b"client in"
|
||||
else:
|
||||
recv_label, send_label = b"client in", b"server in"
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
initial_salt = INITIAL_SALT_VERSION_2
|
||||
else:
|
||||
initial_salt = INITIAL_SALT_VERSION_1
|
||||
|
||||
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
|
||||
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
|
||||
self.recv.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, recv_label, b"", algorithm.digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
self.send.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, send_label, b"", algorithm.digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.recv.teardown()
|
||||
self.send.teardown()
|
||||
|
||||
def update_key(self) -> None:
|
||||
self._update_key_requested = True
|
||||
|
||||
@property
|
||||
def key_phase(self) -> int:
|
||||
if self._update_key_requested:
|
||||
return int(not self.recv.key_phase)
|
||||
else:
|
||||
return self.recv.key_phase
|
||||
|
||||
def _update_key(self, trigger: str) -> None:
|
||||
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
|
||||
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
|
||||
self._update_key_requested = False
|
||||
126
Code/venv/lib/python3.13/site-packages/aioquic/quic/events.py
Normal file
126
Code/venv/lib/python3.13/site-packages/aioquic/quic/events.py
Normal file
@ -0,0 +1,126 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuicEvent:
|
||||
"""
|
||||
Base class for QUIC events.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdIssued(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdRetired(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionTerminated(QuicEvent):
|
||||
"""
|
||||
The ConnectionTerminated event is fired when the QUIC connection is terminated.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code which was specified when closing the connection."
|
||||
|
||||
frame_type: Optional[int]
|
||||
"The frame type which caused the connection to be closed, or `None`."
|
||||
|
||||
reason_phrase: str
|
||||
"The human-readable reason for which the connection was closed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatagramFrameReceived(QuicEvent):
|
||||
"""
|
||||
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandshakeCompleted(QuicEvent):
|
||||
"""
|
||||
The HandshakeCompleted event is fired when the TLS handshake completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: Optional[str]
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
early_data_accepted: bool
|
||||
"Whether early (0-RTT) data was accepted by the remote peer."
|
||||
|
||||
session_resumed: bool
|
||||
"Whether a TLS session was resumed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class PingAcknowledged(QuicEvent):
|
||||
"""
|
||||
The PingAcknowledged event is fired when a PING frame is acknowledged.
|
||||
"""
|
||||
|
||||
uid: int
|
||||
"The unique ID of the PING."
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtocolNegotiated(QuicEvent):
|
||||
"""
|
||||
The ProtocolNegotiated event is fired when ALPN negotiation completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: Optional[str]
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopSendingReceived(QuicEvent):
|
||||
"""
|
||||
The StopSendingReceived event is fired when the remote peer requests
|
||||
stopping data transmission on a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that was sent from the peer."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that the peer requested stopping data transmission."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamDataReceived(QuicEvent):
|
||||
"""
|
||||
The StreamDataReceived event is fired whenever data is received on a
|
||||
stream.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
end_stream: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamReset(QuicEvent):
|
||||
"""
|
||||
The StreamReset event is fired when the remote peer resets a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that triggered the reset."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that was reset."
|
||||
329
Code/venv/lib/python3.13/site-packages/aioquic/quic/logger.py
Normal file
329
Code/venv/lib/python3.13/site-packages/aioquic/quic/logger.py
Normal file
@ -0,0 +1,329 @@
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
from ..h3.events import Headers
|
||||
from .packet import (
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
QuicStreamFrame,
|
||||
QuicTransportParameters,
|
||||
)
|
||||
from .rangeset import RangeSet
|
||||
|
||||
PACKET_TYPE_NAMES = {
|
||||
QuicPacketType.INITIAL: "initial",
|
||||
QuicPacketType.HANDSHAKE: "handshake",
|
||||
QuicPacketType.ZERO_RTT: "0RTT",
|
||||
QuicPacketType.ONE_RTT: "1RTT",
|
||||
QuicPacketType.RETRY: "retry",
|
||||
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
|
||||
}
|
||||
QLOG_VERSION = "0.3"
|
||||
|
||||
|
||||
def hexdump(data: bytes) -> str:
|
||||
return binascii.hexlify(data).decode("ascii")
|
||||
|
||||
|
||||
class QuicLoggerTrace:
|
||||
"""
|
||||
A QUIC event trace.
|
||||
|
||||
Events are logged in the format defined by qlog.
|
||||
|
||||
See:
|
||||
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
|
||||
"""
|
||||
|
||||
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
|
||||
self._odcid = odcid
|
||||
self._events: Deque[Dict[str, Any]] = deque()
|
||||
self._vantage_point = {
|
||||
"name": "aioquic",
|
||||
"type": "client" if is_client else "server",
|
||||
}
|
||||
|
||||
# QUIC
|
||||
|
||||
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict:
|
||||
return {
|
||||
"ack_delay": self.encode_time(delay),
|
||||
"acked_ranges": [[x.start, x.stop - 1] for x in ranges],
|
||||
"frame_type": "ack",
|
||||
}
|
||||
|
||||
def encode_connection_close_frame(
|
||||
self, error_code: int, frame_type: Optional[int], reason_phrase: str
|
||||
) -> Dict:
|
||||
attrs = {
|
||||
"error_code": error_code,
|
||||
"error_space": "application" if frame_type is None else "transport",
|
||||
"frame_type": "connection_close",
|
||||
"raw_error_code": error_code,
|
||||
"reason": reason_phrase,
|
||||
}
|
||||
if frame_type is not None:
|
||||
attrs["trigger_frame_type"] = frame_type
|
||||
|
||||
return attrs
|
||||
|
||||
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict:
|
||||
if frame_type == QuicFrameType.MAX_DATA:
|
||||
return {"frame_type": "max_data", "maximum": maximum}
|
||||
else:
|
||||
return {
|
||||
"frame_type": "max_streams",
|
||||
"maximum": maximum,
|
||||
"stream_type": "unidirectional"
|
||||
if frame_type == QuicFrameType.MAX_STREAMS_UNI
|
||||
else "bidirectional",
|
||||
}
|
||||
|
||||
def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict:
|
||||
return {
|
||||
"frame_type": "crypto",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
}
|
||||
|
||||
def encode_data_blocked_frame(self, limit: int) -> Dict:
|
||||
return {"frame_type": "data_blocked", "limit": limit}
|
||||
|
||||
def encode_datagram_frame(self, length: int) -> Dict:
|
||||
return {"frame_type": "datagram", "length": length}
|
||||
|
||||
def encode_handshake_done_frame(self) -> Dict:
|
||||
return {"frame_type": "handshake_done"}
|
||||
|
||||
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "max_stream_data",
|
||||
"maximum": maximum,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_new_connection_id_frame(
|
||||
self,
|
||||
connection_id: bytes,
|
||||
retire_prior_to: int,
|
||||
sequence_number: int,
|
||||
stateless_reset_token: bytes,
|
||||
) -> Dict:
|
||||
return {
|
||||
"connection_id": hexdump(connection_id),
|
||||
"frame_type": "new_connection_id",
|
||||
"length": len(connection_id),
|
||||
"reset_token": hexdump(stateless_reset_token),
|
||||
"retire_prior_to": retire_prior_to,
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_new_token_frame(self, token: bytes) -> Dict:
|
||||
return {
|
||||
"frame_type": "new_token",
|
||||
"length": len(token),
|
||||
"token": hexdump(token),
|
||||
}
|
||||
|
||||
def encode_padding_frame(self) -> Dict:
|
||||
return {"frame_type": "padding"}
|
||||
|
||||
def encode_path_challenge_frame(self, data: bytes) -> Dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_challenge"}
|
||||
|
||||
def encode_path_response_frame(self, data: bytes) -> Dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_response"}
|
||||
|
||||
def encode_ping_frame(self) -> Dict:
|
||||
return {"frame_type": "ping"}
|
||||
|
||||
def encode_reset_stream_frame(
|
||||
self, error_code: int, final_size: int, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"error_code": error_code,
|
||||
"final_size": final_size,
|
||||
"frame_type": "reset_stream",
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "retire_connection_id",
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "stream_data_blocked",
|
||||
"limit": limit,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "stop_sending",
|
||||
"error_code": error_code,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict:
|
||||
return {
|
||||
"fin": frame.fin,
|
||||
"frame_type": "stream",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "streams_blocked",
|
||||
"limit": limit,
|
||||
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
|
||||
}
|
||||
|
||||
def encode_time(self, seconds: float) -> float:
|
||||
"""
|
||||
Convert a time to milliseconds.
|
||||
"""
|
||||
return seconds * 1000
|
||||
|
||||
def encode_transport_parameters(
|
||||
self, owner: str, parameters: QuicTransportParameters
|
||||
) -> Dict[str, Any]:
|
||||
data: Dict[str, Any] = {"owner": owner}
|
||||
for param_name, param_value in parameters.__dict__.items():
|
||||
if isinstance(param_value, bool):
|
||||
data[param_name] = param_value
|
||||
elif isinstance(param_value, bytes):
|
||||
data[param_name] = hexdump(param_value)
|
||||
elif isinstance(param_value, int):
|
||||
data[param_name] = param_value
|
||||
return data
|
||||
|
||||
def packet_type(self, packet_type: QuicPacketType) -> str:
|
||||
return PACKET_TYPE_NAMES[packet_type]
|
||||
|
||||
# HTTP/3
|
||||
|
||||
def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame": {"frame_type": "data"},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_headers_frame(
|
||||
self, length: int, headers: Headers, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "headers",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_push_promise_frame(
|
||||
self, length: int, headers: Headers, push_id: int, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "push_promise",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
"push_id": push_id,
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def _encode_http3_headers(self, headers: Headers) -> List[Dict]:
|
||||
return [
|
||||
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
|
||||
]
|
||||
|
||||
# CORE
|
||||
|
||||
def log_event(self, *, category: str, event: str, data: Dict) -> None:
|
||||
self._events.append(
|
||||
{
|
||||
"data": data,
|
||||
"name": category + ":" + event,
|
||||
"time": self.encode_time(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the trace as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"common_fields": {
|
||||
"ODCID": hexdump(self._odcid),
|
||||
},
|
||||
"events": list(self._events),
|
||||
"vantage_point": self._vantage_point,
|
||||
}
|
||||
|
||||
|
||||
class QuicLogger:
|
||||
"""
|
||||
A QUIC event logger which stores traces in memory.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._traces: List[QuicLoggerTrace] = []
|
||||
|
||||
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
|
||||
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
|
||||
self._traces.append(trace)
|
||||
return trace
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the traces as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace.to_dict() for trace in self._traces],
|
||||
}
|
||||
|
||||
|
||||
class QuicFileLogger(QuicLogger):
|
||||
"""
|
||||
A QUIC event logger which writes one trace per file.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
if not os.path.isdir(path):
|
||||
raise ValueError("QUIC log output directory '%s' does not exist" % path)
|
||||
self.path = path
|
||||
super().__init__()
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
trace_dict = trace.to_dict()
|
||||
trace_path = os.path.join(
|
||||
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
|
||||
)
|
||||
with open(trace_path, "w") as logger_fp:
|
||||
json.dump(
|
||||
{
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace_dict],
|
||||
},
|
||||
logger_fp,
|
||||
)
|
||||
self._traces.remove(trace)
|
||||
640
Code/venv/lib/python3.13/site-packages/aioquic/quic/packet.py
Normal file
640
Code/venv/lib/python3.13/site-packages/aioquic/quic/packet.py
Normal file
@ -0,0 +1,640 @@
|
||||
import binascii
|
||||
import ipaddress
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ..buffer import Buffer
|
||||
from .rangeset import RangeSet
|
||||
|
||||
PACKET_LONG_HEADER = 0x80
|
||||
PACKET_FIXED_BIT = 0x40
|
||||
PACKET_SPIN_BIT = 0x20
|
||||
|
||||
CONNECTION_ID_MAX_SIZE = 20
|
||||
PACKET_NUMBER_MAX_SIZE = 4
|
||||
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
|
||||
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
|
||||
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
|
||||
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
|
||||
RETRY_INTEGRITY_TAG_SIZE = 16
|
||||
STATELESS_RESET_TOKEN_SIZE = 16
|
||||
|
||||
|
||||
class QuicErrorCode(IntEnum):
|
||||
NO_ERROR = 0x0
|
||||
INTERNAL_ERROR = 0x1
|
||||
CONNECTION_REFUSED = 0x2
|
||||
FLOW_CONTROL_ERROR = 0x3
|
||||
STREAM_LIMIT_ERROR = 0x4
|
||||
STREAM_STATE_ERROR = 0x5
|
||||
FINAL_SIZE_ERROR = 0x6
|
||||
FRAME_ENCODING_ERROR = 0x7
|
||||
TRANSPORT_PARAMETER_ERROR = 0x8
|
||||
CONNECTION_ID_LIMIT_ERROR = 0x9
|
||||
PROTOCOL_VIOLATION = 0xA
|
||||
INVALID_TOKEN = 0xB
|
||||
APPLICATION_ERROR = 0xC
|
||||
CRYPTO_BUFFER_EXCEEDED = 0xD
|
||||
KEY_UPDATE_ERROR = 0xE
|
||||
AEAD_LIMIT_REACHED = 0xF
|
||||
VERSION_NEGOTIATION_ERROR = 0x11
|
||||
CRYPTO_ERROR = 0x100
|
||||
|
||||
|
||||
class QuicPacketType(Enum):
|
||||
INITIAL = 0
|
||||
ZERO_RTT = 1
|
||||
HANDSHAKE = 2
|
||||
RETRY = 3
|
||||
VERSION_NEGOTIATION = 4
|
||||
ONE_RTT = 5
|
||||
|
||||
|
||||
# For backwards compatibility only, use `QuicPacketType` in new code.
|
||||
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
|
||||
|
||||
# QUIC version 1
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
|
||||
QuicPacketType.INITIAL: 0,
|
||||
QuicPacketType.ZERO_RTT: 1,
|
||||
QuicPacketType.HANDSHAKE: 2,
|
||||
QuicPacketType.RETRY: 3,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_1 = dict(
|
||||
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
|
||||
)
|
||||
|
||||
# QUIC version 2
|
||||
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
|
||||
QuicPacketType.INITIAL: 1,
|
||||
QuicPacketType.ZERO_RTT: 2,
|
||||
QuicPacketType.HANDSHAKE: 3,
|
||||
QuicPacketType.RETRY: 0,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_2 = dict(
|
||||
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
|
||||
)
|
||||
|
||||
|
||||
class QuicProtocolVersion(IntEnum):
|
||||
NEGOTIATION = 0
|
||||
VERSION_1 = 0x00000001
|
||||
VERSION_2 = 0x6B3343CF
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicHeader:
|
||||
version: Optional[int]
|
||||
"The protocol version. Only present in long header packets."
|
||||
|
||||
packet_type: QuicPacketType
|
||||
"The type of the packet."
|
||||
|
||||
packet_length: int
|
||||
"The total length of the packet, in bytes."
|
||||
|
||||
destination_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
source_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
token: bytes
|
||||
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
|
||||
|
||||
integrity_tag: bytes
|
||||
"The retry integrity tag. Only present in `RETRY` packets."
|
||||
|
||||
supported_versions: List[int]
|
||||
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
|
||||
|
||||
|
||||
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
|
||||
"""
|
||||
Recover a packet number from a truncated packet number.
|
||||
|
||||
See: Appendix A - Sample Packet Number Decoding Algorithm
|
||||
"""
|
||||
window = 1 << num_bits
|
||||
half_window = window // 2
|
||||
candidate = (expected & ~(window - 1)) | truncated
|
||||
if candidate <= expected - half_window and candidate < (1 << 62) - window:
|
||||
return candidate + window
|
||||
elif candidate > expected + half_window and candidate >= window:
|
||||
return candidate - window
|
||||
else:
|
||||
return candidate
|
||||
|
||||
|
||||
def get_retry_integrity_tag(
|
||||
packet_without_tag: bytes, original_destination_cid: bytes, version: int
|
||||
) -> bytes:
|
||||
"""
|
||||
Calculate the integrity tag for a RETRY packet.
|
||||
"""
|
||||
# build Retry pseudo packet
|
||||
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
|
||||
buf.push_uint8(len(original_destination_cid))
|
||||
buf.push_bytes(original_destination_cid)
|
||||
buf.push_bytes(packet_without_tag)
|
||||
assert buf.eof()
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_2
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
|
||||
else:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_1
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
|
||||
|
||||
# run AES-128-GCM
|
||||
aead = AESGCM(aead_key)
|
||||
integrity_tag = aead.encrypt(aead_nonce, b"", buf.data)
|
||||
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
|
||||
return integrity_tag
|
||||
|
||||
|
||||
def get_spin_bit(first_byte: int) -> bool:
|
||||
return bool(first_byte & PACKET_SPIN_BIT)
|
||||
|
||||
|
||||
def is_long_header(first_byte: int) -> bool:
|
||||
return bool(first_byte & PACKET_LONG_HEADER)
|
||||
|
||||
|
||||
def pretty_protocol_version(version: int) -> str:
|
||||
"""
|
||||
Return a user-friendly representation of a protocol version.
|
||||
"""
|
||||
try:
|
||||
version_name = QuicProtocolVersion(version).name
|
||||
except ValueError:
|
||||
version_name = "UNKNOWN"
|
||||
return f"0x{version:08x} ({version_name})"
|
||||
|
||||
|
||||
def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader:
|
||||
packet_start = buf.tell()
|
||||
|
||||
version = None
|
||||
integrity_tag = b""
|
||||
supported_versions = []
|
||||
token = b""
|
||||
|
||||
first_byte = buf.pull_uint8()
|
||||
if is_long_header(first_byte):
|
||||
# Long Header Packets.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
version = buf.pull_uint32()
|
||||
|
||||
destination_cid_length = buf.pull_uint8()
|
||||
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError(
|
||||
"Destination CID is too long (%d bytes)" % destination_cid_length
|
||||
)
|
||||
destination_cid = buf.pull_bytes(destination_cid_length)
|
||||
|
||||
source_cid_length = buf.pull_uint8()
|
||||
if source_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError("Source CID is too long (%d bytes)" % source_cid_length)
|
||||
source_cid = buf.pull_bytes(source_cid_length)
|
||||
|
||||
if version == QuicProtocolVersion.NEGOTIATION:
|
||||
# Version Negotiation Packet.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
|
||||
packet_type = QuicPacketType.VERSION_NEGOTIATION
|
||||
while not buf.eof():
|
||||
supported_versions.append(buf.pull_uint32())
|
||||
packet_end = buf.tell()
|
||||
else:
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
else:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = buf.pull_uint_var()
|
||||
token = buf.pull_bytes(token_length)
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.ZERO_RTT:
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
rest_length = buf.pull_uint_var()
|
||||
else:
|
||||
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
|
||||
token = buf.pull_bytes(token_length)
|
||||
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
|
||||
rest_length = 0
|
||||
|
||||
# Check remainder length.
|
||||
packet_end = buf.tell() + rest_length
|
||||
if packet_end > buf.capacity:
|
||||
raise ValueError("Packet payload is truncated")
|
||||
|
||||
else:
|
||||
# Short Header Packets.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
version = None
|
||||
packet_type = QuicPacketType.ONE_RTT
|
||||
destination_cid = buf.pull_bytes(host_cid_length)
|
||||
source_cid = b""
|
||||
packet_end = buf.capacity
|
||||
|
||||
return QuicHeader(
|
||||
version=version,
|
||||
packet_type=packet_type,
|
||||
packet_length=packet_end - packet_start,
|
||||
destination_cid=destination_cid,
|
||||
source_cid=source_cid,
|
||||
token=token,
|
||||
integrity_tag=integrity_tag,
|
||||
supported_versions=supported_versions,
|
||||
)
|
||||
|
||||
|
||||
def encode_long_header_first_byte(
|
||||
version: int, packet_type: QuicPacketType, bits: int
|
||||
) -> int:
|
||||
"""
|
||||
Encode the first byte of a long header packet.
|
||||
"""
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
|
||||
else:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
|
||||
return (
|
||||
PACKET_LONG_HEADER
|
||||
| PACKET_FIXED_BIT
|
||||
| long_type_encode[packet_type] << 4
|
||||
| bits
|
||||
)
|
||||
|
||||
|
||||
def encode_quic_retry(
|
||||
version: int,
|
||||
source_cid: bytes,
|
||||
destination_cid: bytes,
|
||||
original_destination_cid: bytes,
|
||||
retry_token: bytes,
|
||||
unused: int = 0,
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ len(retry_token)
|
||||
+ RETRY_INTEGRITY_TAG_SIZE
|
||||
)
|
||||
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
|
||||
buf.push_uint32(version)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
buf.push_bytes(retry_token)
|
||||
buf.push_bytes(
|
||||
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
|
||||
)
|
||||
assert buf.eof()
|
||||
return buf.data
|
||||
|
||||
|
||||
def encode_quic_version_negotiation(
|
||||
source_cid: bytes, destination_cid: bytes, supported_versions: List[int]
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ 4 * len(supported_versions)
|
||||
)
|
||||
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
|
||||
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
for version in supported_versions:
|
||||
buf.push_uint32(version)
|
||||
return buf.data
|
||||
|
||||
|
||||
# TLS EXTENSION
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicPreferredAddress:
|
||||
ipv4_address: Optional[Tuple[str, int]]
|
||||
ipv6_address: Optional[Tuple[str, int]]
|
||||
connection_id: bytes
|
||||
stateless_reset_token: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicVersionInformation:
|
||||
chosen_version: int
|
||||
available_versions: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicTransportParameters:
|
||||
original_destination_connection_id: Optional[bytes] = None
|
||||
max_idle_timeout: Optional[int] = None
|
||||
stateless_reset_token: Optional[bytes] = None
|
||||
max_udp_payload_size: Optional[int] = None
|
||||
initial_max_data: Optional[int] = None
|
||||
initial_max_stream_data_bidi_local: Optional[int] = None
|
||||
initial_max_stream_data_bidi_remote: Optional[int] = None
|
||||
initial_max_stream_data_uni: Optional[int] = None
|
||||
initial_max_streams_bidi: Optional[int] = None
|
||||
initial_max_streams_uni: Optional[int] = None
|
||||
ack_delay_exponent: Optional[int] = None
|
||||
max_ack_delay: Optional[int] = None
|
||||
disable_active_migration: Optional[bool] = False
|
||||
preferred_address: Optional[QuicPreferredAddress] = None
|
||||
active_connection_id_limit: Optional[int] = None
|
||||
initial_source_connection_id: Optional[bytes] = None
|
||||
retry_source_connection_id: Optional[bytes] = None
|
||||
version_information: Optional[QuicVersionInformation] = None
|
||||
max_datagram_frame_size: Optional[int] = None
|
||||
quantum_readiness: Optional[bytes] = None
|
||||
|
||||
|
||||
PARAMS = {
|
||||
0x00: ("original_destination_connection_id", bytes),
|
||||
0x01: ("max_idle_timeout", int),
|
||||
0x02: ("stateless_reset_token", bytes),
|
||||
0x03: ("max_udp_payload_size", int),
|
||||
0x04: ("initial_max_data", int),
|
||||
0x05: ("initial_max_stream_data_bidi_local", int),
|
||||
0x06: ("initial_max_stream_data_bidi_remote", int),
|
||||
0x07: ("initial_max_stream_data_uni", int),
|
||||
0x08: ("initial_max_streams_bidi", int),
|
||||
0x09: ("initial_max_streams_uni", int),
|
||||
0x0A: ("ack_delay_exponent", int),
|
||||
0x0B: ("max_ack_delay", int),
|
||||
0x0C: ("disable_active_migration", bool),
|
||||
0x0D: ("preferred_address", QuicPreferredAddress),
|
||||
0x0E: ("active_connection_id_limit", int),
|
||||
0x0F: ("initial_source_connection_id", bytes),
|
||||
0x10: ("retry_source_connection_id", bytes),
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
|
||||
0x11: ("version_information", QuicVersionInformation),
|
||||
# extensions
|
||||
0x0020: ("max_datagram_frame_size", int),
|
||||
0x0C37: ("quantum_readiness", bytes),
|
||||
}
|
||||
|
||||
|
||||
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
|
||||
ipv4_address = None
|
||||
ipv4_host = buf.pull_bytes(4)
|
||||
ipv4_port = buf.pull_uint16()
|
||||
if ipv4_host != bytes(4):
|
||||
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
|
||||
|
||||
ipv6_address = None
|
||||
ipv6_host = buf.pull_bytes(16)
|
||||
ipv6_port = buf.pull_uint16()
|
||||
if ipv6_host != bytes(16):
|
||||
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
|
||||
|
||||
connection_id_length = buf.pull_uint8()
|
||||
connection_id = buf.pull_bytes(connection_id_length)
|
||||
stateless_reset_token = buf.pull_bytes(16)
|
||||
|
||||
return QuicPreferredAddress(
|
||||
ipv4_address=ipv4_address,
|
||||
ipv6_address=ipv6_address,
|
||||
connection_id=connection_id,
|
||||
stateless_reset_token=stateless_reset_token,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_preferred_address(
|
||||
buf: Buffer, preferred_address: QuicPreferredAddress
|
||||
) -> None:
|
||||
if preferred_address.ipv4_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv4_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(6))
|
||||
|
||||
if preferred_address.ipv6_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv6_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(18))
|
||||
|
||||
buf.push_uint8(len(preferred_address.connection_id))
|
||||
buf.push_bytes(preferred_address.connection_id)
|
||||
buf.push_bytes(preferred_address.stateless_reset_token)
|
||||
|
||||
|
||||
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
|
||||
chosen_version = buf.pull_uint32()
|
||||
available_versions = []
|
||||
for i in range(length // 4 - 1):
|
||||
available_versions.append(buf.pull_uint32())
|
||||
|
||||
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
|
||||
# equal to zero, it MUST treat it as a parsing failure.
|
||||
#
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
|
||||
if chosen_version == 0 or 0 in available_versions:
|
||||
raise ValueError("Version Information must not contain version 0")
|
||||
|
||||
return QuicVersionInformation(
|
||||
chosen_version=chosen_version,
|
||||
available_versions=available_versions,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_version_information(
|
||||
buf: Buffer, version_information: QuicVersionInformation
|
||||
) -> None:
|
||||
buf.push_uint32(version_information.chosen_version)
|
||||
for version in version_information.available_versions:
|
||||
buf.push_uint32(version)
|
||||
|
||||
|
||||
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
|
||||
params = QuicTransportParameters()
|
||||
while not buf.eof():
|
||||
param_id = buf.pull_uint_var()
|
||||
param_len = buf.pull_uint_var()
|
||||
param_start = buf.tell()
|
||||
if param_id in PARAMS:
|
||||
# Parse known parameter.
|
||||
param_name, param_type = PARAMS[param_id]
|
||||
if param_type is int:
|
||||
setattr(params, param_name, buf.pull_uint_var())
|
||||
elif param_type is bytes:
|
||||
setattr(params, param_name, buf.pull_bytes(param_len))
|
||||
elif param_type is QuicPreferredAddress:
|
||||
setattr(params, param_name, pull_quic_preferred_address(buf))
|
||||
elif param_type is QuicVersionInformation:
|
||||
setattr(
|
||||
params,
|
||||
param_name,
|
||||
pull_quic_version_information(buf, param_len),
|
||||
)
|
||||
else:
|
||||
setattr(params, param_name, True)
|
||||
else:
|
||||
# Skip unknown parameter.
|
||||
buf.pull_bytes(param_len)
|
||||
|
||||
if buf.tell() != param_start + param_len:
|
||||
raise ValueError("Transport parameter length does not match")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def push_quic_transport_parameters(
|
||||
buf: Buffer, params: QuicTransportParameters
|
||||
) -> None:
|
||||
for param_id, (param_name, param_type) in PARAMS.items():
|
||||
param_value = getattr(params, param_name)
|
||||
if param_value is not None and param_value is not False:
|
||||
param_buf = Buffer(capacity=65536)
|
||||
if param_type is int:
|
||||
param_buf.push_uint_var(param_value)
|
||||
elif param_type is bytes:
|
||||
param_buf.push_bytes(param_value)
|
||||
elif param_type is QuicPreferredAddress:
|
||||
push_quic_preferred_address(param_buf, param_value)
|
||||
elif param_type is QuicVersionInformation:
|
||||
push_quic_version_information(param_buf, param_value)
|
||||
buf.push_uint_var(param_id)
|
||||
buf.push_uint_var(param_buf.tell())
|
||||
buf.push_bytes(param_buf.data)
|
||||
|
||||
|
||||
# FRAMES
|
||||
|
||||
|
||||
class QuicFrameType(IntEnum):
|
||||
PADDING = 0x00
|
||||
PING = 0x01
|
||||
ACK = 0x02
|
||||
ACK_ECN = 0x03
|
||||
RESET_STREAM = 0x04
|
||||
STOP_SENDING = 0x05
|
||||
CRYPTO = 0x06
|
||||
NEW_TOKEN = 0x07
|
||||
STREAM_BASE = 0x08
|
||||
MAX_DATA = 0x10
|
||||
MAX_STREAM_DATA = 0x11
|
||||
MAX_STREAMS_BIDI = 0x12
|
||||
MAX_STREAMS_UNI = 0x13
|
||||
DATA_BLOCKED = 0x14
|
||||
STREAM_DATA_BLOCKED = 0x15
|
||||
STREAMS_BLOCKED_BIDI = 0x16
|
||||
STREAMS_BLOCKED_UNI = 0x17
|
||||
NEW_CONNECTION_ID = 0x18
|
||||
RETIRE_CONNECTION_ID = 0x19
|
||||
PATH_CHALLENGE = 0x1A
|
||||
PATH_RESPONSE = 0x1B
|
||||
TRANSPORT_CLOSE = 0x1C
|
||||
APPLICATION_CLOSE = 0x1D
|
||||
HANDSHAKE_DONE = 0x1E
|
||||
DATAGRAM = 0x30
|
||||
DATAGRAM_WITH_LENGTH = 0x31
|
||||
|
||||
|
||||
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
|
||||
PROBING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.PATH_CHALLENGE,
|
||||
QuicFrameType.PATH_RESPONSE,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.NEW_CONNECTION_ID,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicResetStreamFrame:
|
||||
error_code: int
|
||||
final_size: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStopSendingFrame:
|
||||
error_code: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStreamFrame:
|
||||
data: bytes = b""
|
||||
fin: bool = False
|
||||
offset: int = 0
|
||||
|
||||
|
||||
def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]:
|
||||
rangeset = RangeSet()
|
||||
end = buf.pull_uint_var() # largest acknowledged
|
||||
delay = buf.pull_uint_var()
|
||||
ack_range_count = buf.pull_uint_var()
|
||||
ack_count = buf.pull_uint_var() # first ack range
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
for _ in range(ack_range_count):
|
||||
end -= buf.pull_uint_var() + 2
|
||||
ack_count = buf.pull_uint_var()
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
return rangeset, delay
|
||||
|
||||
|
||||
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
|
||||
ranges = len(rangeset)
|
||||
index = ranges - 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(r.stop - 1)
|
||||
buf.push_uint_var(delay)
|
||||
buf.push_uint_var(index)
|
||||
buf.push_uint_var(r.stop - 1 - r.start)
|
||||
start = r.start
|
||||
while index > 0:
|
||||
index -= 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(start - r.stop - 1)
|
||||
buf.push_uint_var(r.stop - r.start - 1)
|
||||
start = r.start
|
||||
return ranges
|
||||
@ -0,0 +1,384 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..buffer import Buffer, size_uint_var
|
||||
from ..tls import Epoch
|
||||
from .crypto import CryptoPair
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet import (
|
||||
NON_ACK_ELICITING_FRAME_TYPES,
|
||||
NON_IN_FLIGHT_FRAME_TYPES,
|
||||
PACKET_FIXED_BIT,
|
||||
PACKET_NUMBER_MAX_SIZE,
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
encode_long_header_first_byte,
|
||||
)
|
||||
|
||||
PACKET_LENGTH_SEND_SIZE = 2
|
||||
PACKET_NUMBER_SEND_SIZE = 2
|
||||
|
||||
|
||||
QuicDeliveryHandler = Callable[..., None]
|
||||
|
||||
|
||||
class QuicDeliveryState(Enum):
|
||||
ACKED = 0
|
||||
LOST = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicSentPacket:
|
||||
epoch: Epoch
|
||||
in_flight: bool
|
||||
is_ack_eliciting: bool
|
||||
is_crypto_packet: bool
|
||||
packet_number: int
|
||||
packet_type: QuicPacketType
|
||||
sent_time: Optional[float] = None
|
||||
sent_bytes: int = 0
|
||||
|
||||
delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
quic_logger_frames: List[Dict] = field(default_factory=list)
|
||||
|
||||
|
||||
class QuicPacketBuilderStop(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicPacketBuilder:
|
||||
"""
|
||||
Helper for building QUIC packets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_cid: bytes,
|
||||
peer_cid: bytes,
|
||||
version: int,
|
||||
is_client: bool,
|
||||
max_datagram_size: int,
|
||||
packet_number: int = 0,
|
||||
peer_token: bytes = b"",
|
||||
quic_logger: Optional[QuicLoggerTrace] = None,
|
||||
spin_bit: bool = False,
|
||||
):
|
||||
self.max_flight_bytes: Optional[int] = None
|
||||
self.max_total_bytes: Optional[int] = None
|
||||
self.quic_logger_frames: Optional[List[Dict]] = None
|
||||
|
||||
self._host_cid = host_cid
|
||||
self._is_client = is_client
|
||||
self._peer_cid = peer_cid
|
||||
self._peer_token = peer_token
|
||||
self._quic_logger = quic_logger
|
||||
self._spin_bit = spin_bit
|
||||
self._version = version
|
||||
|
||||
# assembled datagrams and packets
|
||||
self._datagrams: List[bytes] = []
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = True
|
||||
self._datagram_needs_padding = False
|
||||
self._packets: List[QuicSentPacket] = []
|
||||
self._flight_bytes = 0
|
||||
self._total_bytes = 0
|
||||
|
||||
# current packet
|
||||
self._header_size = 0
|
||||
self._packet: Optional[QuicSentPacket] = None
|
||||
self._packet_crypto: Optional[CryptoPair] = None
|
||||
self._packet_number = packet_number
|
||||
self._packet_start = 0
|
||||
self._packet_type: Optional[QuicPacketType] = None
|
||||
|
||||
self._buffer = Buffer(max_datagram_size)
|
||||
self._buffer_capacity = max_datagram_size
|
||||
self._flight_capacity = max_datagram_size
|
||||
|
||||
@property
|
||||
def packet_is_empty(self) -> bool:
|
||||
"""
|
||||
Returns `True` if the current packet is empty.
|
||||
"""
|
||||
assert self._packet is not None
|
||||
packet_size = self._buffer.tell() - self._packet_start
|
||||
return packet_size <= self._header_size
|
||||
|
||||
@property
|
||||
def packet_number(self) -> int:
|
||||
"""
|
||||
Returns the packet number for the next packet.
|
||||
"""
|
||||
return self._packet_number
|
||||
|
||||
@property
|
||||
def remaining_buffer_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._buffer_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_flight_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._flight_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]:
|
||||
"""
|
||||
Returns the assembled datagrams.
|
||||
"""
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
self._flush_current_datagram()
|
||||
|
||||
datagrams = self._datagrams
|
||||
packets = self._packets
|
||||
self._datagrams = []
|
||||
self._packets = []
|
||||
return datagrams, packets
|
||||
|
||||
def start_frame(
|
||||
self,
|
||||
frame_type: int,
|
||||
capacity: int = 1,
|
||||
handler: Optional[QuicDeliveryHandler] = None,
|
||||
handler_args: Sequence[Any] = [],
|
||||
) -> Buffer:
|
||||
"""
|
||||
Starts a new frame.
|
||||
"""
|
||||
if self.remaining_buffer_space < capacity or (
|
||||
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
|
||||
and self.remaining_flight_space < capacity
|
||||
):
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
self._buffer.push_uint_var(frame_type)
|
||||
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
|
||||
self._packet.is_ack_eliciting = True
|
||||
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
|
||||
self._packet.in_flight = True
|
||||
if frame_type == QuicFrameType.CRYPTO:
|
||||
self._packet.is_crypto_packet = True
|
||||
if handler is not None:
|
||||
self._packet.delivery_handlers.append((handler, handler_args))
|
||||
return self._buffer
|
||||
|
||||
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
|
||||
"""
|
||||
Starts a new packet.
|
||||
"""
|
||||
assert packet_type in (
|
||||
QuicPacketType.INITIAL,
|
||||
QuicPacketType.HANDSHAKE,
|
||||
QuicPacketType.ZERO_RTT,
|
||||
QuicPacketType.ONE_RTT,
|
||||
), "Invalid packet type"
|
||||
buf = self._buffer
|
||||
|
||||
# finish previous datagram
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
|
||||
# if there is too little space remaining, start a new datagram
|
||||
# FIXME: the limit is arbitrary!
|
||||
packet_start = buf.tell()
|
||||
if self._buffer_capacity - packet_start < 128:
|
||||
self._flush_current_datagram()
|
||||
packet_start = 0
|
||||
|
||||
# initialize datagram if needed
|
||||
if self._datagram_init:
|
||||
if self.max_total_bytes is not None:
|
||||
remaining_total_bytes = self.max_total_bytes - self._total_bytes
|
||||
if remaining_total_bytes < self._buffer_capacity:
|
||||
self._buffer_capacity = remaining_total_bytes
|
||||
|
||||
self._flight_capacity = self._buffer_capacity
|
||||
if self.max_flight_bytes is not None:
|
||||
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
|
||||
if remaining_flight_bytes < self._flight_capacity:
|
||||
self._flight_capacity = remaining_flight_bytes
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = False
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# calculate header size
|
||||
if packet_type != QuicPacketType.ONE_RTT:
|
||||
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = len(self._peer_token)
|
||||
header_size += size_uint_var(token_length) + token_length
|
||||
else:
|
||||
header_size = 3 + len(self._peer_cid)
|
||||
|
||||
# check we have enough space
|
||||
if packet_start + header_size >= self._buffer_capacity:
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
# determine ack epoch
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
epoch = Epoch.INITIAL
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
epoch = Epoch.HANDSHAKE
|
||||
else:
|
||||
epoch = Epoch.ONE_RTT
|
||||
|
||||
self._header_size = header_size
|
||||
self._packet = QuicSentPacket(
|
||||
epoch=epoch,
|
||||
in_flight=False,
|
||||
is_ack_eliciting=False,
|
||||
is_crypto_packet=False,
|
||||
packet_number=self._packet_number,
|
||||
packet_type=packet_type,
|
||||
)
|
||||
self._packet_crypto = crypto
|
||||
self._packet_start = packet_start
|
||||
self._packet_type = packet_type
|
||||
self.quic_logger_frames = self._packet.quic_logger_frames
|
||||
|
||||
buf.seek(self._packet_start + self._header_size)
|
||||
|
||||
def _end_packet(self) -> None:
|
||||
"""
|
||||
Ends the current packet.
|
||||
"""
|
||||
buf = self._buffer
|
||||
packet_size = buf.tell() - self._packet_start
|
||||
if packet_size > self._header_size:
|
||||
# padding to ensure sufficient sample size
|
||||
padding_size = (
|
||||
PACKET_NUMBER_MAX_SIZE
|
||||
- PACKET_NUMBER_SEND_SIZE
|
||||
+ self._header_size
|
||||
- packet_size
|
||||
)
|
||||
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if (
|
||||
self._is_client or self._packet.is_ack_eliciting
|
||||
) and self._packet_type == QuicPacketType.INITIAL:
|
||||
self._datagram_needs_padding = True
|
||||
|
||||
# For datagrams containing 1-RTT data, we *must* apply the padding
|
||||
# inside the packet, we cannot tack bytes onto the end of the
|
||||
# datagram.
|
||||
if (
|
||||
self._datagram_needs_padding
|
||||
and self._packet_type == QuicPacketType.ONE_RTT
|
||||
):
|
||||
if self.remaining_flight_space > padding_size:
|
||||
padding_size = self.remaining_flight_space
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# write padding
|
||||
if padding_size > 0:
|
||||
buf.push_bytes(bytes(padding_size))
|
||||
packet_size += padding_size
|
||||
self._packet.in_flight = True
|
||||
|
||||
# log frame
|
||||
if self._quic_logger is not None:
|
||||
self._packet.quic_logger_frames.append(
|
||||
self._quic_logger.encode_padding_frame()
|
||||
)
|
||||
|
||||
# write header
|
||||
if self._packet_type != QuicPacketType.ONE_RTT:
|
||||
length = (
|
||||
packet_size
|
||||
- self._header_size
|
||||
+ PACKET_NUMBER_SEND_SIZE
|
||||
+ self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
encode_long_header_first_byte(
|
||||
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
|
||||
)
|
||||
)
|
||||
buf.push_uint32(self._version)
|
||||
buf.push_uint8(len(self._peer_cid))
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint8(len(self._host_cid))
|
||||
buf.push_bytes(self._host_cid)
|
||||
if self._packet_type == QuicPacketType.INITIAL:
|
||||
buf.push_uint_var(len(self._peer_token))
|
||||
buf.push_bytes(self._peer_token)
|
||||
buf.push_uint16(length | 0x4000)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
else:
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
PACKET_FIXED_BIT
|
||||
| (self._spin_bit << 5)
|
||||
| (self._packet_crypto.key_phase << 2)
|
||||
| (PACKET_NUMBER_SEND_SIZE - 1)
|
||||
)
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
|
||||
# encrypt in place
|
||||
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_bytes(
|
||||
self._packet_crypto.encrypt_packet(
|
||||
plain[0 : self._header_size],
|
||||
plain[self._header_size : packet_size],
|
||||
self._packet_number,
|
||||
)
|
||||
)
|
||||
self._packet.sent_bytes = buf.tell() - self._packet_start
|
||||
self._packets.append(self._packet)
|
||||
if self._packet.in_flight:
|
||||
self._datagram_flight_bytes += self._packet.sent_bytes
|
||||
|
||||
# Short header packets cannot be coalesced, we need a new datagram.
|
||||
if self._packet_type == QuicPacketType.ONE_RTT:
|
||||
self._flush_current_datagram()
|
||||
|
||||
self._packet_number += 1
|
||||
else:
|
||||
# "cancel" the packet
|
||||
buf.seek(self._packet_start)
|
||||
|
||||
self._packet = None
|
||||
self.quic_logger_frames = None
|
||||
|
||||
def _flush_current_datagram(self) -> None:
|
||||
datagram_bytes = self._buffer.tell()
|
||||
if datagram_bytes:
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if self._datagram_needs_padding:
|
||||
extra_bytes = self._flight_capacity - self._buffer.tell()
|
||||
if extra_bytes > 0:
|
||||
self._buffer.push_bytes(bytes(extra_bytes))
|
||||
self._datagram_flight_bytes += extra_bytes
|
||||
datagram_bytes += extra_bytes
|
||||
|
||||
self._datagrams.append(self._buffer.data)
|
||||
self._flight_bytes += self._datagram_flight_bytes
|
||||
self._total_bytes += datagram_bytes
|
||||
self._datagram_init = True
|
||||
self._buffer.seek(0)
|
||||
@ -0,0 +1,98 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
class RangeSet(Sequence):
|
||||
def __init__(self, ranges: Iterable[range] = []):
|
||||
self.__ranges: List[range] = []
|
||||
for r in ranges:
|
||||
assert r.step == 1
|
||||
self.add(r.start, r.stop)
|
||||
|
||||
def add(self, start: int, stop: Optional[int] = None) -> None:
|
||||
if stop is None:
|
||||
stop = start + 1
|
||||
assert stop > start
|
||||
|
||||
for i, r in enumerate(self.__ranges):
|
||||
# the added range is entirely before current item, insert here
|
||||
if stop < r.start:
|
||||
self.__ranges.insert(i, range(start, stop))
|
||||
return
|
||||
|
||||
# the added range is entirely after current item, keep looking
|
||||
if start > r.stop:
|
||||
continue
|
||||
|
||||
# the added range touches the current item, merge it
|
||||
start = min(start, r.start)
|
||||
stop = max(stop, r.stop)
|
||||
while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop:
|
||||
stop = max(self.__ranges[i + 1].stop, stop)
|
||||
self.__ranges.pop(i + 1)
|
||||
self.__ranges[i] = range(start, stop)
|
||||
return
|
||||
|
||||
# the added range is entirely after all existing items, append it
|
||||
self.__ranges.append(range(start, stop))
|
||||
|
||||
def bounds(self) -> range:
|
||||
return range(self.__ranges[0].start, self.__ranges[-1].stop)
|
||||
|
||||
def shift(self) -> range:
|
||||
return self.__ranges.pop(0)
|
||||
|
||||
def subtract(self, start: int, stop: int) -> None:
|
||||
assert stop > start
|
||||
|
||||
i = 0
|
||||
while i < len(self.__ranges):
|
||||
r = self.__ranges[i]
|
||||
|
||||
# the removed range is entirely before current item, stop here
|
||||
if stop <= r.start:
|
||||
return
|
||||
|
||||
# the removed range is entirely after current item, keep looking
|
||||
if start >= r.stop:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# the removed range completely covers the current item, remove it
|
||||
if start <= r.start and stop >= r.stop:
|
||||
self.__ranges.pop(i)
|
||||
continue
|
||||
|
||||
# the removed range touches the current item
|
||||
if start > r.start:
|
||||
self.__ranges[i] = range(r.start, start)
|
||||
if stop < r.stop:
|
||||
self.__ranges.insert(i + 1, range(stop, r.stop))
|
||||
else:
|
||||
self.__ranges[i] = range(stop, r.stop)
|
||||
|
||||
i += 1
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __contains__(self, val: Any) -> bool:
|
||||
for r in self.__ranges:
|
||||
if val in r:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, RangeSet):
|
||||
return NotImplemented
|
||||
|
||||
return self.__ranges == other.__ranges
|
||||
|
||||
def __getitem__(self, key: Any) -> range:
|
||||
return self.__ranges[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__ranges)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "RangeSet({})".format(repr(self.__ranges))
|
||||
389
Code/venv/lib/python3.13/site-packages/aioquic/quic/recovery.py
Normal file
389
Code/venv/lib/python3.13/site-packages/aioquic/quic/recovery.py
Normal file
@ -0,0 +1,389 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from .congestion import cubic, reno # noqa
|
||||
from .congestion.base import K_GRANULARITY, create_congestion_control
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet_builder import QuicDeliveryState, QuicSentPacket
|
||||
from .rangeset import RangeSet
|
||||
|
||||
# loss detection
|
||||
K_PACKET_THRESHOLD = 3
|
||||
K_TIME_THRESHOLD = 9 / 8
|
||||
K_MICRO_SECOND = 0.000001
|
||||
K_SECOND = 1.0
|
||||
|
||||
|
||||
class QuicPacketSpace:
|
||||
def __init__(self) -> None:
|
||||
self.ack_at: Optional[float] = None
|
||||
self.ack_queue = RangeSet()
|
||||
self.discarded = False
|
||||
self.expected_packet_number = 0
|
||||
self.largest_received_packet = -1
|
||||
self.largest_received_time: Optional[float] = None
|
||||
|
||||
# sent packets and loss
|
||||
self.ack_eliciting_in_flight = 0
|
||||
self.largest_acked_packet = 0
|
||||
self.loss_time: Optional[float] = None
|
||||
self.sent_packets: Dict[int, QuicSentPacket] = {}
|
||||
|
||||
|
||||
class QuicPacketPacer:
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self.bucket_max: float = 0.0
|
||||
self.bucket_time: float = 0.0
|
||||
self.evaluation_time: float = 0.0
|
||||
self.packet_time: Optional[float] = None
|
||||
|
||||
def next_send_time(self, now: float) -> float:
|
||||
if self.packet_time is not None:
|
||||
self.update_bucket(now=now)
|
||||
if self.bucket_time <= 0:
|
||||
return now + self.packet_time
|
||||
return None
|
||||
|
||||
def update_after_send(self, now: float) -> None:
|
||||
if self.packet_time is not None:
|
||||
self.update_bucket(now=now)
|
||||
if self.bucket_time < self.packet_time:
|
||||
self.bucket_time = 0.0
|
||||
else:
|
||||
self.bucket_time -= self.packet_time
|
||||
|
||||
def update_bucket(self, now: float) -> None:
|
||||
if now > self.evaluation_time:
|
||||
self.bucket_time = min(
|
||||
self.bucket_time + (now - self.evaluation_time), self.bucket_max
|
||||
)
|
||||
self.evaluation_time = now
|
||||
|
||||
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None:
|
||||
pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND)
|
||||
self.packet_time = max(
|
||||
K_MICRO_SECOND, min(self._max_datagram_size / pacing_rate, K_SECOND)
|
||||
)
|
||||
|
||||
self.bucket_max = (
|
||||
max(
|
||||
2 * self._max_datagram_size,
|
||||
min(congestion_window // 4, 16 * self._max_datagram_size),
|
||||
)
|
||||
/ pacing_rate
|
||||
)
|
||||
if self.bucket_time > self.bucket_max:
|
||||
self.bucket_time = self.bucket_max
|
||||
|
||||
|
||||
class QuicPacketRecovery:
|
||||
"""
|
||||
Packet loss and congestion controller.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
congestion_control_algorithm: str,
|
||||
initial_rtt: float,
|
||||
max_datagram_size: int,
|
||||
peer_completed_address_validation: bool,
|
||||
send_probe: Callable[[], None],
|
||||
logger: Optional[logging.LoggerAdapter] = None,
|
||||
quic_logger: Optional[QuicLoggerTrace] = None,
|
||||
) -> None:
|
||||
self.max_ack_delay = 0.025
|
||||
self.peer_completed_address_validation = peer_completed_address_validation
|
||||
self.spaces: List[QuicPacketSpace] = []
|
||||
|
||||
# callbacks
|
||||
self._logger = logger
|
||||
self._quic_logger = quic_logger
|
||||
self._send_probe = send_probe
|
||||
|
||||
# loss detection
|
||||
self._pto_count = 0
|
||||
self._rtt_initial = initial_rtt
|
||||
self._rtt_initialized = False
|
||||
self._rtt_latest = 0.0
|
||||
self._rtt_min = math.inf
|
||||
self._rtt_smoothed = 0.0
|
||||
self._rtt_variance = 0.0
|
||||
self._time_of_last_sent_ack_eliciting_packet = 0.0
|
||||
|
||||
# congestion control
|
||||
self._cc = create_congestion_control(
|
||||
congestion_control_algorithm, max_datagram_size=max_datagram_size
|
||||
)
|
||||
self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size)
|
||||
|
||||
@property
|
||||
def bytes_in_flight(self) -> int:
|
||||
return self._cc.bytes_in_flight
|
||||
|
||||
@property
|
||||
def congestion_window(self) -> int:
|
||||
return self._cc.congestion_window
|
||||
|
||||
def discard_space(self, space: QuicPacketSpace) -> None:
|
||||
assert space in self.spaces
|
||||
|
||||
self._cc.on_packets_expired(
|
||||
packets=filter(lambda x: x.in_flight, space.sent_packets.values())
|
||||
)
|
||||
space.sent_packets.clear()
|
||||
|
||||
space.ack_at = None
|
||||
space.ack_eliciting_in_flight = 0
|
||||
space.loss_time = None
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def get_loss_detection_time(self) -> float:
|
||||
# loss timer
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
return loss_space.loss_time
|
||||
|
||||
# packet timer
|
||||
if (
|
||||
not self.peer_completed_address_validation
|
||||
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
|
||||
):
|
||||
timeout = self.get_probe_timeout() * (2**self._pto_count)
|
||||
return self._time_of_last_sent_ack_eliciting_packet + timeout
|
||||
|
||||
return None
|
||||
|
||||
def get_probe_timeout(self) -> float:
|
||||
if not self._rtt_initialized:
|
||||
return 2 * self._rtt_initial
|
||||
return (
|
||||
self._rtt_smoothed
|
||||
+ max(4 * self._rtt_variance, K_GRANULARITY)
|
||||
+ self.max_ack_delay
|
||||
)
|
||||
|
||||
def on_ack_received(
|
||||
self,
|
||||
*,
|
||||
ack_rangeset: RangeSet,
|
||||
ack_delay: float,
|
||||
now: float,
|
||||
space: QuicPacketSpace,
|
||||
) -> None:
|
||||
"""
|
||||
Update metrics as the result of an ACK being received.
|
||||
"""
|
||||
is_ack_eliciting = False
|
||||
largest_acked = ack_rangeset.bounds().stop - 1
|
||||
largest_newly_acked = None
|
||||
largest_sent_time = None
|
||||
|
||||
if largest_acked > space.largest_acked_packet:
|
||||
space.largest_acked_packet = largest_acked
|
||||
|
||||
for packet_number in sorted(space.sent_packets.keys()):
|
||||
if packet_number > largest_acked:
|
||||
break
|
||||
if packet_number in ack_rangeset:
|
||||
# remove packet and update counters
|
||||
packet = space.sent_packets.pop(packet_number)
|
||||
if packet.is_ack_eliciting:
|
||||
is_ack_eliciting = True
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
if packet.in_flight:
|
||||
self._cc.on_packet_acked(packet=packet, now=now)
|
||||
largest_newly_acked = packet_number
|
||||
largest_sent_time = packet.sent_time
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.ACKED, *args)
|
||||
|
||||
# nothing to do if there are no newly acked packets
|
||||
if largest_newly_acked is None:
|
||||
return
|
||||
|
||||
if largest_acked == largest_newly_acked and is_ack_eliciting:
|
||||
latest_rtt = now - largest_sent_time
|
||||
log_rtt = True
|
||||
|
||||
# limit ACK delay to max_ack_delay
|
||||
ack_delay = min(ack_delay, self.max_ack_delay)
|
||||
|
||||
# update RTT estimate, which cannot be < 1 ms
|
||||
self._rtt_latest = max(latest_rtt, 0.001)
|
||||
if self._rtt_latest < self._rtt_min:
|
||||
self._rtt_min = self._rtt_latest
|
||||
if self._rtt_latest > self._rtt_min + ack_delay:
|
||||
self._rtt_latest -= ack_delay
|
||||
|
||||
if not self._rtt_initialized:
|
||||
self._rtt_initialized = True
|
||||
self._rtt_variance = latest_rtt / 2
|
||||
self._rtt_smoothed = latest_rtt
|
||||
else:
|
||||
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
|
||||
self._rtt_min - self._rtt_latest
|
||||
)
|
||||
self._rtt_smoothed = (
|
||||
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
|
||||
)
|
||||
|
||||
# inform congestion controller
|
||||
self._cc.on_rtt_measurement(now=now, rtt=latest_rtt)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
|
||||
else:
|
||||
log_rtt = False
|
||||
|
||||
self._detect_loss(now=now, space=space)
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated(log_rtt=log_rtt)
|
||||
|
||||
def on_loss_detection_timeout(self, *, now: float) -> None:
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
self._detect_loss(now=now, space=loss_space)
|
||||
else:
|
||||
self._pto_count += 1
|
||||
self.reschedule_data(now=now)
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
|
||||
space.sent_packets[packet.packet_number] = packet
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight += 1
|
||||
if packet.in_flight:
|
||||
if packet.is_ack_eliciting:
|
||||
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
|
||||
|
||||
# add packet to bytes in flight
|
||||
self._cc.on_packet_sent(packet=packet)
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def reschedule_data(self, *, now: float) -> None:
|
||||
"""
|
||||
Schedule some data for retransmission.
|
||||
"""
|
||||
# if there is any outstanding CRYPTO, retransmit it
|
||||
crypto_scheduled = False
|
||||
for space in self.spaces:
|
||||
packets = tuple(
|
||||
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
|
||||
)
|
||||
if packets:
|
||||
self._on_packets_lost(now=now, packets=packets, space=space)
|
||||
crypto_scheduled = True
|
||||
if crypto_scheduled and self._logger is not None:
|
||||
self._logger.debug("Scheduled CRYPTO data for retransmission")
|
||||
|
||||
# ensure an ACK-elliciting packet is sent
|
||||
self._send_probe()
|
||||
|
||||
def _detect_loss(self, *, now: float, space: QuicPacketSpace) -> None:
|
||||
"""
|
||||
Check whether any packets should be declared lost.
|
||||
"""
|
||||
loss_delay = K_TIME_THRESHOLD * (
|
||||
max(self._rtt_latest, self._rtt_smoothed)
|
||||
if self._rtt_initialized
|
||||
else self._rtt_initial
|
||||
)
|
||||
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
|
||||
time_threshold = now - loss_delay
|
||||
|
||||
lost_packets = []
|
||||
space.loss_time = None
|
||||
for packet_number, packet in space.sent_packets.items():
|
||||
if packet_number > space.largest_acked_packet:
|
||||
break
|
||||
|
||||
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
|
||||
lost_packets.append(packet)
|
||||
else:
|
||||
packet_loss_time = packet.sent_time + loss_delay
|
||||
if space.loss_time is None or space.loss_time > packet_loss_time:
|
||||
space.loss_time = packet_loss_time
|
||||
|
||||
self._on_packets_lost(now=now, packets=lost_packets, space=space)
|
||||
|
||||
def _get_loss_space(self) -> Optional[QuicPacketSpace]:
|
||||
loss_space = None
|
||||
for space in self.spaces:
|
||||
if space.loss_time is not None and (
|
||||
loss_space is None or space.loss_time < loss_space.loss_time
|
||||
):
|
||||
loss_space = space
|
||||
return loss_space
|
||||
|
||||
def _log_metrics_updated(self, log_rtt=False) -> None:
|
||||
data: Dict[str, Any] = self._cc.get_log_data()
|
||||
|
||||
if log_rtt:
|
||||
data.update(
|
||||
{
|
||||
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
|
||||
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
|
||||
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
|
||||
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
|
||||
}
|
||||
)
|
||||
|
||||
self._quic_logger.log_event(
|
||||
category="recovery", event="metrics_updated", data=data
|
||||
)
|
||||
|
||||
def _on_packets_lost(
|
||||
self, *, now: float, packets: Iterable[QuicSentPacket], space: QuicPacketSpace
|
||||
) -> None:
|
||||
lost_packets_cc = []
|
||||
for packet in packets:
|
||||
del space.sent_packets[packet.packet_number]
|
||||
|
||||
if packet.in_flight:
|
||||
lost_packets_cc.append(packet)
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._quic_logger.log_event(
|
||||
category="recovery",
|
||||
event="packet_lost",
|
||||
data={
|
||||
"type": self._quic_logger.packet_type(packet.packet_type),
|
||||
"packet_number": packet.packet_number,
|
||||
},
|
||||
)
|
||||
self._log_metrics_updated()
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.LOST, *args)
|
||||
|
||||
# inform congestion controller
|
||||
if lost_packets_cc:
|
||||
self._cc.on_packets_lost(now=now, packets=lost_packets_cc)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
53
Code/venv/lib/python3.13/site-packages/aioquic/quic/retry.py
Normal file
53
Code/venv/lib/python3.13/site-packages/aioquic/quic/retry.py
Normal file
@ -0,0 +1,53 @@
|
||||
import ipaddress
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from ..buffer import Buffer
|
||||
from ..tls import pull_opaque, push_opaque
|
||||
from .connection import NetworkAddress
|
||||
|
||||
|
||||
def encode_address(addr: NetworkAddress) -> bytes:
|
||||
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
|
||||
|
||||
|
||||
class QuicRetryTokenHandler:
|
||||
def __init__(self) -> None:
|
||||
self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
addr: NetworkAddress,
|
||||
original_destination_connection_id: bytes,
|
||||
retry_source_connection_id: bytes,
|
||||
) -> bytes:
|
||||
buf = Buffer(capacity=512)
|
||||
push_opaque(buf, 1, encode_address(addr))
|
||||
push_opaque(buf, 1, original_destination_connection_id)
|
||||
push_opaque(buf, 1, retry_source_connection_id)
|
||||
return self._key.public_key().encrypt(
|
||||
buf.data,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
|
||||
),
|
||||
)
|
||||
|
||||
def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]:
|
||||
buf = Buffer(
|
||||
data=self._key.decrypt(
|
||||
token,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
encoded_addr = pull_opaque(buf, 1)
|
||||
original_destination_connection_id = pull_opaque(buf, 1)
|
||||
retry_source_connection_id = pull_opaque(buf, 1)
|
||||
if encoded_addr != encode_address(addr):
|
||||
raise ValueError("Remote address does not match.")
|
||||
return original_destination_connection_id, retry_source_connection_id
|
||||
364
Code/venv/lib/python3.13/site-packages/aioquic/quic/stream.py
Normal file
364
Code/venv/lib/python3.13/site-packages/aioquic/quic/stream.py
Normal file
@ -0,0 +1,364 @@
|
||||
from typing import Optional
|
||||
|
||||
from . import events
|
||||
from .packet import (
|
||||
QuicErrorCode,
|
||||
QuicResetStreamFrame,
|
||||
QuicStopSendingFrame,
|
||||
QuicStreamFrame,
|
||||
)
|
||||
from .packet_builder import QuicDeliveryState
|
||||
from .rangeset import RangeSet
|
||||
|
||||
|
||||
class FinalSizeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StreamFinishedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicStreamReceiver:
|
||||
"""
|
||||
The receive part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a send-only stream
|
||||
- upon reception of a STREAM_RESET frame
|
||||
- upon reception of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: Optional[int], readable: bool) -> None:
|
||||
self.highest_offset = 0 # the highest offset ever seen
|
||||
self.is_finished = False
|
||||
self.stop_pending = False
|
||||
|
||||
self._buffer = bytearray()
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._final_size: Optional[int] = None
|
||||
self._ranges = RangeSet()
|
||||
self._stream_id = stream_id
|
||||
self._stop_error_code: Optional[int] = None
|
||||
|
||||
def get_stop_frame(self) -> QuicStopSendingFrame:
|
||||
self.stop_pending = False
|
||||
return QuicStopSendingFrame(
|
||||
error_code=self._stop_error_code,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def starting_offset(self) -> int:
|
||||
return self._buffer_start
|
||||
|
||||
def handle_frame(
|
||||
self, frame: QuicStreamFrame
|
||||
) -> Optional[events.StreamDataReceived]:
|
||||
"""
|
||||
Handle a frame of received data.
|
||||
"""
|
||||
pos = frame.offset - self._buffer_start
|
||||
count = len(frame.data)
|
||||
frame_end = frame.offset + count
|
||||
|
||||
# we should receive no more data beyond FIN!
|
||||
if self._final_size is not None:
|
||||
if frame_end > self._final_size:
|
||||
raise FinalSizeError("Data received beyond final size")
|
||||
elif frame.fin and frame_end != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
if frame.fin:
|
||||
self._final_size = frame_end
|
||||
if frame_end > self.highest_offset:
|
||||
self.highest_offset = frame_end
|
||||
|
||||
# fast path: new in-order chunk
|
||||
if pos == 0 and count and not self._buffer:
|
||||
self._buffer_start += count
|
||||
if frame.fin:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
return events.StreamDataReceived(
|
||||
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
|
||||
)
|
||||
|
||||
# discard duplicate data
|
||||
if pos < 0:
|
||||
frame.data = frame.data[-pos:]
|
||||
frame.offset -= pos
|
||||
pos = 0
|
||||
count = len(frame.data)
|
||||
|
||||
# marked received range
|
||||
if frame_end > frame.offset:
|
||||
self._ranges.add(frame.offset, frame_end)
|
||||
|
||||
# add new data
|
||||
gap = pos - len(self._buffer)
|
||||
if gap > 0:
|
||||
self._buffer += bytearray(gap)
|
||||
self._buffer[pos : pos + count] = frame.data
|
||||
|
||||
# return data from the front of the buffer
|
||||
data = self._pull_data()
|
||||
end_stream = self._buffer_start == self._final_size
|
||||
if end_stream:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
if data or end_stream:
|
||||
return events.StreamDataReceived(
|
||||
data=data, end_stream=end_stream, stream_id=self._stream_id
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def handle_reset(
|
||||
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
|
||||
) -> Optional[events.StreamReset]:
|
||||
"""
|
||||
Handle an abrupt termination of the receiving part of the QUIC stream.
|
||||
"""
|
||||
if self._final_size is not None and final_size != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
|
||||
# we are done receiving
|
||||
self._final_size = final_size
|
||||
self.is_finished = True
|
||||
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
|
||||
|
||||
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a STOP_SENDING is ACK'd.
|
||||
"""
|
||||
if delivery != QuicDeliveryState.ACKED:
|
||||
self.stop_pending = True
|
||||
|
||||
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
|
||||
"""
|
||||
Request the peer stop sending data on the QUIC stream.
|
||||
"""
|
||||
self._stop_error_code = error_code
|
||||
self.stop_pending = True
|
||||
|
||||
def _pull_data(self) -> bytes:
|
||||
"""
|
||||
Remove data from the front of the buffer.
|
||||
"""
|
||||
try:
|
||||
has_data_to_read = self._ranges[0].start == self._buffer_start
|
||||
except IndexError:
|
||||
has_data_to_read = False
|
||||
if not has_data_to_read:
|
||||
return b""
|
||||
|
||||
r = self._ranges.shift()
|
||||
pos = r.stop - r.start
|
||||
data = bytes(self._buffer[:pos])
|
||||
del self._buffer[:pos]
|
||||
self._buffer_start = r.stop
|
||||
return data
|
||||
|
||||
|
||||
class QuicStreamSender:
|
||||
"""
|
||||
The send part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a receive-only stream
|
||||
- upon acknowledgement of a STREAM_RESET frame
|
||||
- upon acknowledgement of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: Optional[int], writable: bool) -> None:
|
||||
self.buffer_is_empty = True
|
||||
self.highest_offset = 0
|
||||
self.is_finished = not writable
|
||||
self.reset_pending = False
|
||||
|
||||
self._acked = RangeSet()
|
||||
self._acked_fin = False
|
||||
self._buffer = bytearray()
|
||||
self._buffer_fin: Optional[int] = None
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._buffer_stop = 0 # the offset for the stop of the buffer
|
||||
self._pending = RangeSet()
|
||||
self._pending_eof = False
|
||||
self._reset_error_code: Optional[int] = None
|
||||
self._stream_id = stream_id
|
||||
|
||||
@property
|
||||
def next_offset(self) -> int:
|
||||
"""
|
||||
The offset for the next frame to send.
|
||||
|
||||
This is used to determine the space needed for the frame's `offset` field.
|
||||
"""
|
||||
try:
|
||||
return self._pending[0].start
|
||||
except IndexError:
|
||||
return self._buffer_stop
|
||||
|
||||
def get_frame(
|
||||
self, max_size: int, max_offset: Optional[int] = None
|
||||
) -> Optional[QuicStreamFrame]:
|
||||
"""
|
||||
Get a frame of data to send.
|
||||
"""
|
||||
assert self._reset_error_code is None, "cannot call get_frame() after reset()"
|
||||
|
||||
# get the first pending data range
|
||||
try:
|
||||
r = self._pending[0]
|
||||
except IndexError:
|
||||
if self._pending_eof:
|
||||
# FIN only
|
||||
self._pending_eof = False
|
||||
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
|
||||
|
||||
self.buffer_is_empty = True
|
||||
return None
|
||||
|
||||
# apply flow control
|
||||
start = r.start
|
||||
stop = min(r.stop, start + max_size)
|
||||
if max_offset is not None and stop > max_offset:
|
||||
stop = max_offset
|
||||
if stop <= start:
|
||||
return None
|
||||
|
||||
# create frame
|
||||
frame = QuicStreamFrame(
|
||||
data=bytes(
|
||||
self._buffer[start - self._buffer_start : stop - self._buffer_start]
|
||||
),
|
||||
offset=start,
|
||||
)
|
||||
self._pending.subtract(start, stop)
|
||||
|
||||
# track the highest offset ever sent
|
||||
if stop > self.highest_offset:
|
||||
self.highest_offset = stop
|
||||
|
||||
# if the buffer is empty and EOF was written, set the FIN bit
|
||||
if self._buffer_fin == stop:
|
||||
frame.fin = True
|
||||
self._pending_eof = False
|
||||
|
||||
return frame
|
||||
|
||||
def get_reset_frame(self) -> QuicResetStreamFrame:
|
||||
self.reset_pending = False
|
||||
return QuicResetStreamFrame(
|
||||
error_code=self._reset_error_code,
|
||||
final_size=self.highest_offset,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def on_data_delivery(
|
||||
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
|
||||
) -> None:
|
||||
"""
|
||||
Callback when sent data is ACK'd.
|
||||
"""
|
||||
# If the frame had the FIN bit set, its end MUST match otherwise
|
||||
# we have a programming error.
|
||||
assert (
|
||||
not fin or stop == self._buffer_fin
|
||||
), "on_data_delivered() was called with inconsistent fin / stop"
|
||||
|
||||
# If a reset has been requested, stop processing data delivery.
|
||||
# The transition to the finished state only depends on the reset
|
||||
# being acknowledged.
|
||||
if self._reset_error_code is not None:
|
||||
return
|
||||
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
if stop > start:
|
||||
# Some data has been ACK'd, discard it.
|
||||
self._acked.add(start, stop)
|
||||
first_range = self._acked[0]
|
||||
if first_range.start == self._buffer_start:
|
||||
size = first_range.stop - first_range.start
|
||||
self._acked.shift()
|
||||
self._buffer_start += size
|
||||
del self._buffer[:size]
|
||||
|
||||
if fin:
|
||||
# The FIN has been ACK'd.
|
||||
self._acked_fin = True
|
||||
|
||||
if self._buffer_start == self._buffer_fin and self._acked_fin:
|
||||
# All data and the FIN have been ACK'd, we're done sending.
|
||||
self.is_finished = True
|
||||
else:
|
||||
if stop > start:
|
||||
# Some data has been lost, reschedule it.
|
||||
self.buffer_is_empty = False
|
||||
self._pending.add(start, stop)
|
||||
|
||||
if fin:
|
||||
# The FIN has been lost, reschedule it.
|
||||
self.buffer_is_empty = False
|
||||
self._pending_eof = True
|
||||
|
||||
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a reset is ACK'd.
|
||||
"""
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
# The reset has been ACK'd, we're done sending.
|
||||
self.is_finished = True
|
||||
else:
|
||||
# The reset has been lost, reschedule it.
|
||||
self.reset_pending = True
|
||||
|
||||
def reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Abruptly terminate the sending part of the QUIC stream.
|
||||
"""
|
||||
assert self._reset_error_code is None, "cannot call reset() more than once"
|
||||
self._reset_error_code = error_code
|
||||
self.reset_pending = True
|
||||
|
||||
# Prevent any more data from being sent or re-sent.
|
||||
self.buffer_is_empty = True
|
||||
|
||||
def write(self, data: bytes, end_stream: bool = False) -> None:
|
||||
"""
|
||||
Write some data bytes to the QUIC stream.
|
||||
"""
|
||||
assert self._buffer_fin is None, "cannot call write() after FIN"
|
||||
assert self._reset_error_code is None, "cannot call write() after reset()"
|
||||
size = len(data)
|
||||
|
||||
if size:
|
||||
self.buffer_is_empty = False
|
||||
self._pending.add(self._buffer_stop, self._buffer_stop + size)
|
||||
self._buffer += data
|
||||
self._buffer_stop += size
|
||||
if end_stream:
|
||||
self.buffer_is_empty = False
|
||||
self._buffer_fin = self._buffer_stop
|
||||
self._pending_eof = True
|
||||
|
||||
|
||||
class QuicStream:
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: Optional[int] = None,
|
||||
max_stream_data_local: int = 0,
|
||||
max_stream_data_remote: int = 0,
|
||||
readable: bool = True,
|
||||
writable: bool = True,
|
||||
) -> None:
|
||||
self.is_blocked = False
|
||||
self.max_stream_data_local = max_stream_data_local
|
||||
self.max_stream_data_local_sent = max_stream_data_local
|
||||
self.max_stream_data_remote = max_stream_data_remote
|
||||
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
|
||||
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
|
||||
self.stream_id = stream_id
|
||||
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
return self.receiver.is_finished and self.sender.is_finished
|
||||
2185
Code/venv/lib/python3.13/site-packages/aioquic/tls.py
Normal file
2185
Code/venv/lib/python3.13/site-packages/aioquic/tls.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user