https://github.com/alvinwan/neural-backed-decision-trees
Raw File
Tip revision: a7a2ee6f735bbc1b3d8c7c4f9ecdd02c6a75fc1e authored by Alvin Wan on 03 June 2021, 04:38:35 UTC
Merge pull request #20 from alvinwan/dependabot/pip/examples/app/flask-cors-3.0.9
Tip revision: a7a2ee6
gen_train_eval_resnet.sh
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below.

for i in "CIFAR10 1" "CIFAR100 1" "TinyImagenet200 10"; do
  read dataset weight <<< "${i}";

  # 1. generate hieararchy
  nbdt-hierarchy --dataset=${dataset} --arch=ResNet18

  # 2. train with soft tree supervision loss
  python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --tree-supervision-weight=${weight}

  # 3. evaluate with soft then hard inference
  for analysis in SoftEmbeddedDecisionRules HardEmbeddedDecisionRules; do
    python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} --eval --resume --analysis=${analysis}
  done
done;
back to top