Skip to content

Commit 2e9ad0a

Browse files
committed
feat: make the c struct def parser more flexible to types and includes
1 parent e57c629 commit 2e9ad0a

1 file changed

Lines changed: 65 additions & 5 deletions

File tree

libdestruct/c/struct_parser.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from __future__ import annotations
88

99
import ctypes
10+
import re
11+
import subprocess
12+
import tempfile
1013
from typing import TYPE_CHECKING
1114

1215
from pycparser import c_ast, c_parser
@@ -21,12 +24,24 @@ def definition_to_type(definition: str) -> type[obj]:
2124
"""Converts a C struct definition to a struct object."""
2225
parser = c_parser.CParser()
2326

24-
ast = parser.parse(definition)
27+
# If the definition contains includes, we must expand them.
28+
if "#include" in definition:
29+
definition = cleanup_attributes(expand_includes(definition))
30+
force_more_tops = True
31+
elif "typedef" in definition:
32+
force_more_tops = True
2533

26-
if len(ast.ext) != 1:
34+
try:
35+
ast = parser.parse(definition)
36+
except c_parser.ParseError as e:
37+
raise ValueError("Invalid definition. Please add the necessary includes if using non-standard type definitions.") from e
38+
39+
if not force_more_tops and len(ast.ext) != 1:
2740
raise ValueError("Definition must contain exactly one top object.")
2841

29-
root = ast.ext[0].type
42+
# If force_more_tops is True, we take the last top object.
43+
# This is useful when a struct definition is preceded by typedefs.
44+
root = ast.ext[-1].type if force_more_tops else ast.ext[0].type
3045

3146
if not isinstance(root, c_ast.Struct):
3247
raise TypeError("Definition must be a struct.")
@@ -65,14 +80,59 @@ def type_decl_to_type(decl: c_ast.TypeDecl) -> type[obj]:
6580
raise TypeError("Unsupported type.")
6681

6782

83+
def to_uniform_name(name: str) -> str:
84+
"""Converts a name to a uniform name."""
85+
name = name.replace("unsigned", "u")
86+
name = name.replace("_Bool", "bool")
87+
name = name.replace("uchar", "ubyte") # uchar is not a valid ctypes type
88+
89+
# We have to convert each intX, uintX, intX_t, uintX_t to the original char, short etc.
90+
name = name.replace("uint8_t", "ubyte")
91+
name = name.replace("int8_t", "char")
92+
name = name.replace("int16_t", "short")
93+
name = name.replace("int32_t", "int")
94+
name = name.replace("int64_t", "longlong")
95+
96+
# Only size_t, ssize_t and time_t can end with _t
97+
if not any(x in name for x in ["size", "ssize", "time"]):
98+
name = name.replace("_t", "")
99+
100+
return name
101+
102+
103+
def expand_includes(definition: str) -> str:
104+
"""Expands includes in a C definition using the C preprocessor."""
105+
# TODO: cache this result between subsequent runs of the same script
106+
with tempfile.NamedTemporaryFile(mode="w", suffix=".c") as f:
107+
f.write(definition)
108+
f.flush()
109+
110+
result = subprocess.run(["cc", "-std=c99", "-E", f.name], capture_output=True, text=True, check=True) # noqa: S607
111+
112+
return result.stdout
113+
114+
115+
def cleanup_attributes(definition: str) -> str:
116+
"""Cleans up attributes in a C definition."""
117+
# Remove __attribute__ ((...)) from the definition.
118+
pattern = r"__attribute__\s*\(\((?:[^()]+|\((?:[^()]+|\([^()]*\))*\))*\)\)" # ChatGPT provided this, don't ask me
119+
return re.sub(pattern, "", definition)
120+
121+
68122
def identifier_to_type(identifier: c_ast.IdentifierType) -> type[obj]:
69123
"""Converts a C identifier to a type."""
70124
if not isinstance(identifier, c_ast.IdentifierType):
71125
raise TypeError("Definition must be an identifier.")
72126

73-
identifier_name = "_".join(identifier.names)
127+
identifier_name = "".join(identifier.names)
128+
129+
ctypes_name = "c_" + identifier_name
130+
131+
if hasattr(ctypes, ctypes_name):
132+
return getattr(ctypes, ctypes_name)
74133

75-
ctypes_name = f"c_{identifier_name}"
134+
# Convert the identifier name to a uniform name, e.g., "unsigned int" -> "uint".
135+
ctypes_name = "c_" + to_uniform_name(identifier_name)
76136

77137
if hasattr(ctypes, ctypes_name):
78138
return getattr(ctypes, ctypes_name)

0 commit comments

Comments
 (0)