|
| 1 | +import re |
| 2 | +from collections.abc import Callable, Iterator |
| 3 | + |
| 4 | +from asmtransformers.operands import is_offset |
| 5 | + |
| 6 | + |
| 7 | +# based mostly on: https://en.wikipedia.org/wiki/List_of_x86_instructions |
| 8 | +BRANCH_INSTRUCTIONS = ( |
| 9 | + # unconditional jump and call/return |
| 10 | + 'jmp', |
| 11 | + 'call', |
| 12 | + 'ret', |
| 13 | + 'retn', |
| 14 | + 'retf', |
| 15 | + # equal / zero |
| 16 | + 'je', |
| 17 | + 'jz', |
| 18 | + # not equal / not zero |
| 19 | + 'jne', |
| 20 | + 'jnz', |
| 21 | + # signed less / less-or-equal |
| 22 | + 'jl', |
| 23 | + 'jnge', |
| 24 | + 'jle', |
| 25 | + 'jng', |
| 26 | + # signed greater / greater-or-equal |
| 27 | + 'jg', |
| 28 | + 'jnle', |
| 29 | + 'jge', |
| 30 | + 'jnl', |
| 31 | + # unsigned above / above-or-equal (carry clear) |
| 32 | + 'ja', |
| 33 | + 'jnbe', |
| 34 | + 'jae', |
| 35 | + 'jnb', |
| 36 | + # unsigned below / below-or-equal (carry set) |
| 37 | + 'jb', |
| 38 | + 'jnae', |
| 39 | + 'jbe', |
| 40 | + 'jna', |
| 41 | + # sign / overflow |
| 42 | + 'js', |
| 43 | + 'jns', |
| 44 | + 'jo', |
| 45 | + 'jno', |
| 46 | + # parity |
| 47 | + 'jp', |
| 48 | + 'jpe', |
| 49 | + 'jnp', |
| 50 | + 'jpo', |
| 51 | + # cx/ecx/rcx zero (short-range loop exits) |
| 52 | + 'jcxz', |
| 53 | + 'jecxz', |
| 54 | + 'jrcxz', |
| 55 | + # loop instructions (branch if cx != 0) |
| 56 | + 'loop', |
| 57 | + 'loope', |
| 58 | + 'loopne', |
| 59 | +) |
| 60 | + |
| 61 | + |
| 62 | +SIZE_QUALIFIERS = ( |
| 63 | + 'byte', |
| 64 | + 'word', |
| 65 | + 'dword', |
| 66 | + 'qword', |
| 67 | + 'tbyte', |
| 68 | + 'xmmword', |
| 69 | + 'ymmword', |
| 70 | + 'zmmword', |
| 71 | +) |
| 72 | + |
| 73 | +_SIZE_QUALIFIER_SET = frozenset(SIZE_QUALIFIERS) |
| 74 | + |
| 75 | +# Matches one token that stops before commas, spaces, or memory-expression operators |
| 76 | +_OPERAND_TOKEN = re.compile(r'[^\s,+\-*\[\]]+') |
| 77 | + |
| 78 | +_MEM_OPERATORS = frozenset('+-*') |
| 79 | + |
| 80 | + |
| 81 | +def parse_operands(operands: str) -> Iterator[str]: |
| 82 | + """ |
| 83 | + move through x86 string operands linearly |
| 84 | + """ |
| 85 | + offset = 0 |
| 86 | + length = len(operands) |
| 87 | + |
| 88 | + while offset < length: |
| 89 | + match operands[offset]: |
| 90 | + case ' ' | ',': |
| 91 | + offset += 1 |
| 92 | + case '[': |
| 93 | + end = operands.index(']', offset) |
| 94 | + yield '[' |
| 95 | + yield from _parse_mem_expr(operands[offset + 1 : end]) |
| 96 | + yield ']' |
| 97 | + offset = end + 1 |
| 98 | + case ch if ch in _MEM_OPERATORS: |
| 99 | + yield ch |
| 100 | + offset += 1 |
| 101 | + case _: |
| 102 | + m = _OPERAND_TOKEN.match(operands, offset) |
| 103 | + if not m: |
| 104 | + offset += 1 |
| 105 | + continue |
| 106 | + |
| 107 | + token = m.group().rstrip(':') # strip segment-override colon (e.g. "fs:") |
| 108 | + lower = token.lower() |
| 109 | + |
| 110 | + if lower in _SIZE_QUALIFIER_SET: |
| 111 | + # consume the mandatory following "ptr" keyword and merge into one token |
| 112 | + rest = operands[m.end() :].lstrip() |
| 113 | + if rest.lower().startswith('ptr'): |
| 114 | + yield f'{lower}_ptr' |
| 115 | + offset = m.end() + (len(operands[m.end() :]) - len(rest)) + 3 |
| 116 | + else: |
| 117 | + yield lower |
| 118 | + offset = m.end() |
| 119 | + else: |
| 120 | + yield lower |
| 121 | + offset = m.end() |
| 122 | + |
| 123 | + |
| 124 | +def _parse_mem_expr(expr: str) -> Iterator[str]: |
| 125 | + """ |
| 126 | + We need this because subtracting a memory expression in ghidra is done by [rbp + -0x8]. |
| 127 | + Split a memory address expression into tokens. |
| 128 | + treating negative displacements like "-0x8" as a single token rather than an operator followed by a number.""" |
| 129 | + expr = expr.strip() |
| 130 | + i = 0 |
| 131 | + length = len(expr) |
| 132 | + |
| 133 | + while i < length: |
| 134 | + match expr[i]: |
| 135 | + case ' ': |
| 136 | + i += 1 |
| 137 | + case '+' | '*': |
| 138 | + yield expr[i] |
| 139 | + i += 1 |
| 140 | + case '-': |
| 141 | + # unary minus: attach to the number that follows when preceded by an operator or start |
| 142 | + prev = expr[:i].rstrip() |
| 143 | + if not prev or prev[-1] in ('+', '-', '*'): |
| 144 | + m = _OPERAND_TOKEN.match(expr, i + 1) |
| 145 | + if m: |
| 146 | + yield f'-{m.group().lower()}' |
| 147 | + i = m.end() |
| 148 | + continue |
| 149 | + yield '-' |
| 150 | + i += 1 |
| 151 | + case _: |
| 152 | + m = _OPERAND_TOKEN.match(expr, i) |
| 153 | + if m: |
| 154 | + yield m.group().lower() |
| 155 | + i = m.end() |
| 156 | + else: |
| 157 | + i += 1 |
| 158 | + |
| 159 | + |
| 160 | +class X86Preprocessor: |
| 161 | + """ |
| 162 | + Based on the ARM64 preprocessor but adjusted for amd64 (x86_64) arch. |
| 163 | + """ |
| 164 | + |
| 165 | + def __init__( |
| 166 | + self, |
| 167 | + *, |
| 168 | + branch_instructions: tuple[str, ...] = BRANCH_INSTRUCTIONS, |
| 169 | + parse_operands: Callable[[str], Iterator[str]] = parse_operands, |
| 170 | + context_length: int = 512, |
| 171 | + prefix_tokens: tuple[str, ...] | None = None, |
| 172 | + operand_formatters: tuple[Callable, ...] | None = None, |
| 173 | + ): |
| 174 | + self.branch_instructions = frozenset(branch_instructions) |
| 175 | + self.parse_operands = parse_operands |
| 176 | + self.context_length = context_length |
| 177 | + self.prefix_tokens = prefix_tokens or () |
| 178 | + self.operand_formatters = operand_formatters or () |
| 179 | + |
| 180 | + def format_jump(self, operand: str, target_index: int | None) -> str: |
| 181 | + if target_index is None: |
| 182 | + return 'UNK_JUMP_ADDR' |
| 183 | + elif target_index < self.context_length: |
| 184 | + return f'JUMP_ADDR_{target_index}' |
| 185 | + else: |
| 186 | + return 'JUMP_ADDR_EXCEEDED' |
| 187 | + |
| 188 | + def format_operand(self, operand: str) -> str | None: |
| 189 | + for formatter in self.operand_formatters: |
| 190 | + if replacement := formatter(operand): |
| 191 | + return replacement |
| 192 | + |
| 193 | + def preprocess(self, function_blocks: dict[int, list[str]]) -> list[str]: |
| 194 | + block_offsets = {} |
| 195 | + jump_offsets = {} |
| 196 | + tokens = list(self.prefix_tokens) |
| 197 | + |
| 198 | + function_blocks = dict(sorted(function_blocks.items())) |
| 199 | + |
| 200 | + for block_id, block in function_blocks.items(): |
| 201 | + block_offsets[block_id] = len(tokens) |
| 202 | + for instruction in block: |
| 203 | + parts = instruction.lower().split(maxsplit=1) |
| 204 | + mnemonic = parts[0] |
| 205 | + operand_str = parts[1] if len(parts) > 1 else '' |
| 206 | + |
| 207 | + tokens.append(mnemonic) |
| 208 | + for operand in self.parse_operands(operand_str) if operand_str else (): |
| 209 | + if mnemonic in self.branch_instructions and (offset := is_offset(operand)): |
| 210 | + # can't slice at place 2 because negative hex values so therefore use value from is_offset regex |
| 211 | + jump_target = int(offset.group('value'), base=16) |
| 212 | + jump_offsets[len(tokens)] = jump_target |
| 213 | + else: |
| 214 | + operand = self.format_operand(operand) or operand |
| 215 | + tokens.append(operand) |
| 216 | + |
| 217 | + for offset, jump_target in jump_offsets.items(): |
| 218 | + token = tokens[offset] |
| 219 | + if replacement := self.format_jump(token, block_offsets.get(jump_target)): |
| 220 | + tokens[offset] = replacement |
| 221 | + |
| 222 | + return tokens |
0 commit comments