Revision e1467a79dc6580ae009d827b5e6f274faff3b339 authored by liqunfu on 27 March 2020, 21:42:04 UTC, committed by GitHub on 27 March 2020, 21:42:04 UTC
support Pooling ops with Sequence axis
run_fast_rcnn.py
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import os
import numpy as np
from FastRCNN_train import prepare, train_fast_rcnn
from FastRCNN_eval import compute_test_set_aps, FastRCNN_Evaluator
from utils.config_helpers import merge_configs
from utils.plot_helpers import plot_test_set_results
def get_configuration():
# load configs for detector, base network and data set
from FastRCNN_config import cfg as detector_cfg
# for VGG16 base model use: from utils.configs.VGG16_config import cfg as network_cfg
# for AlexNet base model use: from utils.configs.AlexNet_config import cfg as network_cfg
from utils.configs.AlexNet_config import cfg as network_cfg
# for Pascal VOC 2007 data set use: from utils.configs.Pascal_config import cfg as dataset_cfg
# for the Grocery data set use: from utils.configs.Grocery_config import cfg as dataset_cfg
from utils.configs.Grocery_config import cfg as dataset_cfg
return merge_configs([detector_cfg, network_cfg, dataset_cfg])
# trains and evaluates a Fast R-CNN model.
if __name__ == '__main__':
cfg = get_configuration()
prepare(cfg, True)
# train and test
trained_model = train_fast_rcnn(cfg)
eval_results = compute_test_set_aps(trained_model, cfg)
# write AP results to output
for class_name in eval_results: print('AP for {:>15} = {:.4f}'.format(class_name, eval_results[class_name]))
print('Mean AP = {:.4f}'.format(np.nanmean(list(eval_results.values()))))
# Plot results on test set images
if cfg.VISUALIZE_RESULTS:
num_eval = min(cfg["DATA"].NUM_TEST_IMAGES, 100)
results_folder = os.path.join(cfg.OUTPUT_PATH, cfg["DATA"].DATASET)
evaluator = FastRCNN_Evaluator(trained_model, cfg)
plot_test_set_results(evaluator, num_eval, results_folder, cfg)
Computing file changes ...