2020-10-19 23:29:40 +02:00
#!/usr/bin/python3
import argparse
2020-12-30 12:08:27 +01:00
from collections import Counter , OrderedDict
2021-08-02 16:04:11 +02:00
from datetime import datetime
2021-02-20 15:41:30 +01:00
from Levenshtein import ratio
2020-10-19 23:29:40 +02:00
import os
2021-02-20 15:41:30 +01:00
import re
2020-10-19 23:29:40 +02:00
import sys
script_dir = os . path . dirname ( os . path . realpath ( __file__ ) )
root_dir = script_dir + " /../ "
2021-03-12 16:53:11 +01:00
asm_dir = root_dir + " ver/current/asm/nonmatchings/ "
2020-10-19 23:29:40 +02:00
build_dir = root_dir + " build/ "
2021-08-02 16:04:11 +02:00
2020-10-19 23:29:40 +02:00
def read_rom ( ) :
2021-09-22 12:57:12 +02:00
with open ( root_dir + " ver/current/baserom.z64 " , " rb " ) as f :
2020-10-19 23:29:40 +02:00
return f . read ( )
def find_dir ( query ) :
for root , dirs , files in os . walk ( asm_dir ) :
for d in dirs :
if d == query :
return os . path . join ( root , d )
return None
def get_all_s_files ( ) :
ret = set ( )
for root , dirs , files in os . walk ( asm_dir ) :
for f in files :
if f . endswith ( " .s " ) :
ret . add ( f [ : - 2 ] )
return ret
2020-11-20 05:08:05 +01:00
def get_symbol_length ( sym_name ) :
if " end " in map_offsets [ sym_name ] and " start " in map_offsets [ sym_name ] :
return map_offsets [ sym_name ] [ " end " ] - map_offsets [ sym_name ] [ " start " ]
return 0
2020-10-19 23:29:40 +02:00
def get_symbol_bytes ( offsets , func ) :
if func not in offsets or " start " not in offsets [ func ] or " end " not in offsets [ func ] :
return None
start = offsets [ func ] [ " start " ]
end = offsets [ func ] [ " end " ]
2021-02-22 14:00:49 +01:00
bs = list ( rom_bytes [ start : end ] )
2021-02-20 15:41:30 +01:00
while len ( bs ) > 0 and bs [ - 1 ] == 0 :
bs . pop ( )
2021-02-22 14:00:49 +01:00
insns = bs [ 0 : : 4 ]
2021-02-18 12:54:07 +01:00
ret = [ ]
2021-02-22 14:00:49 +01:00
for ins in insns :
2021-02-18 12:54:07 +01:00
ret . append ( ins >> 2 )
2021-02-20 15:41:30 +01:00
return bytes ( ret ) . decode ( ' utf-8 ' ) , bs
2020-10-19 23:29:40 +02:00
def parse_map ( fname ) :
ram_offset = None
cur_file = " <no file> "
syms = { }
prev_sym = None
prev_line = " "
with open ( fname ) as f :
for line in f :
if " load address " in line :
if " noload " in line or " noload " in prev_line :
ram_offset = None
continue
ram = int ( line [ 16 : 16 + 18 ] , 0 )
rom = int ( line [ 59 : 59 + 18 ] , 0 )
ram_offset = ram - rom
continue
prev_line = line
if (
ram_offset is None
or " = " in line
or " *fill* " in line
or " 0x " not in line
) :
continue
ram = int ( line [ 16 : 16 + 18 ] , 0 )
rom = ram - ram_offset
fn = line . split ( ) [ - 1 ]
if " 0x " in fn :
ram_offset = None
elif " / " in fn :
cur_file = fn
else :
syms [ fn ] = ( rom , cur_file , prev_sym , ram )
prev_sym = fn
return syms
def get_map_offsets ( syms ) :
offsets = { }
for sym in syms :
prev_sym = syms [ sym ] [ 2 ]
if sym not in offsets :
offsets [ sym ] = { }
if prev_sym not in offsets :
offsets [ prev_sym ] = { }
offsets [ sym ] [ " start " ] = syms [ sym ] [ 0 ]
offsets [ prev_sym ] [ " end " ] = syms [ sym ] [ 0 ]
return offsets
2020-10-21 02:59:09 +02:00
def is_zeros ( vals ) :
for val in vals :
if val != 0 :
return False
return True
2020-10-19 23:29:40 +02:00
def diff_syms ( qb , tb ) :
2021-02-20 15:41:30 +01:00
if len ( tb [ 1 ] ) < 8 :
2020-10-19 23:29:40 +02:00
return 0
2022-01-13 07:25:39 +01:00
# The minimum edit distance for two strings of different lengths is `abs(l1 - l2)`
# Quickly check if it's impossible to beat the threshold. If it is, then return 0
l1 , l2 = len ( qb [ 0 ] ) , len ( tb [ 0 ] )
if abs ( l1 - l2 ) / ( l1 + l2 ) > 1.0 - args . threshold :
return 0
2021-02-20 15:41:30 +01:00
r = ratio ( qb [ 0 ] , tb [ 0 ] )
if r == 1.0 and qb [ 1 ] != tb [ 1 ] :
r = 0.99
return r
2020-10-19 23:29:40 +02:00
2020-11-20 05:08:05 +01:00
def get_pair_score ( query_bytes , b ) :
b_bytes = get_symbol_bytes ( map_offsets , b )
if query_bytes and b_bytes :
return diff_syms ( query_bytes , b_bytes )
return 0
2020-10-19 23:29:40 +02:00
def get_matches ( query ) :
query_bytes = get_symbol_bytes ( map_offsets , query )
if query_bytes is None :
sys . exit ( " Symbol ' " + query + " ' not found " )
ret = { }
for symbol in map_offsets :
if symbol is not None and query != symbol :
2020-11-20 05:08:05 +01:00
score = get_pair_score ( query_bytes , symbol )
if score > = args . threshold :
ret [ symbol ] = score
2020-10-19 23:29:40 +02:00
return OrderedDict ( sorted ( ret . items ( ) , key = lambda kv : kv [ 1 ] , reverse = True ) )
def do_query ( query ) :
matches = get_matches ( query )
num_matches = len ( matches )
if num_matches == 0 :
print ( query + " - found no matches " )
return
i = 0
more_str = " : "
if args . num_out < num_matches :
more_str = " (showing only " + str ( args . num_out ) + " ): "
print ( query + " - found " + str ( num_matches ) + " matches total " + more_str )
for match in matches :
if i == args . num_out :
break
match_str = " {:.2f} - {} " . format ( matches [ match ] , match )
if match not in s_files :
match_str + = " (decompiled) "
print ( match_str )
i + = 1
print ( )
2020-11-20 05:08:05 +01:00
2021-08-02 16:04:11 +02:00
def all_matches ( all_funcs_flag ) :
match_dict = dict ( )
to_match_files = list ( s_files . copy ( ) )
# assumption that after half the functions have been matched, nothing of significance is left
# since duplicates that already have been discovered are removed from tp_match_files
if all_funcs_flag :
iter_limit = 0
else :
iter_limit = len ( s_files ) / 2
num_decomped_dupes = 0
num_undecomped_dupes = 0
num_perfect_dupes = 0
i = 0
while len ( to_match_files ) > iter_limit :
file = to_match_files [ 0 ]
i + = 1
print ( " File matching progress: {:%} " . format ( i / ( len ( s_files ) - iter_limit ) ) , end = ' \r ' )
if get_symbol_length ( file ) < 16 :
to_match_files . remove ( file )
continue
matches = get_matches ( file )
num_matches = len ( matches )
if num_matches == 0 :
to_match_files . remove ( file )
continue
num_undecomped_dupes + = 1
match_list = [ ]
for match in matches :
if match in to_match_files :
i + = 1
to_match_files . remove ( match )
match_str = " {:.2f} - {} " . format ( matches [ match ] , match )
if matches [ match ] > = 0.995 :
num_perfect_dupes + = 1
if match not in s_files :
num_decomped_dupes + = 1
match_str + = " (decompiled) "
else :
num_undecomped_dupes + = 1
match_list . append ( match_str )
match_dict . update ( { file : ( num_matches , match_list ) } )
to_match_files . remove ( file )
output_match_dict ( match_dict , num_decomped_dupes , num_undecomped_dupes , num_perfect_dupes , i )
def output_match_dict ( match_dict , num_decomped_dupes , num_undecomped_dupes , num_perfect_dupes , num_checked_files ) :
out_file = open ( datetime . today ( ) . strftime ( ' % Y- % m- %d - % H- % M- % S ' ) + " _all_matches.txt " , " w+ " )
out_file . write ( " Number of s-files: " + str ( len ( s_files ) ) + " \n "
" Number of checked s-files: " + str ( round ( num_checked_files ) ) + " \n "
" Number of decompiled duplicates found: " + str ( num_decomped_dupes ) + " \n "
" Number of undecompiled duplicates found: " + str ( num_undecomped_dupes ) + " \n "
" Number of overall exact duplicates found: " + str ( num_perfect_dupes ) + " \n \n " )
sorted_dict = OrderedDict ( sorted ( match_dict . items ( ) , key = lambda item : item [ 1 ] [ 0 ] , reverse = True ) )
print ( " Creating output file: " + out_file . name , end = ' \n ' )
for file_name , matches in sorted_dict . items ( ) :
out_file . write ( file_name + " - found " + str ( matches [ 0 ] ) + " matches total: \n " )
for match in matches [ 1 ] :
out_file . write ( match + " \n " )
out_file . write ( " \n " )
out_file . close ( )
2022-03-13 18:33:10 +01:00
def is_decompiled ( sym ) :
return sym not in s_files
2021-08-02 16:04:11 +02:00
2020-11-20 05:08:05 +01:00
def do_cross_query ( ) :
2020-12-30 12:08:27 +01:00
ccount = Counter ( )
2020-11-20 05:08:05 +01:00
clusters = [ ]
2022-01-13 07:25:39 +01:00
sym_bytes = { }
2020-11-20 05:08:05 +01:00
for sym_name in map_syms :
2021-02-20 15:41:30 +01:00
if not sym_name . startswith ( " D_ " ) and \
not sym_name . startswith ( " _binary " ) and \
not sym_name . startswith ( " jtbl_ " ) and \
not re . match ( r " L[0-9A-F] {8} _[0-9A-F] { 5,6} " , sym_name ) :
2020-11-22 05:25:13 +01:00
if get_symbol_length ( sym_name ) > 16 :
2022-01-13 07:25:39 +01:00
sym_bytes [ sym_name ] = get_symbol_bytes ( map_offsets , sym_name )
for sym_name , query_bytes in sym_bytes . items ( ) :
cluster_match = False
for cluster in clusters :
cluster_first = cluster [ 0 ]
cluster_score = diff_syms ( query_bytes , sym_bytes [ cluster_first ] )
if cluster_score > = args . threshold :
cluster_match = True
2022-03-13 18:33:10 +01:00
if is_decompiled ( sym_name ) and not is_decompiled ( cluster_first ) :
2022-01-13 07:25:39 +01:00
ccount [ sym_name ] = ccount [ cluster_first ]
del ccount [ cluster_first ]
cluster_first = sym_name
cluster . insert ( 0 , cluster_first )
else :
cluster . append ( sym_name )
2022-03-13 18:33:10 +01:00
if not is_decompiled ( cluster_first ) :
2022-03-10 11:11:34 +01:00
ccount [ cluster_first ] + = len ( sym_bytes [ cluster_first ] [ 0 ] )
2022-01-13 07:25:39 +01:00
2022-03-13 18:33:10 +01:00
if len ( cluster ) % 10 == 0 and len ( cluster ) > = 10 :
print ( f " Cluster { cluster_first } grew to size { len ( cluster ) } - { sym_name } : { str ( cluster_score ) } " )
2022-01-13 07:25:39 +01:00
break
if not cluster_match :
clusters . append ( [ sym_name ] )
2020-12-30 12:08:27 +01:00
print ( ccount . most_common ( 100 ) )
2020-11-20 05:08:05 +01:00
2022-01-02 13:10:49 +01:00
parser = argparse . ArgumentParser ( description = " Tool to find duplicates for a specific function or to find all duplicates across the codebase. " )
group = parser . add_mutually_exclusive_group ( )
group . add_argument ( " -a " , " --all " , help = " find ALL duplicates and output them into a file " , action = ' store_true ' , required = False )
group . add_argument ( " -c " , " --cross " , help = " do a cross query over the codebase " , action = ' store_true ' , required = False )
group . add_argument ( " -s " , " --short " , help = " find MOST duplicates besides some very small duplicates. Cuts the runtime in half with minimal loss " , action = ' store_true ' , required = False )
parser . add_argument ( " query " , help = " function or file " , nargs = ' ? ' , default = None )
parser . add_argument ( " -t " , " --threshold " , help = " score threshold between 0 and 1 (higher is more restrictive) " , type = float , default = 0.9 , required = False )
parser . add_argument ( " -n " , " --num-out " , help = " number of functions to display " , type = int , default = 100 , required = False )
2020-10-19 23:29:40 +02:00
args = parser . parse_args ( )
2022-01-02 13:10:49 +01:00
if __name__ == " __main__ " :
rom_bytes = read_rom ( )
map_syms = parse_map ( os . path . join ( root_dir , " ver " , " current " , " build " , " papermario.map " ) )
map_offsets = get_map_offsets ( map_syms )
2022-03-10 11:11:34 +01:00
2022-01-02 13:10:49 +01:00
s_files = get_all_s_files ( )
2022-03-10 11:11:34 +01:00
2022-01-02 13:10:49 +01:00
query_dir = find_dir ( args . query )
2022-03-10 11:11:34 +01:00
2022-01-02 13:10:49 +01:00
if query_dir is not None :
files = os . listdir ( query_dir )
for f_name in files :
do_query ( f_name [ : - 2 ] )
2020-11-20 05:08:05 +01:00
else :
2022-01-02 13:10:49 +01:00
if args . cross :
args . threshold = 0.985
do_cross_query ( )
elif args . all :
args . threshold = 0.985
all_matches ( True )
elif args . short :
args . threshold = 0.985
all_matches ( False )
else :
if args . query is None :
parser . print_help ( )
else :
do_query ( args . query )