Skip to content

Commit e3254be

Browse files
authored
Add x86-preprocessor
Prepocessor for x86_64 arch
2 parents 32f114a + e17f247 commit e3254be

2 files changed

Lines changed: 351 additions & 0 deletions

File tree

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import re
2+
from collections.abc import Callable, Iterator
3+
4+
from asmtransformers.operands import is_offset
5+
6+
7+
# based mostly on: https://en.wikipedia.org/wiki/List_of_x86_instructions
8+
BRANCH_INSTRUCTIONS = (
9+
# unconditional jump and call/return
10+
'jmp',
11+
'call',
12+
'ret',
13+
'retn',
14+
'retf',
15+
# equal / zero
16+
'je',
17+
'jz',
18+
# not equal / not zero
19+
'jne',
20+
'jnz',
21+
# signed less / less-or-equal
22+
'jl',
23+
'jnge',
24+
'jle',
25+
'jng',
26+
# signed greater / greater-or-equal
27+
'jg',
28+
'jnle',
29+
'jge',
30+
'jnl',
31+
# unsigned above / above-or-equal (carry clear)
32+
'ja',
33+
'jnbe',
34+
'jae',
35+
'jnb',
36+
# unsigned below / below-or-equal (carry set)
37+
'jb',
38+
'jnae',
39+
'jbe',
40+
'jna',
41+
# sign / overflow
42+
'js',
43+
'jns',
44+
'jo',
45+
'jno',
46+
# parity
47+
'jp',
48+
'jpe',
49+
'jnp',
50+
'jpo',
51+
# cx/ecx/rcx zero (short-range loop exits)
52+
'jcxz',
53+
'jecxz',
54+
'jrcxz',
55+
# loop instructions (branch if cx != 0)
56+
'loop',
57+
'loope',
58+
'loopne',
59+
)
60+
61+
62+
SIZE_QUALIFIERS = (
63+
'byte',
64+
'word',
65+
'dword',
66+
'qword',
67+
'tbyte',
68+
'xmmword',
69+
'ymmword',
70+
'zmmword',
71+
)
72+
73+
_SIZE_QUALIFIER_SET = frozenset(SIZE_QUALIFIERS)
74+
75+
# Matches one token that stops before commas, spaces, or memory-expression operators
76+
_OPERAND_TOKEN = re.compile(r'[^\s,+\-*\[\]]+')
77+
78+
_MEM_OPERATORS = frozenset('+-*')
79+
80+
81+
def parse_operands(operands: str) -> Iterator[str]:
82+
"""
83+
move through x86 string operands linearly
84+
"""
85+
offset = 0
86+
length = len(operands)
87+
88+
while offset < length:
89+
match operands[offset]:
90+
case ' ' | ',':
91+
offset += 1
92+
case '[':
93+
end = operands.index(']', offset)
94+
yield '['
95+
yield from _parse_mem_expr(operands[offset + 1 : end])
96+
yield ']'
97+
offset = end + 1
98+
case ch if ch in _MEM_OPERATORS:
99+
yield ch
100+
offset += 1
101+
case _:
102+
m = _OPERAND_TOKEN.match(operands, offset)
103+
if not m:
104+
offset += 1
105+
continue
106+
107+
token = m.group().rstrip(':') # strip segment-override colon (e.g. "fs:")
108+
lower = token.lower()
109+
110+
if lower in _SIZE_QUALIFIER_SET:
111+
# consume the mandatory following "ptr" keyword and merge into one token
112+
rest = operands[m.end() :].lstrip()
113+
if rest.lower().startswith('ptr'):
114+
yield f'{lower}_ptr'
115+
offset = m.end() + (len(operands[m.end() :]) - len(rest)) + 3
116+
else:
117+
yield lower
118+
offset = m.end()
119+
else:
120+
yield lower
121+
offset = m.end()
122+
123+
124+
def _parse_mem_expr(expr: str) -> Iterator[str]:
125+
"""
126+
We need this because subtracting a memory expression in ghidra is done by [rbp + -0x8].
127+
Split a memory address expression into tokens.
128+
treating negative displacements like "-0x8" as a single token rather than an operator followed by a number."""
129+
expr = expr.strip()
130+
i = 0
131+
length = len(expr)
132+
133+
while i < length:
134+
match expr[i]:
135+
case ' ':
136+
i += 1
137+
case '+' | '*':
138+
yield expr[i]
139+
i += 1
140+
case '-':
141+
# unary minus: attach to the number that follows when preceded by an operator or start
142+
prev = expr[:i].rstrip()
143+
if not prev or prev[-1] in ('+', '-', '*'):
144+
m = _OPERAND_TOKEN.match(expr, i + 1)
145+
if m:
146+
yield f'-{m.group().lower()}'
147+
i = m.end()
148+
continue
149+
yield '-'
150+
i += 1
151+
case _:
152+
m = _OPERAND_TOKEN.match(expr, i)
153+
if m:
154+
yield m.group().lower()
155+
i = m.end()
156+
else:
157+
i += 1
158+
159+
160+
class X86Preprocessor:
161+
"""
162+
Based on the ARM64 preprocessor but adjusted for amd64 (x86_64) arch.
163+
"""
164+
165+
def __init__(
166+
self,
167+
*,
168+
branch_instructions: tuple[str, ...] = BRANCH_INSTRUCTIONS,
169+
parse_operands: Callable[[str], Iterator[str]] = parse_operands,
170+
context_length: int = 512,
171+
prefix_tokens: tuple[str, ...] | None = None,
172+
operand_formatters: tuple[Callable, ...] | None = None,
173+
):
174+
self.branch_instructions = frozenset(branch_instructions)
175+
self.parse_operands = parse_operands
176+
self.context_length = context_length
177+
self.prefix_tokens = prefix_tokens or ()
178+
self.operand_formatters = operand_formatters or ()
179+
180+
def format_jump(self, operand: str, target_index: int | None) -> str:
181+
if target_index is None:
182+
return 'UNK_JUMP_ADDR'
183+
elif target_index < self.context_length:
184+
return f'JUMP_ADDR_{target_index}'
185+
else:
186+
return 'JUMP_ADDR_EXCEEDED'
187+
188+
def format_operand(self, operand: str) -> str | None:
189+
for formatter in self.operand_formatters:
190+
if replacement := formatter(operand):
191+
return replacement
192+
193+
def preprocess(self, function_blocks: dict[int, list[str]]) -> list[str]:
194+
block_offsets = {}
195+
jump_offsets = {}
196+
tokens = list(self.prefix_tokens)
197+
198+
function_blocks = dict(sorted(function_blocks.items()))
199+
200+
for block_id, block in function_blocks.items():
201+
block_offsets[block_id] = len(tokens)
202+
for instruction in block:
203+
parts = instruction.lower().split(maxsplit=1)
204+
mnemonic = parts[0]
205+
operand_str = parts[1] if len(parts) > 1 else ''
206+
207+
tokens.append(mnemonic)
208+
for operand in self.parse_operands(operand_str) if operand_str else ():
209+
if mnemonic in self.branch_instructions and (offset := is_offset(operand)):
210+
# can't slice at place 2 because negative hex values so therefore use value from is_offset regex
211+
jump_target = int(offset.group('value'), base=16)
212+
jump_offsets[len(tokens)] = jump_target
213+
else:
214+
operand = self.format_operand(operand) or operand
215+
tokens.append(operand)
216+
217+
for offset, jump_target in jump_offsets.items():
218+
token = tokens[offset]
219+
if replacement := self.format_jump(token, block_offsets.get(jump_target)):
220+
tokens[offset] = replacement
221+
222+
return tokens

asmtransformers/tests/test_x86.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import pytest
2+
3+
from asmtransformers import x86
4+
from asmtransformers.operands import is_offset
5+
6+
7+
@pytest.fixture
8+
def tokenizer():
9+
return x86.X86Preprocessor()
10+
11+
12+
def test_parse_plain_operands():
13+
assert list(x86.parse_operands('rax')) == ['rax']
14+
assert list(x86.parse_operands('0x10')) == ['0x10']
15+
assert list(x86.parse_operands('rax, rbx')) == ['rax', 'rbx']
16+
assert list(x86.parse_operands('rax, 0x10')) == ['rax', '0x10']
17+
18+
19+
def test_parse_memory_operand():
20+
assert list(x86.parse_operands('[rax]')) == ['[', 'rax', ']']
21+
assert list(x86.parse_operands('[rbp + -0x8]')) == ['[', 'rbp', '+', '-0x8', ']']
22+
23+
24+
def test_parse_size_qualifier():
25+
assert list(x86.parse_operands('dword ptr [rbp + -0x8]')) == ['dword_ptr', '[', 'rbp', '+', '-0x8', ']']
26+
assert list(x86.parse_operands('xmmword ptr [rax]')) == ['xmmword_ptr', '[', 'rax', ']']
27+
28+
29+
def test_parse_segment_override():
30+
assert list(x86.parse_operands('fs:[rax]')) == ['fs', '[', 'rax', ']']
31+
32+
33+
def test_parse_complex_memory():
34+
assert list(x86.parse_operands('[rax + rcx*4 + 0x10]')) == [
35+
'[',
36+
'rax',
37+
'+',
38+
'rcx',
39+
'*',
40+
'4',
41+
'+',
42+
'0x10',
43+
']',
44+
]
45+
46+
47+
def test_tokenize_single_block(tokenizer):
48+
49+
graph = {0: ['mov rax, 0x1234', 'add rax, 0x1234', 'ret']}
50+
51+
assert tokenizer.preprocess(graph) == [
52+
'mov',
53+
'rax',
54+
'0x1234',
55+
'add',
56+
'rax',
57+
'0x1234',
58+
'ret',
59+
]
60+
61+
62+
def test_tokenize_branching_blocks(tokenizer):
63+
64+
# NB: nodes are in 'reverse order', tokenizer should reorder these based on their node ids
65+
graph = {42: ['add rcx, 0x290', 'je 0x0'], 0: ['sub rcx, 0x290', 'jmp 0x2a']} # branch to offset 0
66+
67+
assert tokenizer.preprocess(graph) == [
68+
'sub',
69+
'rcx',
70+
'0x290',
71+
'jmp',
72+
'JUMP_ADDR_5',
73+
'add',
74+
'rcx',
75+
'0x290',
76+
'je',
77+
'JUMP_ADDR_0',
78+
]
79+
80+
81+
def test_context_length_boundary():
82+
tokens = x86.X86Preprocessor(context_length=10).preprocess(
83+
{
84+
0x12: ['mov rax, 0x1234', 'jmp 0x34'],
85+
0x34: ['add rcx, 0x290', 'je 0x56'],
86+
0x56: ['sub rdx, 0x10', 'jmp 0x12'],
87+
}
88+
)
89+
90+
# je target (0x56) falls at block offset 10, which equals context_length — exceeded
91+
assert tokens.index('JUMP_ADDR_EXCEEDED') - tokens.index('je') == 1
92+
93+
94+
def test_jump_to_unknown_block(tokenizer):
95+
graph = {0x12: ['jg 0x999', 'add rax, 0x12'], 0x34: ['jmp 0x12', 'sub rax, 0x12']}
96+
tokens = tokenizer.preprocess(graph)
97+
98+
assert tokens.index('UNK_JUMP_ADDR') - tokens.index('jg') == 1
99+
100+
101+
def test_offset_prefix_tokens(tokenizer):
102+
103+
graph = {0x12: ['jmp 0x34'], 0x34: ['jmp 0x12']}
104+
tokens1 = tokenizer.preprocess(graph)
105+
tokenizer.prefix_tokens = ('[CLS]', '[PAD]')
106+
tokens2 = tokenizer.preprocess(graph)
107+
108+
assert tokens1 != tokens2
109+
assert tokens2[:2] == ['[CLS]', '[PAD]']
110+
# same code, jumps shift by the two prefix tokens
111+
assert (tokens1[1], tokens1[3]) == ('JUMP_ADDR_2', 'JUMP_ADDR_0')
112+
assert (tokens2[3], tokens2[5]) == ('JUMP_ADDR_4', 'JUMP_ADDR_2')
113+
114+
115+
def test_format_operand():
116+
class ObfuscatingPreprocessor(x86.X86Preprocessor):
117+
def format_operand(self, operand):
118+
if is_offset(operand):
119+
return 'OBFUSCATED'
120+
else:
121+
return operand
122+
123+
graph = {0x12: ['add rax, 0x78', 'jmp 0x78'], 0x78: ['sub rbx, 0x78', 'ret']}
124+
tokens = ObfuscatingPreprocessor().preprocess(graph)
125+
126+
# jmp 0x78 resolves to a JUMP_ADDR token; the other two 0x78 occurrences are obfuscated
127+
assert 'JUMP_ADDR_5' in tokens
128+
assert '0x78' not in tokens
129+
assert tokens.count('OBFUSCATED') == 2

0 commit comments

Comments
 (0)