sesr / test.py
zhengrongzhang's picture
init model
2071132
raw
history blame contribute delete
No virus
1.89 kB
import torch
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
from tqdm import tqdm
import utility
import data
from option import args
import metric
import onnxruntime
import cv2
from data.data_tiling import tiling_inference
def prepare(a, b, device):
def _prepare(tensor):
return tensor.to(device)
return _prepare(a), _prepare(b)
def test_model(session, loader, device):
torch.set_grad_enabled(False)
self_scale = [2]
for idx_data, d in enumerate(loader.loader_test):
eval_ssim = 0
eval_psnr = 0
for idx_scale, scale in enumerate(self_scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = prepare(lr, hr, device)
sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56))
sr = torch.from_numpy(sr).to(device)
sr = utility.quantize(sr, 255)
eval_psnr += metric.calc_psnr(
sr, hr, scale, 255, benchmark=d)
eval_ssim += metric.calc_ssim(
sr, hr, scale, 255, dataset=d)
mean_ssim = eval_ssim / len(d)
mean_psnr = eval_psnr / len(d)
print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
return mean_psnr, mean_ssim
def main():
loader = data.Data(args)
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
onnx_file_name = args.onnx_path
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
test_model(ort_session, loader, device="cpu")
if __name__ == '__main__':
main()