From f152efa31cf20ef404bddb92bfed84f05d40ab5e Mon Sep 17 00:00:00 2001 From: Ethan Date: Sun, 7 Jan 2024 12:30:26 +0900 Subject: [PATCH] fix / improve find_similar_areas.py --- tools/find_similar_areas.py | 243 ++++++++++++------------------------ 1 file changed, 81 insertions(+), 162 deletions(-) diff --git a/tools/find_similar_areas.py b/tools/find_similar_areas.py index 09125e68b0..650fc5bc35 100755 --- a/tools/find_similar_areas.py +++ b/tools/find_similar_areas.py @@ -8,7 +8,8 @@ import sys from collections import OrderedDict from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional +from mapfile_parser import MapFile, Symbol from sty import fg @@ -16,28 +17,12 @@ script_dir = Path(os.path.dirname(os.path.realpath(__file__))) root_dir = script_dir / ".." asm_dir = root_dir / "ver/current/asm/nonmatchings/" build_dir = root_dir / "ver/current/build/" -elf_path = build_dir / "papermario.elf" map_file_path = build_dir / "papermario.map" rom_path = root_dir / "ver/current/baserom.z64" OBJDUMP = "mips-linux-gnu-objdump" -@dataclass -class Symbol: - name: str - rom_start: int - ram: int - current_file: Path - prev_sym: str - is_decompiled: bool - rom_end: Optional[int] = None - - def size(self): - assert self.rom_end is not None - return self.rom_end - self.rom_start - - @dataclass class Bytes: normalized: str @@ -58,31 +43,13 @@ def get_all_unmatched_functions(): return ret -def get_func_sizes() -> Dict[str, int]: - try: - result = subprocess.run(["mips-linux-gnu-objdump", "-x", elf_path], stdout=subprocess.PIPE) - nm_lines = result.stdout.decode().split("\n") - except: - print(f"Error: Could not run objdump on {elf_path} - make sure that the project is built") - sys.exit(1) - - sizes: Dict[str, int] = {} - - for line in nm_lines: - if " F " in line: - components = line.split() - size = int(components[4], 16) - name = components[5] - sizes[name] = size - - return sizes - - -def get_symbol_bytes(func: str) -> Optional[Bytes]: - if func not in syms or syms[func].rom_end is None: +def get_symbol_bytes(symbol: Symbol, mapfile: MapFile) -> Optional[Bytes]: + if not symbol.vrom or not symbol.size: return None - sym = syms[func] - bs = list(rom_bytes[sym.rom_start : sym.rom_end]) + + sym_bytes = rom_bytes[symbol.vrom : symbol.vrom + symbol.size] + + bs = list(sym_bytes) # trim nops while len(bs) > 0 and bs[-1] == 0: @@ -94,67 +61,7 @@ def get_symbol_bytes(func: str) -> Optional[Bytes]: for ins in insns: ret.append(ins >> 2) - return Bytes(0, bytes(ret).decode("utf-8"), rom_bytes[sym.rom_start : sym.rom_end]) - - -def parse_map() -> OrderedDict[str, Symbol]: - ram_offset = None - cur_file = "" - syms: OrderedDict[str, Symbol] = OrderedDict() - prev_sym = "" - prev_line = "" - cur_sect = "" - sect_re = re.compile(r"\(\..*\)") - with open(map_file_path) as f: - for line in f: - sect = sect_re.search(line) - if sect: - sect_str = sect.group(0) - if sect_str in ["(.text*)", "(.data*)", "(.rodata*)", "(.bss*)"]: - cur_sect = sect_str - - if "load address" in line: - if "noload" in line or "noload" in prev_line: - ram_offset = None - continue - ram = int(line[16 : 16 + 18], 0) - rom = int(line[59 : 59 + 18], 0) - ram_offset = ram - rom - continue - prev_line = line - - if ram_offset is None or "=" in line or "*fill*" in line or " 0x" not in line: - continue - ram = int(line[16 : 16 + 18], 0) - rom = ram - ram_offset - fn = line.split()[-1] - if "0x" in fn: - ram_offset = None - elif "/" in fn: - cur_file = fn - else: - if cur_sect != "(.text*)": - continue - new_sym = Symbol( - name=fn, - rom_start=rom, - ram=ram, - current_file=Path(cur_file), - prev_sym=prev_sym, - is_decompiled=not fn in unmatched_functions, - ) - if fn in func_sizes: - new_sym.rom_end = rom + func_sizes[fn] - syms[fn] = new_sym - prev_sym = fn - - # Calc end offsets - for sym in syms: - prev_sym = syms[sym].prev_sym - if prev_sym and not syms[prev_sym].rom_end: - syms[prev_sym].rom_end = syms[sym].rom_start - - return syms + return Bytes(bytes(ret).decode("utf-8"), sym_bytes) @dataclass @@ -169,8 +76,8 @@ class Match: @dataclass class Result: - query: str - target: str + query: Symbol + target: Symbol query_start: int target_start: int length: int @@ -201,8 +108,8 @@ def get_hashes(bytes: Bytes, window_size: int) -> list[str]: def group_matches( - query: str, - target: str, + query: Symbol, + target: Symbol, matches: list[Match], window_size: int, min: Optional[int], @@ -358,80 +265,91 @@ def get_matches( contains: Optional[int], show_disasm: bool, ): - query_bytes: Optional[Bytes] = get_symbol_bytes(query) + query_sym_info = mapfile.findSymbolByName(query) + if not query_sym_info: + sys.exit("Symbol '" + query + "' not found") + + query_symbol = query_sym_info.symbol + + query_bytes: Optional[Bytes] = get_symbol_bytes(query_symbol, mapfile) if query_bytes is None: - sys.exit("Symbol '" + query + "' not found") + sys.exit("Symbol '" + query + "' bytes not found") query_hashes = get_hashes(query_bytes, window_size) ret: dict[str, float] = {} - for symbol in syms: - if query == symbol: - continue - sym_bytes: Optional[Bytes] = get_symbol_bytes(symbol) - if not sym_bytes: - continue + for segment in mapfile._segmentsList: + for file in segment._filesList: + if file.sectionType != ".text": + continue - if len(sym_bytes.bytes) / 4 < window_size: - continue + for symbol in file._symbols: + if query_symbol == symbol: + continue - sym_hashes = get_hashes(sym_bytes, window_size) + sym_bytes: Optional[Bytes] = get_symbol_bytes(symbol, mapfile) + if not sym_bytes: + continue - matches: list[Match] = get_pair_matches(query_hashes, sym_hashes) - if not matches: - continue + if len(sym_bytes.bytes) / 4 < window_size: + continue - results: list[Result] = group_matches(query, symbol, matches, window_size, min, max, contains) - if not results: - continue + sym_hashes = get_hashes(sym_bytes, window_size) - obj_file = syms[symbol].current_file + matches: list[Match] = get_pair_matches(query_hashes, sym_hashes) + if not matches: + continue - line_numbers = {} - tu_offset = None - decompiled_str = ":" - if syms[symbol].is_decompiled: - line_numbers = get_line_numbers(obj_file) - tu_offset = get_tu_offset(obj_file, symbol) - decompiled_str = fg.green + " (decompiled)" + fg.rs + ":" + results: list[Result] = group_matches(query, symbol, matches, window_size, min, max, contains) + if not results: + continue - print(symbol + decompiled_str) + obj_file = file.filepath - for result in results: - c_range = None - if tu_offset is not None and len(line_numbers) > 0: - c_range = get_c_range( - tu_offset + (result.target_start * 4), - tu_offset + (result.target_end * 4), - line_numbers, - ) + line_numbers = {} + tu_offset = None + decompiled_str = ":" - target_range_str = "" - if c_range: - target_range_str = fg.li_cyan + f" (line {c_range} in {obj_file.stem})" + fg.rs + if symbol.name not in unmatched_functions: + line_numbers = get_line_numbers(obj_file) + tu_offset = get_tu_offset(obj_file, symbol.name) + decompiled_str = fg.green + " (decompiled)" + fg.rs + ":" - query_str = f"query [{result.query_start}-{result.query_end}]" - target_str = ( - f"{symbol} [insn {result.target_start}-{result.target_end}] ({result.length} total){target_range_str}" - ) - print(f"\t{query_str} matches {target_str}") + print(symbol.name + decompiled_str) - if show_disasm: - try: - import rabbitizer - except ImportError: - print("rabbitizer not found, cannot show disassembly") - sys.exit(1) - result_query_bytes = query_bytes.bytes[result.query_start * 4 : result.query_end * 4] - result_target_bytes = sym_bytes.bytes[result.target_start * 4 : result.target_end * 4] + for result in results: + c_range = None + if tu_offset is not None and len(line_numbers) > 0: + c_range = get_c_range( + tu_offset + (result.target_start * 4), + tu_offset + (result.target_end * 4), + line_numbers, + ) - for i in range(0, len(result_query_bytes), 4): - q_insn = rabbitizer.Instruction(int.from_bytes(result_query_bytes[i : i + 4], "big")) - t_insn = rabbitizer.Instruction(int.from_bytes(result_target_bytes[i : i + 4], "big")) + target_range_str = "" + if c_range: + target_range_str = fg.li_cyan + f" (line {c_range} in {obj_file.stem})" + fg.rs - print(f"\t\t{q_insn.disassemble():35} | {t_insn.disassemble()}") + query_str = f"query [{result.query_start}-{result.query_end}]" + target_str = f"{symbol.name} [insn {result.target_start}-{result.target_end}] ({result.length} total){target_range_str}" + print(f"\t{query_str} matches {target_str}") + + if show_disasm: + try: + import rabbitizer + except ImportError: + print("rabbitizer not found, cannot show disassembly") + sys.exit(1) + result_query_bytes = query_bytes.bytes[result.query_start * 4 : result.query_end * 4] + result_target_bytes = sym_bytes.bytes[result.target_start * 4 : result.target_end * 4] + + for i in range(0, len(result_query_bytes), 4): + q_insn = rabbitizer.Instruction(int.from_bytes(result_query_bytes[i : i + 4], "big")) + t_insn = rabbitizer.Instruction(int.from_bytes(result_target_bytes[i : i + 4], "big")) + + print(f"\t\t{q_insn.disassemble():35} | {t_insn.disassemble()}") return OrderedDict(sorted(ret.items(), key=lambda kv: kv[1], reverse=True)) @@ -482,8 +400,9 @@ args = parser.parse_args() if __name__ == "__main__": rom_bytes = read_rom() unmatched_functions = get_all_unmatched_functions() - func_sizes = get_func_sizes() - syms = parse_map() + + mapfile = MapFile() + mapfile.readMapFile(map_file_path) do_query( args.query,