papermario/tools/find_similar_areas.py
2024-01-07 12:30:26 +09:00

415 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import os
import re
import subprocess
import sys
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
from mapfile_parser import MapFile, Symbol
from sty import fg
script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
root_dir = script_dir / ".."
asm_dir = root_dir / "ver/current/asm/nonmatchings/"
build_dir = root_dir / "ver/current/build/"
map_file_path = build_dir / "papermario.map"
rom_path = root_dir / "ver/current/baserom.z64"
OBJDUMP = "mips-linux-gnu-objdump"
@dataclass
class Bytes:
normalized: str
bytes: bytes
def read_rom() -> bytes:
with open(rom_path, "rb") as f:
return f.read()
def get_all_unmatched_functions():
ret = set()
for root, dirs, files in os.walk(asm_dir):
for f in files:
if f.endswith(".s"):
ret.add(f[:-2])
return ret
def get_symbol_bytes(symbol: Symbol, mapfile: MapFile) -> Optional[Bytes]:
if not symbol.vrom or not symbol.size:
return None
sym_bytes = rom_bytes[symbol.vrom : symbol.vrom + symbol.size]
bs = list(sym_bytes)
# trim nops
while len(bs) > 0 and bs[-1] == 0:
bs.pop()
insns = bs[0::4]
ret = []
for ins in insns:
ret.append(ins >> 2)
return Bytes(bytes(ret).decode("utf-8"), sym_bytes)
@dataclass
class Match:
query_offset: int
target_offset: int
length: int
def __str__(self):
return f"{self.query_offset} {self.target_offset} {self.length}"
@dataclass
class Result:
query: Symbol
target: Symbol
query_start: int
target_start: int
length: int
@property
def query_end(self):
return self.query_start + self.length
@property
def target_end(self):
return self.target_start + self.length
def get_pair_matches(query_hashes: list[str], sym_hashes: list[str]) -> list[Match]:
ret = []
matching_hashes = set(query_hashes).intersection(sym_hashes)
for hash in matching_hashes:
ret.append(Match(query_hashes.index(hash), sym_hashes.index(hash), 1))
return ret
def get_hashes(bytes: Bytes, window_size: int) -> list[str]:
ret = []
for i in range(0, len(bytes.normalized) - window_size):
ret.append(bytes.normalized[i : i + window_size])
return ret
def group_matches(
query: Symbol,
target: Symbol,
matches: list[Match],
window_size: int,
min: Optional[int],
max: Optional[int],
contains: Optional[int],
) -> list[Result]:
ret = []
matches.sort(key=lambda m: m.query_offset)
match_groups: List[List[Match]] = []
last_start = matches[0].query_offset
for match in matches:
if match.query_offset == last_start + 1:
match_groups[-1].append(match)
else:
match_groups.append([match])
last_start = match.query_offset
for group in match_groups:
query_start = group[0].query_offset
target_start = group[0].target_offset
length = len(group) + window_size
if min is not None and query_start + length < min:
continue
if max is not None and query_start > max:
continue
if contains is not None and (query_start > contains or query_start + length < contains):
continue
ret.append(Result(query, target, query_start, target_start, length))
return ret
def get_line_numbers(obj_file: Path) -> Dict[int, int]:
ret = {}
objdump_out = (
subprocess.run(
[OBJDUMP, "-WL", obj_file],
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)
.stdout.decode("utf-8")
.split("\n")
)
if not objdump_out:
return {}
for line in objdump_out[7:]:
if not line:
continue
pieces = line.split()
if len(pieces) < 3:
continue
fn = pieces[0]
if fn == OBJDUMP or fn[0] == "<":
continue
starting_addr = int(pieces[2], 0)
try:
line_num = int(pieces[1])
ret[starting_addr] = line_num
except ValueError:
continue
return ret
def get_tu_offset(obj_file: Path, symbol: str) -> Optional[int]:
objdump = "mips-linux-gnu-objdump"
objdump_out = subprocess.run([objdump, "-t", obj_file], stdout=subprocess.PIPE).stdout.decode("utf-8").split("\n")
if not objdump_out:
return None
for line in objdump_out[4:]:
if not line:
continue
pieces = line.split()
if pieces[-1] == symbol:
return int(pieces[0], 16)
return None
@dataclass
class CRange:
start: Optional[int] = None
end: Optional[int] = None
start_exact = False
end_exact = False
def has_info(self):
return self.start is not None or self.end is not None
def __str__(self):
start_str = "?"
end_str = "?"
if self.start is not None:
if self.start_exact:
start_str = f"{self.start}"
else:
start_str = f"~{self.start}"
if self.end is not None:
if self.end_exact:
end_str = f"{self.end}"
else:
end_str = f"~{self.end}"
return f"{start_str} - {end_str}"
def get_c_range(insn_start: int, insn_end: int, line_numbers: Dict[int, int]) -> CRange:
range = CRange()
if insn_start in line_numbers:
range.start = line_numbers[insn_start]
range.start_exact = True
else:
keys = list(line_numbers.keys())
for i, key in enumerate(keys[:-1]):
if keys[i + 1] > insn_start:
range.start = line_numbers[keys[i]]
break
if insn_end in line_numbers:
range.end = line_numbers[insn_end]
range.end_exact = True
else:
keys = list(line_numbers.keys())
for i, key in enumerate(keys):
if key > insn_end:
range.end = line_numbers[key]
break
return range
def get_matches(
query: str,
window_size: int,
min: Optional[int],
max: Optional[int],
contains: Optional[int],
show_disasm: bool,
):
query_sym_info = mapfile.findSymbolByName(query)
if not query_sym_info:
sys.exit("Symbol '" + query + "' not found")
query_symbol = query_sym_info.symbol
query_bytes: Optional[Bytes] = get_symbol_bytes(query_symbol, mapfile)
if query_bytes is None:
sys.exit("Symbol '" + query + "' bytes not found")
query_hashes = get_hashes(query_bytes, window_size)
ret: dict[str, float] = {}
for segment in mapfile._segmentsList:
for file in segment._filesList:
if file.sectionType != ".text":
continue
for symbol in file._symbols:
if query_symbol == symbol:
continue
sym_bytes: Optional[Bytes] = get_symbol_bytes(symbol, mapfile)
if not sym_bytes:
continue
if len(sym_bytes.bytes) / 4 < window_size:
continue
sym_hashes = get_hashes(sym_bytes, window_size)
matches: list[Match] = get_pair_matches(query_hashes, sym_hashes)
if not matches:
continue
results: list[Result] = group_matches(query, symbol, matches, window_size, min, max, contains)
if not results:
continue
obj_file = file.filepath
line_numbers = {}
tu_offset = None
decompiled_str = ":"
if symbol.name not in unmatched_functions:
line_numbers = get_line_numbers(obj_file)
tu_offset = get_tu_offset(obj_file, symbol.name)
decompiled_str = fg.green + " (decompiled)" + fg.rs + ":"
print(symbol.name + decompiled_str)
for result in results:
c_range = None
if tu_offset is not None and len(line_numbers) > 0:
c_range = get_c_range(
tu_offset + (result.target_start * 4),
tu_offset + (result.target_end * 4),
line_numbers,
)
target_range_str = ""
if c_range:
target_range_str = fg.li_cyan + f" (line {c_range} in {obj_file.stem})" + fg.rs
query_str = f"query [{result.query_start}-{result.query_end}]"
target_str = f"{symbol.name} [insn {result.target_start}-{result.target_end}] ({result.length} total){target_range_str}"
print(f"\t{query_str} matches {target_str}")
if show_disasm:
try:
import rabbitizer
except ImportError:
print("rabbitizer not found, cannot show disassembly")
sys.exit(1)
result_query_bytes = query_bytes.bytes[result.query_start * 4 : result.query_end * 4]
result_target_bytes = sym_bytes.bytes[result.target_start * 4 : result.target_end * 4]
for i in range(0, len(result_query_bytes), 4):
q_insn = rabbitizer.Instruction(int.from_bytes(result_query_bytes[i : i + 4], "big"))
t_insn = rabbitizer.Instruction(int.from_bytes(result_target_bytes[i : i + 4], "big"))
print(f"\t\t{q_insn.disassemble():35} | {t_insn.disassemble()}")
return OrderedDict(sorted(ret.items(), key=lambda kv: kv[1], reverse=True))
def do_query(query, window_size, min, max, contains, show_disasm):
get_matches(query, window_size, min, max, contains, show_disasm)
parser = argparse.ArgumentParser(
description="Tool to find duplicate portions of code from one function in code across the codebase"
)
parser.add_argument("query", help="function")
parser.add_argument(
"-w",
"--window-size",
help="number of bytes to compare",
type=int,
default=20,
required=False,
)
parser.add_argument(
"--min",
help="lower bound of instruction for matches against query",
type=int,
required=False,
)
parser.add_argument(
"--max",
help="upper bound of instruction for matches against query",
type=int,
required=False,
)
parser.add_argument(
"--contains",
help="All matches must contain this number'th instruction from the query",
type=int,
required=False,
)
parser.add_argument(
"--show-disasm",
help="Show disassembly of matches",
action="store_true",
required=False,
)
args = parser.parse_args()
if __name__ == "__main__":
rom_bytes = read_rom()
unmatched_functions = get_all_unmatched_functions()
mapfile = MapFile()
mapfile.readMapFile(map_file_path)
do_query(
args.query,
args.window_size,
args.min,
args.max,
args.contains,
args.show_disasm,
)