Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.9' ]
python-version: [ '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.9' ]
exclude:
- os: windows-latest
python-version: pypy3.9
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ sphinx:
build:
os: "ubuntu-22.04"
tools:
python: "3.8"
python: "3.13"

python:
install:
Expand Down
106 changes: 69 additions & 37 deletions construct/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@
from construct.lib import *
from construct.expr import *
from construct.version import *
import logging

def _emit_function_expression_or_const(code, func, parameters="this"):
if isinstance(func, ExprMixin) or (not callable(func)):
return repr(func)
else:
aid = code.allocateId()
code.userfunction[aid] = func
return f"userfunction[{aid}]({parameters})"

def _emit_source_or_use_linked(code, obj):
try:
if eval(repr(obj)) == obj:
return repr(obj)
except Exception as _:
aid = code.allocateId()
code.userfunction[aid] = obj
return f"userfunction[{aid}]"


#===============================================================================
Expand Down Expand Up @@ -980,7 +998,7 @@ def _sizeof(self, context, path):
raise SizeofError("cannot calculate size, key not found in context", path=path)

def _emitparse(self, code):
return f"io.read({self.length})"
return f"io.read({_emit_function_expression_or_const(code, self.length)})"

def _emitbuild(self, code):
return f"(io.write(obj), obj)[1]"
Expand Down Expand Up @@ -1916,6 +1934,17 @@ def new(intvalue, stringvalue):
ret.intvalue = intvalue
return ret

def __eq__(self, other):
if isinstance(other, int):
return (self.intvalue == other)
elif isinstance(other, type(self)):
return (self.intvalue == other.intvalue)
elif isinstance(other, str):
return str(self) == other
raise NotImplementedError(f"Cont compare {type(self)} to {type(other)} {other}")

def __hash__(self):
return self.intvalue

class Enum(Adapter):
r"""
Expand Down Expand Up @@ -1966,9 +1995,12 @@ def __init__(self, subcon, *merge, **mapping):
for enum in merge:
for enumentry in enum:
mapping[enumentry.name] = enumentry.value
self.encmapping = {EnumIntegerString.new(v,k):v for k,v in mapping.items()}
self.decmapping = {v:EnumIntegerString.new(v,k) for k,v in mapping.items()}
self.ksymapping = {v:k for k,v in mapping.items()}
encmappingFromNames = {k:v for k, v in mapping.items()}
encmappingFromValues = {v:v for _, v in mapping.items()}

self.encmapping = {**encmappingFromNames, **encmappingFromValues}
self.decmapping = {v:EnumIntegerString.new(v,k) for k, v in mapping.items()}
self.ksymapping = {v:k for k, v in mapping.items()}

def __getattr__(self, name):
if name in self.encmapping:
Expand All @@ -1981,6 +2013,9 @@ def _decode(self, obj, context, path):
except KeyError:
return EnumInteger(obj)

def _emitdecode(self, code):
return f"{_emit_source_or_use_linked(self.decmapping)}.get(obj, EnumInteger(obj))"

def _encode(self, obj, context, path):
try:
if isinstance(obj, int):
Expand All @@ -1989,22 +2024,14 @@ def _encode(self, obj, context, path):
except KeyError:
raise MappingError("building failed, no mapping for %r" % (obj,), path=path)

def _emitparse(self, code):
fname = f"factory_{code.allocateId()}"
code.append(f"{fname} = {repr(self.decmapping)}")
return f"reuse(({self.subcon._compileparse(code)}), lambda x: {fname}.get(x, EnumInteger(x)))"

def _emitbuild(self, code):
fname = f"factory_{code.allocateId()}"
code.append(f"{fname} = {repr(self.encmapping)}")
return f"reuse({fname}.get(obj, obj), lambda obj: ({self.subcon._compilebuild(code)}))"
def _emitencode(self, code):
return f"{_emit_source_or_use_linked(self.encmapping)}.get(obj, EnumInteger(obj))"

def _emitprimitivetype(self, ksy, bitwise):
name = "enum_%s" % ksy.allocateId()
ksy.enums[name] = self.ksymapping
return name


class BitwisableString(str):
"""Used internally."""

Expand Down Expand Up @@ -2096,9 +2123,6 @@ def _encode(self, obj, context, path):
except KeyError:
raise MappingError("building failed, unknown label: %r" % (obj,), path=path)

def _emitparse(self, code):
return f"reuse(({self.subcon._compileparse(code)}), lambda x: Container({', '.join(f'{k}=bool(x & {v} == {v})' for k,v in self.flags.items()) }))"

def _emitseq(self, ksy, bitwise):
bitstotal = self.subcon.sizeof() * 8
seq = []
Expand Down Expand Up @@ -2145,15 +2169,11 @@ def _encode(self, obj, context, path):
except (KeyError, TypeError):
raise MappingError("building failed, no encoding mapping for %r" % (obj,), path=path)

def _emitparse(self, code):
fname = f"factory_{code.allocateId()}"
code.append(f"{fname} = {repr(self.decmapping)}")
return f"{fname}[{self.subcon._compileparse(code)}]"
def _emitdecode(self, code):
return f"{_emit_source_or_use_linked(code, self.decmapping)}[obj]"

def _emitbuild(self, code):
fname = f"factory_{code.allocateId()}"
code.append(f"{fname} = {repr(self.encmapping)}")
return f"reuse({fname}[obj], lambda obj: ({self.subcon._compilebuild(code)}))"
def _emitencode(self, code):
return f"{_emit_source_or_use_linked(code, self.encmapping)}[obj]"


#===============================================================================
Expand Down Expand Up @@ -2924,10 +2944,10 @@ def _sizeof(self, context, path):
return 0

def _emitparse(self, code):
return repr(self.func)
return _emit_function_expression_or_const(code, self.func)

def _emitbuild(self, code):
return repr(self.func)
return _emit_function_expression_or_const(code, self.func)


@singleton
Expand Down Expand Up @@ -3119,14 +3139,14 @@ def _emitparse(self, code):
def parse_check(condition):
if not condition: raise CheckError
""")
return f"parse_check({repr(self.func)})"
return f"parse_check({_emit_function_expression_or_const(code, self.func)})"

def _emitbuild(self, code):
code.append(f"""
def build_check(condition):
if not condition: raise CheckError
""")
return f"build_check({repr(self.func)})"
return f"build_check({_emit_function_expression_or_const(code, self.func)})"


@singleton
Expand Down Expand Up @@ -3987,10 +4007,10 @@ def _sizeof(self, context, path):
return sc._sizeof(context, path)

def _emitparse(self, code):
return "((%s) if (%s) else (%s))" % (self.thensubcon._compileparse(code), self.condfunc, self.elsesubcon._compileparse(code), )
return "((%s) if (%s) else (%s))" % (self.thensubcon._compileparse(code), _emit_function_expression_or_const(code, self.condfunc), self.elsesubcon._compileparse(code), )

def _emitbuild(self, code):
return f"(({self.thensubcon._compilebuild(code)}) if ({repr(self.condfunc)}) else ({self.elsesubcon._compilebuild(code)}))"
return f"(({self.thensubcon._compilebuild(code)}) if ({_emit_function_expression_or_const(code, self.condfunc)}) else ({self.elsesubcon._compilebuild(code)}))"

def _emitseq(self, ksy, bitwise):
return [
Expand Down Expand Up @@ -4060,20 +4080,30 @@ def _sizeof(self, context, path):
def _emitparse(self, code):
fname = f"switch_cases_{code.allocateId()}"
code.append(f"{fname} = {{}}")
selector = ""
for key,sc in self.cases.items():
code.append(f"{fname}[{repr(key)}] = lambda io,this: {sc._compileparse(code)}")
if isinstance(key, EnumIntegerString):
selector = ".intvalue"
code.append(f"{fname}[{int(key)}] = lambda io,this: {sc._compileparse(code)}")
else:
code.append(f"{fname}[{_emit_source_or_use_linked(code, key)}] = lambda io,this: {sc._compileparse(code)}")
defaultfname = f"switch_defaultcase_{code.allocateId()}"
code.append(f"{defaultfname} = lambda io,this: {self.default._compileparse(code)}")
return f"{fname}.get({repr(self.keyfunc)}, {defaultfname})(io, this)"
return f"{fname}.get({_emit_source_or_use_linked(code, self.keyfunc)}{selector}, {defaultfname})(io, this)"

def _emitbuild(self, code):
fname = f"switch_cases_{code.allocateId()}"
code.append(f"{fname} = {{}}")
selector = ""
for key,sc in self.cases.items():
code.append(f"{fname}[{repr(key)}] = lambda obj,io,this: {sc._compilebuild(code)}")
if isinstance(key, EnumIntegerString):
selector = ".intvalue"
code.append(f"{fname}[{int(key)}] = lambda obj,io,this: {sc._compilebuild(code)}")
else:
code.append(f"{fname}[{_emit_source_or_use_linked(code, key)}] = lambda obj,io,this: {sc._compilebuild(code)}")
defaultfname = f"switch_defaultcase_{code.allocateId()}"
code.append(f"{defaultfname} = lambda obj,io,this: {self.default._compilebuild(code)}")
return f"{fname}.get({repr(self.keyfunc)}, {defaultfname})(obj, io, this)"
return f"{fname}.get({_emit_source_or_use_linked(code, self.keyfunc)}{selector}, {defaultfname})(obj, io, this)"


class StopIf(Construct):
Expand Down Expand Up @@ -4249,9 +4279,11 @@ def _sizeof(self, context, path):
raise SizeofError("cannot calculate size, key not found in context", path=path)

def _emitparse(self, code):
assert isinstance(self.length, int), "Padding needs to be known at compile time"
return f"({self.subcon._compileparse(code)}, io.read(({self.length})-({self.subcon.sizeof()}) ))[0]"

def _emitbuild(self, code):
assert isinstance(self.length, int), "Padding needs to be known at compile time"
return f"({self.subcon._compilebuild(code)}, io.write({repr(self.pattern)}*(({self.length})-({self.subcon.sizeof()})) ))[0]"

def _emitfulltype(self, ksy, bitwise):
Expand Down Expand Up @@ -4442,7 +4474,7 @@ def parse_pointer(io, offset, func):
io.seek(fallback)
return obj
""")
return f"parse_pointer(io, {self.offset}, lambda: {self.subcon._compileparse(code)})"
return f"parse_pointer(io, {_emit_function_expression_or_const(code, self.offset)}, lambda: {self.subcon._compileparse(code)})"

def _emitbuild(self, code):
code.append(f"""
Expand All @@ -4453,7 +4485,7 @@ def build_pointer(obj, io, offset, func):
io.seek(fallback)
return ret
""")
return f"build_pointer(obj, io, {self.offset}, lambda: {self.subcon._compilebuild(code)})"
return f"build_pointer(obj, io, {_emit_function_expression_or_const(code, self.offset)}, lambda: {self.subcon._compilebuild(code)})"

def _emitprimitivetype(self, ksy, bitwise):
offset = self.offset.__getfield__() if callable(self.offset) else self.offset
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
author = "Arkadiusz Bulski, Tomer Filiba, Corbin Simpson",
author_email = "arek.bulski@gmail.com, tomerfiliba@gmail.com, MostAwesomeDude@gmail.com",
python_requires = ">=3.8",
python_requires = ">=3.9",
install_requires = [],
extras_require = {
"extras": [
Expand Down
20 changes: 9 additions & 11 deletions tests/declarativeunittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,33 @@ def raises(func, *args, **kw):
except Exception as e:
return e.__class__

def common(format, datasample, objsample, sizesample=SizeofError, **kw):
def common(format, datasample, objsample, sizesample=SizeofError, compile=True, **kw):
# following are implied (re-parse and re-build)
# assert format.parse(format.build(obj)) == obj
# assert format.build(format.parse(data)) == data
obj = format.parse(datasample, **kw)
assert obj == objsample
assert obj == objsample, f"{obj} != {objsample} (obj, non compiled)"
data = format.build(objsample, **kw)
assert data == datasample
assert data == datasample, f"{data} != {datasample} (data, non compiled)"

if isinstance(sizesample, int):
size = format.sizeof(**kw)
assert size == sizesample
assert size == sizesample, f"{size} != {sizesample} (size (int), non compiled)"
else:
size = raises(format.sizeof, **kw)
assert size == sizesample
assert size == sizesample, f"{size} != {sizesample} (size, non compiled)"

# attemps to compile, ignores if compilation fails
# following was added to test compiling functionality
# and implies: format.parse(data) == cformat.parse(data)
# and implies: format.build(obj) == cformat.build(obj)
try:
if compile:
cformat = format.compile()
except Exception:
pass
else:

obj = cformat.parse(datasample, **kw)
assert obj == objsample
assert obj == objsample, f"{obj} != {objsample} (obj, compiled)"
data = cformat.build(objsample, **kw)
assert data == datasample
assert data == datasample, f"{data} != {datasample} (data, compiled)"

def commonhex(format, hexdata):
commonbytes(format, binascii.unhexlify(hexdata))
Expand Down
14 changes: 7 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,7 +1690,7 @@ def test_operators_issue_87():

def test_from_issue_76():
d = Aligned(4, Struct("a"/Byte, "f"/Bytes(lambda ctx: ctx.a)))
common(d, b"\x02\xab\xcd\x00", Container(a=2, f=b"\xab\xcd"))
common(d, b"\x02\xab\xcd\x00", Container(a=2, f=b"\xab\xcd"), compile=False)

def test_from_issue_60():
Header = Struct(
Expand Down Expand Up @@ -1973,21 +1973,21 @@ def test_exposing_members_context():
"data" / Bytes(lambda this: this.count - this._subcons.count.sizeof()),
Check(lambda this: this._subcons.count.sizeof() == 1),
)
common(d, b"\x05four", Container(count=5, data=b"four"))
common(d, b"\x05four", Container(count=5, data=b"four"), compile=False)

d = Sequence(
"count" / Byte,
"data" / Bytes(lambda this: this.count - this._subcons.count.sizeof()),
Check(lambda this: this._subcons.count.sizeof() == 1),
)
common(d, b"\x05four", [5,b"four",None])
common(d, b"\x05four", [5,b"four",None], compile=False)

d = FocusedSeq("count",
"count" / Byte,
"data" / Padding(lambda this: this.count - this._subcons.count.sizeof()),
Check(lambda this: this._subcons.count.sizeof() == 1),
)
common(d, b'\x04\x00\x00\x00', 4, SizeofError)
common(d, b'\x04\x00\x00\x00', 4, SizeofError, compile=False)

d = Union(None,
"chars" / Byte[4],
Expand Down Expand Up @@ -2366,9 +2366,9 @@ def test_switch_issue_913_using_enum():
}

d = Switch(keyfunc = this.x, cases = mapping)
common(d, b"", None, 0, x="Zero")
common(d, b"\xab", 171, 1, x="One")
common(d, b"\x09\x00", 9, 2, x="Two")
common(d, b"", None, 0, x=enum.Zero)
common(d, b"\xab", 171, 1, x=enum.One)
common(d, b"\x09\x00", 9, 2, x=enum.Two)

def test_switch_issue_913_using_strings():
mapping = {
Expand Down