Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import logging | |
import time | |
import os | |
import torch | |
from tqdm import tqdm | |
from maskrcnn_benchmark.data.datasets.evaluation import evaluate | |
from ..utils.comm import is_main_process, get_world_size | |
from ..utils.comm import all_gather | |
from ..utils.comm import synchronize | |
from ..utils.timer import Timer, get_time_str | |
def compute_on_dataset(model, data_loader, device, timer=None): | |
model.eval() | |
results_dict = {} | |
cpu_device = torch.device("cpu") | |
for _, batch in enumerate(tqdm(data_loader)): | |
images, targets, image_ids = batch | |
images = images.to(device) | |
with torch.no_grad(): | |
if timer: | |
timer.tic() | |
output = model(images) | |
if timer: | |
torch.cuda.synchronize() | |
timer.toc() | |
output = [o.to(cpu_device) for o in output] | |
results_dict.update( | |
{img_id: result for img_id, result in zip(image_ids, output)} | |
) | |
return results_dict | |
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): | |
all_predictions = all_gather(predictions_per_gpu) | |
if not is_main_process(): | |
return | |
# merge the list of dicts | |
predictions = {} | |
for p in all_predictions: | |
predictions.update(p) | |
# convert a dict where the key is the index in a list | |
image_ids = list(sorted(predictions.keys())) | |
if len(image_ids) != image_ids[-1] + 1: | |
logger = logging.getLogger("maskrcnn_benchmark.inference") | |
logger.warning( | |
"Number of images that were gathered from multiple processes is not " | |
"a contiguous set. Some images might be missing from the evaluation" | |
) | |
# convert to a list | |
predictions = [predictions[i] for i in image_ids] | |
return predictions | |
def inference( | |
model, | |
data_loader, | |
dataset_name, | |
iou_types=("bbox",), | |
box_only=False, | |
device="cuda", | |
expected_results=(), | |
expected_results_sigma_tol=4, | |
output_folder=None, | |
): | |
logger = logging.getLogger("maskrcnn_benchmark.inference") | |
dataset = data_loader.dataset | |
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset))) | |
extra_args = dict( | |
box_only=box_only, | |
iou_types=iou_types, | |
expected_results=expected_results, | |
expected_results_sigma_tol=expected_results_sigma_tol, | |
) | |
# load predictions if exists | |
prediction_file = os.path.join(output_folder, 'predictions.pth') | |
if os.path.isfile(prediction_file): | |
predictions = torch.load(prediction_file) | |
logger.info("Found prediction results at {}".format(prediction_file)) | |
return evaluate(dataset=dataset, | |
predictions=predictions, | |
output_folder=output_folder, | |
**extra_args) | |
# convert to a torch.device for efficiency | |
device = torch.device(device) | |
num_devices = get_world_size() | |
total_timer = Timer() | |
inference_timer = Timer() | |
total_timer.tic() | |
predictions = compute_on_dataset(model, data_loader, device, inference_timer) | |
# wait for all processes to complete before measuring the time | |
synchronize() | |
total_time = total_timer.toc() | |
total_time_str = get_time_str(total_time) | |
logger.info( | |
"Total run time: {} ({} s / img per device, on {} devices)".format( | |
total_time_str, total_time * num_devices / len(dataset), num_devices | |
) | |
) | |
total_infer_time = get_time_str(inference_timer.total_time) | |
logger.info( | |
"Model inference time: {} ({} s / img per device, on {} devices)".format( | |
total_infer_time, | |
inference_timer.total_time * num_devices / len(dataset), | |
num_devices, | |
) | |
) | |
predictions = _accumulate_predictions_from_multiple_gpus(predictions) | |
if not is_main_process(): | |
return | |
if output_folder: | |
torch.save(predictions, os.path.join(output_folder, "predictions.pth")) | |
return evaluate(dataset=dataset, | |
predictions=predictions, | |
output_folder=output_folder, | |
**extra_args) | |