Cyril666's picture
First model version
6250360
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import time
import os
import torch
from tqdm import tqdm
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str
def compute_on_dataset(model, data_loader, device, timer=None):
model.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for _, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
images = images.to(device)
with torch.no_grad():
if timer:
timer.tic()
output = model(images)
if timer:
torch.cuda.synchronize()
timer.toc()
output = [o.to(cpu_device) for o in output]
results_dict.update(
{img_id: result for img_id, result in zip(image_ids, output)}
)
return results_dict
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
all_predictions = all_gather(predictions_per_gpu)
if not is_main_process():
return
# merge the list of dicts
predictions = {}
for p in all_predictions:
predictions.update(p)
# convert a dict where the key is the index in a list
image_ids = list(sorted(predictions.keys()))
if len(image_ids) != image_ids[-1] + 1:
logger = logging.getLogger("maskrcnn_benchmark.inference")
logger.warning(
"Number of images that were gathered from multiple processes is not "
"a contiguous set. Some images might be missing from the evaluation"
)
# convert to a list
predictions = [predictions[i] for i in image_ids]
return predictions
def inference(
model,
data_loader,
dataset_name,
iou_types=("bbox",),
box_only=False,
device="cuda",
expected_results=(),
expected_results_sigma_tol=4,
output_folder=None,
):
logger = logging.getLogger("maskrcnn_benchmark.inference")
dataset = data_loader.dataset
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
extra_args = dict(
box_only=box_only,
iou_types=iou_types,
expected_results=expected_results,
expected_results_sigma_tol=expected_results_sigma_tol,
)
# load predictions if exists
prediction_file = os.path.join(output_folder, 'predictions.pth')
if os.path.isfile(prediction_file):
predictions = torch.load(prediction_file)
logger.info("Found prediction results at {}".format(prediction_file))
return evaluate(dataset=dataset,
predictions=predictions,
output_folder=output_folder,
**extra_args)
# convert to a torch.device for efficiency
device = torch.device(device)
num_devices = get_world_size()
total_timer = Timer()
inference_timer = Timer()
total_timer.tic()
predictions = compute_on_dataset(model, data_loader, device, inference_timer)
# wait for all processes to complete before measuring the time
synchronize()
total_time = total_timer.toc()
total_time_str = get_time_str(total_time)
logger.info(
"Total run time: {} ({} s / img per device, on {} devices)".format(
total_time_str, total_time * num_devices / len(dataset), num_devices
)
)
total_infer_time = get_time_str(inference_timer.total_time)
logger.info(
"Model inference time: {} ({} s / img per device, on {} devices)".format(
total_infer_time,
inference_timer.total_time * num_devices / len(dataset),
num_devices,
)
)
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
if not is_main_process():
return
if output_folder:
torch.save(predictions, os.path.join(output_folder, "predictions.pth"))
return evaluate(dataset=dataset,
predictions=predictions,
output_folder=output_folder,
**extra_args)