fix / improve find_similar_areas.py

This commit is contained in:
Ethan 2024-01-07 12:30:26 +09:00
parent 6741e3f848
commit f152efa31c
No known key found for this signature in database
GPG Key ID: 9BCC97FDA5482E7A

View File

@ -8,7 +8,8 @@ import sys
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
from mapfile_parser import MapFile, Symbol
from sty import fg from sty import fg
@ -16,28 +17,12 @@ script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
root_dir = script_dir / ".." root_dir = script_dir / ".."
asm_dir = root_dir / "ver/current/asm/nonmatchings/" asm_dir = root_dir / "ver/current/asm/nonmatchings/"
build_dir = root_dir / "ver/current/build/" build_dir = root_dir / "ver/current/build/"
elf_path = build_dir / "papermario.elf"
map_file_path = build_dir / "papermario.map" map_file_path = build_dir / "papermario.map"
rom_path = root_dir / "ver/current/baserom.z64" rom_path = root_dir / "ver/current/baserom.z64"
OBJDUMP = "mips-linux-gnu-objdump" OBJDUMP = "mips-linux-gnu-objdump"
@dataclass
class Symbol:
name: str
rom_start: int
ram: int
current_file: Path
prev_sym: str
is_decompiled: bool
rom_end: Optional[int] = None
def size(self):
assert self.rom_end is not None
return self.rom_end - self.rom_start
@dataclass @dataclass
class Bytes: class Bytes:
normalized: str normalized: str
@ -58,31 +43,13 @@ def get_all_unmatched_functions():
return ret return ret
def get_func_sizes() -> Dict[str, int]: def get_symbol_bytes(symbol: Symbol, mapfile: MapFile) -> Optional[Bytes]:
try: if not symbol.vrom or not symbol.size:
result = subprocess.run(["mips-linux-gnu-objdump", "-x", elf_path], stdout=subprocess.PIPE)
nm_lines = result.stdout.decode().split("\n")
except:
print(f"Error: Could not run objdump on {elf_path} - make sure that the project is built")
sys.exit(1)
sizes: Dict[str, int] = {}
for line in nm_lines:
if " F " in line:
components = line.split()
size = int(components[4], 16)
name = components[5]
sizes[name] = size
return sizes
def get_symbol_bytes(func: str) -> Optional[Bytes]:
if func not in syms or syms[func].rom_end is None:
return None return None
sym = syms[func]
bs = list(rom_bytes[sym.rom_start : sym.rom_end]) sym_bytes = rom_bytes[symbol.vrom : symbol.vrom + symbol.size]
bs = list(sym_bytes)
# trim nops # trim nops
while len(bs) > 0 and bs[-1] == 0: while len(bs) > 0 and bs[-1] == 0:
@ -94,67 +61,7 @@ def get_symbol_bytes(func: str) -> Optional[Bytes]:
for ins in insns: for ins in insns:
ret.append(ins >> 2) ret.append(ins >> 2)
return Bytes(0, bytes(ret).decode("utf-8"), rom_bytes[sym.rom_start : sym.rom_end]) return Bytes(bytes(ret).decode("utf-8"), sym_bytes)
def parse_map() -> OrderedDict[str, Symbol]:
ram_offset = None
cur_file = "<no file>"
syms: OrderedDict[str, Symbol] = OrderedDict()
prev_sym = ""
prev_line = ""
cur_sect = ""
sect_re = re.compile(r"\(\..*\)")
with open(map_file_path) as f:
for line in f:
sect = sect_re.search(line)
if sect:
sect_str = sect.group(0)
if sect_str in ["(.text*)", "(.data*)", "(.rodata*)", "(.bss*)"]:
cur_sect = sect_str
if "load address" in line:
if "noload" in line or "noload" in prev_line:
ram_offset = None
continue
ram = int(line[16 : 16 + 18], 0)
rom = int(line[59 : 59 + 18], 0)
ram_offset = ram - rom
continue
prev_line = line
if ram_offset is None or "=" in line or "*fill*" in line or " 0x" not in line:
continue
ram = int(line[16 : 16 + 18], 0)
rom = ram - ram_offset
fn = line.split()[-1]
if "0x" in fn:
ram_offset = None
elif "/" in fn:
cur_file = fn
else:
if cur_sect != "(.text*)":
continue
new_sym = Symbol(
name=fn,
rom_start=rom,
ram=ram,
current_file=Path(cur_file),
prev_sym=prev_sym,
is_decompiled=not fn in unmatched_functions,
)
if fn in func_sizes:
new_sym.rom_end = rom + func_sizes[fn]
syms[fn] = new_sym
prev_sym = fn
# Calc end offsets
for sym in syms:
prev_sym = syms[sym].prev_sym
if prev_sym and not syms[prev_sym].rom_end:
syms[prev_sym].rom_end = syms[sym].rom_start
return syms
@dataclass @dataclass
@ -169,8 +76,8 @@ class Match:
@dataclass @dataclass
class Result: class Result:
query: str query: Symbol
target: str target: Symbol
query_start: int query_start: int
target_start: int target_start: int
length: int length: int
@ -201,8 +108,8 @@ def get_hashes(bytes: Bytes, window_size: int) -> list[str]:
def group_matches( def group_matches(
query: str, query: Symbol,
target: str, target: Symbol,
matches: list[Match], matches: list[Match],
window_size: int, window_size: int,
min: Optional[int], min: Optional[int],
@ -358,80 +265,91 @@ def get_matches(
contains: Optional[int], contains: Optional[int],
show_disasm: bool, show_disasm: bool,
): ):
query_bytes: Optional[Bytes] = get_symbol_bytes(query) query_sym_info = mapfile.findSymbolByName(query)
if not query_sym_info:
sys.exit("Symbol '" + query + "' not found")
query_symbol = query_sym_info.symbol
query_bytes: Optional[Bytes] = get_symbol_bytes(query_symbol, mapfile)
if query_bytes is None: if query_bytes is None:
sys.exit("Symbol '" + query + "' not found") sys.exit("Symbol '" + query + "' bytes not found")
query_hashes = get_hashes(query_bytes, window_size) query_hashes = get_hashes(query_bytes, window_size)
ret: dict[str, float] = {} ret: dict[str, float] = {}
for symbol in syms:
if query == symbol:
continue
sym_bytes: Optional[Bytes] = get_symbol_bytes(symbol) for segment in mapfile._segmentsList:
if not sym_bytes: for file in segment._filesList:
continue if file.sectionType != ".text":
continue
if len(sym_bytes.bytes) / 4 < window_size: for symbol in file._symbols:
continue if query_symbol == symbol:
continue
sym_hashes = get_hashes(sym_bytes, window_size) sym_bytes: Optional[Bytes] = get_symbol_bytes(symbol, mapfile)
if not sym_bytes:
continue
matches: list[Match] = get_pair_matches(query_hashes, sym_hashes) if len(sym_bytes.bytes) / 4 < window_size:
if not matches: continue
continue
results: list[Result] = group_matches(query, symbol, matches, window_size, min, max, contains) sym_hashes = get_hashes(sym_bytes, window_size)
if not results:
continue
obj_file = syms[symbol].current_file matches: list[Match] = get_pair_matches(query_hashes, sym_hashes)
if not matches:
continue
line_numbers = {} results: list[Result] = group_matches(query, symbol, matches, window_size, min, max, contains)
tu_offset = None if not results:
decompiled_str = ":" continue
if syms[symbol].is_decompiled:
line_numbers = get_line_numbers(obj_file)
tu_offset = get_tu_offset(obj_file, symbol)
decompiled_str = fg.green + " (decompiled)" + fg.rs + ":"
print(symbol + decompiled_str) obj_file = file.filepath
for result in results: line_numbers = {}
c_range = None tu_offset = None
if tu_offset is not None and len(line_numbers) > 0: decompiled_str = ":"
c_range = get_c_range(
tu_offset + (result.target_start * 4),
tu_offset + (result.target_end * 4),
line_numbers,
)
target_range_str = "" if symbol.name not in unmatched_functions:
if c_range: line_numbers = get_line_numbers(obj_file)
target_range_str = fg.li_cyan + f" (line {c_range} in {obj_file.stem})" + fg.rs tu_offset = get_tu_offset(obj_file, symbol.name)
decompiled_str = fg.green + " (decompiled)" + fg.rs + ":"
query_str = f"query [{result.query_start}-{result.query_end}]" print(symbol.name + decompiled_str)
target_str = (
f"{symbol} [insn {result.target_start}-{result.target_end}] ({result.length} total){target_range_str}"
)
print(f"\t{query_str} matches {target_str}")
if show_disasm: for result in results:
try: c_range = None
import rabbitizer if tu_offset is not None and len(line_numbers) > 0:
except ImportError: c_range = get_c_range(
print("rabbitizer not found, cannot show disassembly") tu_offset + (result.target_start * 4),
sys.exit(1) tu_offset + (result.target_end * 4),
result_query_bytes = query_bytes.bytes[result.query_start * 4 : result.query_end * 4] line_numbers,
result_target_bytes = sym_bytes.bytes[result.target_start * 4 : result.target_end * 4] )
for i in range(0, len(result_query_bytes), 4): target_range_str = ""
q_insn = rabbitizer.Instruction(int.from_bytes(result_query_bytes[i : i + 4], "big")) if c_range:
t_insn = rabbitizer.Instruction(int.from_bytes(result_target_bytes[i : i + 4], "big")) target_range_str = fg.li_cyan + f" (line {c_range} in {obj_file.stem})" + fg.rs
print(f"\t\t{q_insn.disassemble():35} | {t_insn.disassemble()}") query_str = f"query [{result.query_start}-{result.query_end}]"
target_str = f"{symbol.name} [insn {result.target_start}-{result.target_end}] ({result.length} total){target_range_str}"
print(f"\t{query_str} matches {target_str}")
if show_disasm:
try:
import rabbitizer
except ImportError:
print("rabbitizer not found, cannot show disassembly")
sys.exit(1)
result_query_bytes = query_bytes.bytes[result.query_start * 4 : result.query_end * 4]
result_target_bytes = sym_bytes.bytes[result.target_start * 4 : result.target_end * 4]
for i in range(0, len(result_query_bytes), 4):
q_insn = rabbitizer.Instruction(int.from_bytes(result_query_bytes[i : i + 4], "big"))
t_insn = rabbitizer.Instruction(int.from_bytes(result_target_bytes[i : i + 4], "big"))
print(f"\t\t{q_insn.disassemble():35} | {t_insn.disassemble()}")
return OrderedDict(sorted(ret.items(), key=lambda kv: kv[1], reverse=True)) return OrderedDict(sorted(ret.items(), key=lambda kv: kv[1], reverse=True))
@ -482,8 +400,9 @@ args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
rom_bytes = read_rom() rom_bytes = read_rom()
unmatched_functions = get_all_unmatched_functions() unmatched_functions = get_all_unmatched_functions()
func_sizes = get_func_sizes()
syms = parse_map() mapfile = MapFile()
mapfile.readMapFile(map_file_path)
do_query( do_query(
args.query, args.query,