papermario/tools/assist.py
Ethan Roseman 68792f4331
Misc decomp (#193)
* fixing segment

* redoing subtree

* git subrepo clone https://github.com/ethteck/splat.git tools/splat

subrepo:
  subdir:   "tools/splat"
  merged:   "99cf514297"
upstream:
  origin:   "https://github.com/ethteck/splat.git"
  branch:   "master"
  commit:   "99cf514297"
git-subrepo:
  version:  "0.4.3"
  origin:   "https://github.com/ingydotnet/git-subrepo"
  commit:   "2f68596"

* fix symbol_addrs

* sweet victory

* reformat collision_main_lateral.s

* Lots of matches, fixed diff.py issue, improved coverage.py

* some more

* update symbol_addrs

* git subrepo pull tools/splat

subrepo:
  subdir:   "tools/splat"
  merged:   "c46026725a"
upstream:
  origin:   "https://github.com/ethteck/splat.git"
  branch:   "master"
  commit:   "c46026725a"
git-subrepo:
  version:  "0.4.3"
  origin:   "https://github.com/ingydotnet/git-subrepo"
  commit:   "2f68596"

* update symbol_addrs

* fix disasm

* 2 mo

* Jenkinsfile

* PR comments

* check_input_status_menu, removing some macros

* a few funcs

* assist updates

* UnkAlphaFunc

* UnkPartnerPosFuncs

* Deleting stuff

* oopz

* dead_syms.txt, UnkNpcAIFunc14

* update symbol_addrs

* cleanup

* symbol_addrs

* cleanup

* Set80071270_0_16

* UnkFloatFunc2

* fix watt data start

* UnkNpcAIFunc15

* 🐱

* 🐱 🐱

* symbol_addrs

* fix assist.py

* delete unnecessary data file

* fix diff.py, remove parens

* parens
2021-02-22 22:00:49 +09:00

240 lines
6.7 KiB
Python
Executable File

#!/usr/bin/python3
import argparse
from collections import Counter, OrderedDict
from Levenshtein import ratio
import os
import re
import sys
script_dir = os.path.dirname(os.path.realpath(__file__))
root_dir = script_dir + "/../"
asm_dir = root_dir + "asm/nonmatchings/"
build_dir = root_dir + "build/"
def read_rom():
with open("baserom.z64", "rb") as f:
return f.read()
def find_dir(query):
for root, dirs, files in os.walk(asm_dir):
for d in dirs:
if d == query:
return os.path.join(root, d)
return None
def get_all_s_files():
ret = set()
for root, dirs, files in os.walk(asm_dir):
for f in files:
if f.endswith(".s"):
ret.add(f[:-2])
return ret
def get_symbol_length(sym_name):
if "end" in map_offsets[sym_name] and "start" in map_offsets[sym_name]:
return map_offsets[sym_name]["end"] - map_offsets[sym_name]["start"]
return 0
def get_symbol_bytes(offsets, func):
if func not in offsets or "start" not in offsets[func] or "end" not in offsets[func]:
return None
start = offsets[func]["start"]
end = offsets[func]["end"]
bs = list(rom_bytes[start:end])
while len(bs) > 0 and bs[-1] == 0:
bs.pop()
insns = bs[0::4]
ret = []
for ins in insns:
ret.append(ins >> 2)
return bytes(ret).decode('utf-8'), bs
def parse_map(fname):
ram_offset = None
cur_file = "<no file>"
syms = {}
prev_sym = None
prev_line = ""
with open(fname) as f:
for line in f:
if "load address" in line:
if "noload" in line or "noload" in prev_line:
ram_offset = None
continue
ram = int(line[16 : 16 + 18], 0)
rom = int(line[59 : 59 + 18], 0)
ram_offset = ram - rom
continue
prev_line = line
if (
ram_offset is None
or "=" in line
or "*fill*" in line
or " 0x" not in line
):
continue
ram = int(line[16 : 16 + 18], 0)
rom = ram - ram_offset
fn = line.split()[-1]
if "0x" in fn:
ram_offset = None
elif "/" in fn:
cur_file = fn
else:
syms[fn] = (rom, cur_file, prev_sym, ram)
prev_sym = fn
return syms
def get_map_offsets(syms):
offsets = {}
for sym in syms:
prev_sym = syms[sym][2]
if sym not in offsets:
offsets[sym] = {}
if prev_sym not in offsets:
offsets[prev_sym] = {}
offsets[sym]["start"] = syms[sym][0]
offsets[prev_sym]["end"] = syms[sym][0]
return offsets
def is_zeros(vals):
for val in vals:
if val != 0:
return False
return True
def diff_syms(qb, tb):
if len(tb[1]) < 8:
return 0
r = ratio(qb[0], tb[0])
if r == 1.0 and qb[1] != tb[1]:
r = 0.99
return r
def get_pair_score(query_bytes, b):
b_bytes = get_symbol_bytes(map_offsets, b)
if query_bytes and b_bytes:
return diff_syms(query_bytes, b_bytes)
return 0
def get_matches(query):
query_bytes = get_symbol_bytes(map_offsets, query)
if query_bytes is None:
sys.exit("Symbol '" + query + "' not found")
ret = {}
for symbol in map_offsets:
if symbol is not None and query != symbol:
score = get_pair_score(query_bytes, symbol)
if score >= args.threshold:
ret[symbol] = score
return OrderedDict(sorted(ret.items(), key=lambda kv: kv[1], reverse=True))
def do_query(query):
matches = get_matches(query)
num_matches = len(matches)
if num_matches == 0:
print(query + " - found no matches")
return
i = 0
more_str = ":"
if args.num_out < num_matches:
more_str = " (showing only " + str(args.num_out) + "):"
print(query + " - found " + str(num_matches) + " matches total" + more_str)
for match in matches:
if i == args.num_out:
break
match_str = "{:.2f} - {}".format(matches[match], match)
if match not in s_files:
match_str += " (decompiled)"
print(match_str)
i += 1
print()
def do_cross_query():
ccount = Counter()
clusters = []
for sym_name in map_syms:
if not sym_name.startswith("D_") and \
not sym_name.startswith("_binary") and \
not sym_name.startswith("jtbl_") and \
not re.match(r"L[0-9A-F]{8}_[0-9A-F]{5,6}", sym_name):
if get_symbol_length(sym_name) > 16:
query_bytes = get_symbol_bytes(map_offsets, sym_name)
cluster_match = False
for cluster in clusters:
cluster_first = cluster[0]
cluster_score = get_pair_score(query_bytes, cluster_first)
if cluster_score >= args.threshold:
cluster_match = True
if sym_name.startswith("func") and not cluster_first.startswith("func"):
ccount[sym_name] = ccount[cluster_first]
del ccount[cluster_first]
cluster_first = sym_name
cluster.insert(0, cluster_first)
else:
cluster.append(sym_name)
if cluster_first.startswith("func"):
ccount[cluster_first] += 1
#if len(cluster) % 10 == 0 and len(cluster) >= 10:
print(f"Cluster {cluster_first} grew to size {len(cluster)} - {sym_name}: {str(cluster_score)}")
break
if not cluster_match:
clusters.append([sym_name])
print(ccount.most_common(100))
parser = argparse.ArgumentParser(description="Tools to assist with decomp")
parser.add_argument("query", help="function or file")
parser.add_argument("--threshold", help="score threshold between 0 and 1 (higher is more restrictive)", type=float, default=0.9, required=False)
parser.add_argument("--num-out", help="number of functions to display", type=int, default=100, required=False)
args = parser.parse_args()
rom_bytes = read_rom()
map_syms = parse_map(os.path.join(root_dir, "ver", "current", "build", "papermario.map"))
map_offsets = get_map_offsets(map_syms)
s_files = get_all_s_files()
query_dir = find_dir(args.query)
if query_dir is not None:
files = os.listdir(query_dir)
for f_name in files:
do_query(f_name[:-2])
else:
if args.query == "cross":
args.threshold = 0.985
do_cross_query()
else:
do_query(args.query)