from __future__ import annotations
import dataclasses
import inspect
import typing
from decimal import Decimal
from enum import IntEnum
from io import BufferedIOBase
from struct import Struct as _Struct
from struct import error as _StructError
from typing import (
Annotated,
Any,
Callable,
Generic,
Iterable,
Iterator,
Literal,
TypeVar,
Union,
)
__version__ = "0.5.0"
[docs]class BstructError(Exception):
pass
if typing.TYPE_CHECKING:
from _typeshed import DataclassInstance
D = TypeVar("D", bound=DataclassInstance)
else:
D = TypeVar("D")
ByteOrder = Literal["big", "little"]
NativeValue = Union[bool, bytes, int]
ValueIterator = Iterator[NativeValue]
ValueList = list[NativeValue]
T = TypeVar("T")
Decoder = Callable[[ValueIterator, ByteOrder], T]
Encoder = Callable[[T, ValueList, ByteOrder], None]
def _raw_encode(value: Any, attributes: ValueList, _byte_order: ByteOrder) -> None:
attributes.append(value)
def _raw_decode(attributes: ValueIterator, _byteorder: ByteOrder) -> Any:
return next(attributes)
[docs]class Encoding(Generic[T]):
def __init__(
self,
target: type[T],
format: str,
decode_attributes: Decoder[T],
encode_attributes: Encoder[T],
):
self.target = target
self.format = format
self.decode_attributes = decode_attributes
self.encode_attributes = encode_attributes
self.le_struct = _Struct(f"<{format}")
self.be_struct = _Struct(f">{format}")
self.size = self.le_struct.size
[docs] def get_struct(self, byteorder: ByteOrder) -> _Struct:
"""
Return the underlying `struct.Struct` instance used to (un)pack
the binary data.
"""
if byteorder == "little":
return self.le_struct
else:
return self.be_struct
[docs] def decode(self, data: bytes, byteorder: ByteOrder = "little") -> T:
"""
Decode an instance of `cls` from the provided data.
The size of `data` must exactly match the size of `cls`.
"""
try:
raw_attributes = self.get_struct(byteorder).unpack(data)
return self.decode_attributes(iter(raw_attributes), byteorder)
except _StructError as error:
raise BstructError(str(error)) from error
[docs] def decode_all(self, data: bytes, byteorder: ByteOrder = "little") -> Iterator[T]:
"""
Decode multiple instances of `cls`.
The size of `data` must be an integer multiple of the size of `S`.
"""
try:
iterator = self.get_struct(byteorder).iter_unpack(data)
for raw_attributes in iterator:
yield self.decode_attributes(iter(raw_attributes), byteorder)
except _StructError as error:
raise BstructError(str(error)) from error
[docs] def read(self, buffer: BufferedIOBase, byteorder: ByteOrder = "little") -> T:
"""
Read and decode an instance of `cls` from the `buffer`.
"""
data = buffer.read(self.size)
return self.decode(data, byteorder)
[docs] def read_many(
self,
buffer: BufferedIOBase,
count: int,
byteorder: ByteOrder = "little",
) -> list[T]:
"""
Read and decode `count` instances of `cls` from the `buffer`.
"""
data = buffer.read(self.size * count)
return list(self.decode_all(data, byteorder))
[docs] def encode(self, value: T, byteorder: ByteOrder = "little") -> bytes:
"""
Encode the value according to the provided byteorder.
"""
raw_attributes: list[Any] = []
self.encode_attributes(value, raw_attributes, byteorder)
try:
return self.get_struct(byteorder).pack(*raw_attributes)
except _StructError as error:
raise BstructError(str(error)) from error
[docs] def write(
self, value: T, buffer: BufferedIOBase, byteorder: ByteOrder = "little"
) -> None:
"""
Encode and write `value` into the `buffer`.
"""
data = self.encode(value, byteorder)
buffer.write(data)
[docs] def write_many(
self,
items: Iterable[T],
buffer: BufferedIOBase,
byteorder: ByteOrder = "little",
) -> None:
"""
Encode and write all `items` into the `buffer`.
"""
for item in items:
data = self.encode(item, byteorder)
buffer.write(data)
[docs] def get_size(self) -> int:
"""
Return the size in bytes of the serialized form of a `bstruct.Struct` class.
"""
return self.size
class _NativeEncoding(Encoding[T]):
def __init__(self, target: type[T], format: str):
super().__init__(
target, format, decode_attributes=_raw_decode, encode_attributes=_raw_encode
)
[docs]class Bytes(_NativeEncoding[bytes]):
def __init__(self, size: int):
assert size > 0
super().__init__(bytes, f"{size}s")
[docs]class Array:
"""
Special annotation class to specify the length of an array.
"""
def __init__(self, length: int):
assert length > 0
self.length = length
[docs]class CustomEncoding(Encoding[T]):
def __init__(
self,
target: type[T],
format: str,
decode_attributes: Decoder[T],
encode_attributes: Encoder[T],
):
super().__init__(
target,
format,
decode_attributes=decode_attributes,
encode_attributes=encode_attributes,
)
[docs] @staticmethod
def create(
target: type[T],
fields: Iterable[_NativeEncoding[Any]],
decode_attributes: Decoder[T],
encode_attributes: Encoder[T],
) -> CustomEncoding[T]:
return CustomEncoding(
target,
format=compile_format(fields),
decode_attributes=decode_attributes,
encode_attributes=encode_attributes,
)
[docs]class String(CustomEncoding[str]):
def __init__(self, size: int):
super().__init__(
target=str,
format=Bytes(size).format,
decode_attributes=_decode_str,
encode_attributes=_encode_str,
)
def _decode_str(attributes: ValueIterator, _byteorder: ByteOrder) -> str:
value = next(attributes)
assert isinstance(value, bytes)
return value.rstrip(b"\0").decode("utf-8")
def _encode_str(value: str, attributes: ValueList, _byteorder: ByteOrder) -> None:
# The `struct` library automatically adds zeros to the end of
# the encoded string to fill the necessary `data.size` bytes.
data = value.encode("utf-8")
attributes.append(data)
def _decode_uint(attributes: ValueIterator, byteorder: ByteOrder) -> int:
value = next(attributes)
assert isinstance(value, bytes)
return int.from_bytes(value, byteorder, signed=False)
def _encode_uint128(value: int, attributes: ValueList, byteorder: ByteOrder) -> None:
data = value.to_bytes(16, byteorder, signed=False)
attributes.append(data)
def _encode_uint256(value: int, attributes: ValueList, byteorder: ByteOrder) -> None:
data = value.to_bytes(32, byteorder, signed=False)
attributes.append(data)
def _decode_int(attributes: ValueIterator, byteorder: ByteOrder) -> int:
value = next(attributes)
assert isinstance(value, bytes)
return int.from_bytes(value, byteorder, signed=True)
def _encode_int128(value: int, attributes: ValueList, byteorder: ByteOrder) -> None:
data = value.to_bytes(16, byteorder, signed=True)
attributes.append(data)
def _encode_int256(value: int, attributes: ValueList, byteorder: ByteOrder) -> None:
data = value.to_bytes(32, byteorder, signed=True)
attributes.append(data)
# The last 6 bytes (== 6 * 8 == 48 bits) of the 16 byte value is the fractional part.
# Therefore, divide by 2^48.
I80F48_DIVISOR = Decimal(2**48)
def _decode_I80F48(attributes: ValueIterator, byte_order: ByteOrder) -> Decimal:
value = next(attributes)
assert isinstance(value, bytes)
return int.from_bytes(value, byte_order, signed=True) / I80F48_DIVISOR
def _encode_I80F48(
value: Decimal, attributes: ValueList, byte_order: ByteOrder
) -> None:
data = int(value * I80F48_DIVISOR).to_bytes(16, byte_order, signed=True)
attributes.append(data)
[docs]class Encodings:
bool: _NativeEncoding[bool] = _NativeEncoding(bool, format="?")
u8 = _NativeEncoding(int, "B")
u16 = _NativeEncoding(int, format="H")
u32 = _NativeEncoding(int, format="I")
u64 = _NativeEncoding(int, format="Q")
u128 = CustomEncoding.create(
int,
fields=[Bytes(16)],
decode_attributes=_decode_uint,
encode_attributes=_encode_uint128,
)
u256 = CustomEncoding.create(
int,
fields=[Bytes(32)],
decode_attributes=_decode_uint,
encode_attributes=_encode_uint256,
)
i8 = _NativeEncoding(int, format="b")
i16 = _NativeEncoding(int, format="h")
i32 = _NativeEncoding(int, format="i")
i64 = _NativeEncoding(int, format="q")
i128 = CustomEncoding.create(
int,
fields=[Bytes(16)],
decode_attributes=_decode_int,
encode_attributes=_encode_int128,
)
i256 = CustomEncoding.create(
int,
fields=[Bytes(32)],
decode_attributes=_decode_int,
encode_attributes=_encode_int256,
)
f16 = _NativeEncoding(float, format="e")
f32 = _NativeEncoding(float, format="f")
f64 = _NativeEncoding(float, format="d")
I80F48 = CustomEncoding.create(
Decimal,
fields=[Bytes(16)],
decode_attributes=_decode_I80F48,
encode_attributes=_encode_I80F48,
)
u8 = Annotated[int, Encodings.u8]
u16 = Annotated[int, Encodings.u16]
u32 = Annotated[int, Encodings.u32]
u64 = Annotated[int, Encodings.u64]
u128 = Annotated[int, Encodings.u128]
u256 = Annotated[int, Encodings.u256]
i8 = Annotated[int, Encodings.i8]
i16 = Annotated[int, Encodings.i16]
i32 = Annotated[int, Encodings.i32]
i64 = Annotated[int, Encodings.i64]
i128 = Annotated[int, Encodings.i128]
i256 = Annotated[int, Encodings.i256]
f16 = Annotated[float, Encodings.f16]
f32 = Annotated[float, Encodings.f32]
f64 = Annotated[float, Encodings.f64]
I80F48 = Annotated[Decimal, Encodings.I80F48]
[docs]def derive(attribute_type: type[T]) -> Encoding[T]:
return _derive(attribute_type)
def _derive(
attribute_type: Any,
) -> Encoding[Any]:
if isinstance(attribute_type, str):
raise Exception(
"Do not use 'from __future__ import annotations' in the same file in which 'binary.derive()' is used."
)
if attribute_type is bool:
return Encodings.bool
elif inspect.isclass(attribute_type) and dataclasses.is_dataclass(attribute_type):
return _resolve_dataclass_encoding(attribute_type)
elif typing.get_origin(attribute_type) is typing.Annotated:
annotated_type, *annotation_args = typing.get_args(attribute_type)
if inspect.isclass(annotated_type) and issubclass(annotated_type, IntEnum):
return _resolve_int_enum_encoding(annotated_type, annotation_args)
elif annotated_type is list:
raise TypeError("list is missing inner type")
elif typing.get_origin(annotated_type) is list:
return _resolve_array_encoding(annotated_type, annotation_args)
else:
return _resolve_simple_encoding(annotated_type, annotation_args)
else:
raise TypeError(f"Missing annotation for type {attribute_type.__name__}")
def _encode_native_list(
l: list[NativeValue], attributes: ValueList, _byteorder: ByteOrder
) -> None:
attributes.extend(l)
def _resolve_array_encoding(
array_type: Any, metadata: list[Any]
) -> CustomEncoding[list[Any]]:
array_length = _resolve_array_length(metadata)
array_type_args = typing.get_args(array_type)
assert len(array_type_args) == 1
inner_type = array_type_args[0]
inner_encoding = derive(inner_type)
array_format = "".join([inner_encoding.format] * array_length)
inner_decode = inner_encoding.decode_attributes
inner_encode = inner_encoding.encode_attributes
def _decode(attributes: Iterator[Any], byte_order: ByteOrder) -> list[Any]:
return [inner_decode(attributes, byte_order) for _ in range(array_length)]
# Optimization: If the encoding function of the element is just inserting the raw element,
# use a batch insert to speed things up.
if inner_encode == _raw_encode:
_encode = _encode_native_list
else:
def _encode(l: list[Any], attributes: list[Any], byte_order: ByteOrder) -> None:
for item in l:
inner_encode(item, attributes, byte_order)
return CustomEncoding(array_type, array_format, _decode, _encode)
def _resolve_array_length(metadata: list[Any]) -> int:
for data in metadata:
if isinstance(data, Array):
return data.length
raise TypeError("Cannot find length annotation for list")
I = TypeVar("I", bound=IntEnum)
def _resolve_int_enum_encoding(cls: type[I], metadata: list[Any]) -> CustomEncoding[I]:
inner_encoding = _resolve_simple_encoding(int, metadata)
return CustomEncoding(
cls,
format=inner_encoding.format,
decode_attributes=lambda a, b: cls(inner_encoding.decode_attributes(a, b)),
encode_attributes=inner_encoding.encode_attributes,
)
def _resolve_simple_encoding(target: type[T], metadata: list[Any]) -> Encoding[T]:
for data in metadata:
if isinstance(data, (_NativeEncoding, CustomEncoding)):
data: Encoding[Any] = data
if data.target is not target:
raise TypeError(
f"Wrong encoding: Expected Encoding[{target.__name__}], got Encoding[{data.target.__name__}]"
)
else:
return data
raise TypeError(f"Cannot find encoding for type {target.__name__}")
def _resolve_dataclass_encoding(cls: type[D]) -> CustomEncoding[D]:
fields = dataclasses.fields(cls)
attribute_decoders: list[Decoder[Any]] = []
attribute_encoders: list[Encoder[Any]] = []
full_format: str = ""
for field in fields:
encoding = derive(field.type)
full_format += encoding.format
attribute_decoders.append(encoding.decode_attributes)
attribute_encoders.append(encoding.encode_attributes)
attribute_names = [field.name for field in fields]
def _decode(raw_attributes: ValueIterator, byteorder: ByteOrder) -> D:
attributes = [
decode_attribute(raw_attributes, byteorder)
for decode_attribute in attribute_decoders
]
return cls(*attributes)
def _encode(value: D, raw_attributes: ValueList, byteorder: ByteOrder) -> None:
for encode_attribute, name in zip(attribute_encoders, attribute_names):
attribute = getattr(value, name)
encode_attribute(attribute, raw_attributes, byteorder)
return CustomEncoding(
cls,
format=full_format,
decode_attributes=_decode,
encode_attributes=_encode,
)