amd
/

ONNX
PyTorch
English
RyzenAI
super resolution
SISR
File size: 1,890 Bytes
2071132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
from tqdm import tqdm
import utility
import data
from option import args
import metric
import onnxruntime
import cv2
from data.data_tiling import tiling_inference


def prepare(a, b, device):
    def _prepare(tensor):
        return tensor.to(device)

    return _prepare(a), _prepare(b)


def test_model(session, loader, device):
    torch.set_grad_enabled(False)
    self_scale = [2]
    for idx_data, d in enumerate(loader.loader_test):
        eval_ssim = 0
        eval_psnr = 0
        for idx_scale, scale in enumerate(self_scale):
            d.dataset.set_scale(idx_scale)
            for lr, hr, filename in tqdm(d, ncols=80):
                lr, hr = prepare(lr, hr, device)
                sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56))
                sr = torch.from_numpy(sr).to(device)
                sr = utility.quantize(sr, 255)
                eval_psnr += metric.calc_psnr(
                    sr, hr, scale, 255, benchmark=d)
                eval_ssim += metric.calc_ssim(
                    sr, hr, scale, 255, dataset=d)
            mean_ssim = eval_ssim / len(d)
            mean_psnr = eval_psnr  / len(d)
        print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
    return mean_psnr, mean_ssim

def main():
    loader = data.Data(args)
    if args.ipu:
        providers = ["VitisAIExecutionProvider"]
        provider_options = [{"config_file": args.provider_config}]
    else:
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        provider_options = None
    onnx_file_name = args.onnx_path
    ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options) 
    test_model(ort_session, loader, device="cpu")


if __name__ == '__main__':
    main()