https://github.com/alvinwan/neural-backed-decision-trees
Revision 3f52a9633fc8f4215f4c5a2f5073afaf52973f96 authored by Alvin Wan on 12 March 2020, 00:58:24 UTC, committed by Alvin Wan on 12 March 2020, 00:58:52 UTC
1 parent d6766d5
Tip revision: 3f52a9633fc8f4215f4c5a2f5073afaf52973f96 authored by Alvin Wan on 12 March 2020, 00:58:24 UTC
instantiate custom criterionn dircetly, instead of callign model
instantiate custom criterionn dircetly, instead of callign model
Tip revision: 3f52a96
test_generated_hierarchy.py
from utils.utils import DATASETS, METHODS, DATASET_TO_FOLDER_NAME, Colors
from utils.graph import get_parser, get_wnids_from_dataset, read_graph, \
get_leaves, generate_fname, get_directory, get_graph_path_from_args, \
get_roots
from pathlib import Path
import argparse
import os
def get_seen_wnids(wnid_set, nodes):
leaves_seen = set()
for leaf in nodes:
if leaf in wnid_set:
wnid_set.remove(leaf)
if leaf in leaves_seen:
pass
leaves_seen.add(leaf)
return leaves_seen
def match_wnid_leaves(wnids, G, tree_name):
wnid_set = set()
for wnid in wnids:
wnid_set.add(wnid.strip())
leaves_seen = get_seen_wnids(wnid_set, get_leaves(G))
return leaves_seen, wnid_set
def match_wnid_nodes(wnids, G, tree_name):
wnid_set = {wnid.strip() for wnid in wnids}
leaves_seen = get_seen_wnids(wnid_set, G.nodes)
return leaves_seen, wnid_set
def print_stats(leaves_seen, wnid_set, tree_name, node_type):
print(f"[{tree_name}] \t {node_type}: {len(leaves_seen)} \t WNIDs missing from {node_type}: {len(wnid_set)}")
if len(wnid_set):
Colors.red(f"==> Warning: WNIDs in wnid.txt are missing from {tree_name} {node_type}")
def main():
parser = get_parser()
args = parser.parse_args()
wnids = get_wnids_from_dataset(args.dataset)
path = get_graph_path_from_args(args)
print('==> Reading from {}'.format(path))
G = read_graph(path)
G_name = Path(path).stem
leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name)
print_stats(leaves_seen, wnid_set1, G_name, 'leaves')
leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name)
print_stats(leaves_seen, wnid_set2, G_name, 'nodes')
num_roots = len(list(get_roots(G)))
if num_roots == 1:
Colors.green('Found just 1 root.')
else:
Colors.red(f'Found {num_roots} roots. Should be only 1.')
if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1:
Colors.green("==> All checks pass!")
else:
Colors.red('==> Test failed')
if __name__ == '__main__':
main()
Computing file changes ...