Update code and model to support NHWC input format
Browse files- PAN_int8.onnx +2 -2
- data/benchmark.py +2 -2
- data/common.py +6 -5
- data/data_tiling.py +6 -5
- data/srdata.py +4 -2
- eval_onnx.py +4 -0
- infer_onnx.py +2 -2
- utility.py +3 -0
PAN_int8.onnx
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:138287c52ea8f1a58857397964bf06a77d2e2314d4824796ed560c0dc245990d
|
3 |
+
size 1263653
|
data/benchmark.py
CHANGED
@@ -3,9 +3,9 @@ import os
|
|
3 |
from data import srdata
|
4 |
|
5 |
class Benchmark(srdata.SRData):
|
6 |
-
def __init__(self, args, name='', benchmark=True):
|
7 |
super(Benchmark, self).__init__(
|
8 |
-
args, name=name, benchmark=True
|
9 |
)
|
10 |
|
11 |
def _set_filesystem(self, dir_data):
|
|
|
3 |
from data import srdata
|
4 |
|
5 |
class Benchmark(srdata.SRData):
|
6 |
+
def __init__(self, args, name='', benchmark=True, input_data_format ='NHWC'):
|
7 |
super(Benchmark, self).__init__(
|
8 |
+
args, name=name, benchmark=True, input_data_format=input_data_format
|
9 |
)
|
10 |
|
11 |
def _set_filesystem(self, dir_data):
|
data/common.py
CHANGED
@@ -20,12 +20,13 @@ def set_channel(*args, n_channels=3):
|
|
20 |
|
21 |
return [_set_channel(a) for a in args]
|
22 |
|
23 |
-
def np2Tensor(*args, rgb_range=255):
|
24 |
-
def _np2Tensor(img):
|
25 |
-
|
26 |
-
|
|
|
27 |
tensor.mul_(rgb_range / 255)
|
28 |
|
29 |
return tensor
|
30 |
|
31 |
-
return [_np2Tensor(a) for a in args]
|
|
|
20 |
|
21 |
return [_set_channel(a) for a in args]
|
22 |
|
23 |
+
def np2Tensor(*args, rgb_range=255, format='NCHW'):
|
24 |
+
def _np2Tensor(img, channel_format):
|
25 |
+
assert channel_format in ('NCHW', 'NHWC')
|
26 |
+
img = np.ascontiguousarray(img.transpose((2, 0, 1))) if channel_format == ('NCHW') else img
|
27 |
+
tensor = torch.from_numpy(img).float()
|
28 |
tensor.mul_(rgb_range / 255)
|
29 |
|
30 |
return tensor
|
31 |
|
32 |
+
return [_np2Tensor(a, format) for a in args]
|
data/data_tiling.py
CHANGED
@@ -11,8 +11,8 @@ def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
|
|
11 |
- patch_size: a tuple of (height, width) that specifies the size of each patch
|
12 |
Returns: - a numpy array that represents the enhanced image
|
13 |
"""
|
14 |
-
_,
|
15 |
-
sr = np.zeros((1,
|
16 |
n_h = math.ceil(h / float(patch_size[0] - overlapping))
|
17 |
n_w = math.ceil(w / float(patch_size[1] - overlapping))
|
18 |
#every tilling input has same size of patch_size
|
@@ -23,8 +23,9 @@ def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
|
|
23 |
w_idx = iw * (patch_size[1] - overlapping)
|
24 |
w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
|
25 |
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
|
30 |
left += overlapping//2
|
@@ -42,5 +43,5 @@ def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
|
|
42 |
right += overlapping//2
|
43 |
|
44 |
#get preditions
|
45 |
-
sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right)] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right]
|
46 |
return sr
|
|
|
11 |
- patch_size: a tuple of (height, width) that specifies the size of each patch
|
12 |
Returns: - a numpy array that represents the enhanced image
|
13 |
"""
|
14 |
+
_, h, w, _ = lr.shape
|
15 |
+
sr = np.zeros((1, 2*h, 2*w, 3))
|
16 |
n_h = math.ceil(h / float(patch_size[0] - overlapping))
|
17 |
n_w = math.ceil(w / float(patch_size[1] - overlapping))
|
18 |
#every tilling input has same size of patch_size
|
|
|
23 |
w_idx = iw * (patch_size[1] - overlapping)
|
24 |
w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
|
25 |
|
26 |
+
tiling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1], :]
|
27 |
+
# import pdb; pdb.set_trace()
|
28 |
+
sr_tiling = session.run(None, {session.get_inputs()[0].name: tiling_lr})[0]
|
29 |
|
30 |
left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
|
31 |
left += overlapping//2
|
|
|
43 |
right += overlapping//2
|
44 |
|
45 |
#get preditions
|
46 |
+
sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right), :] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right, :]
|
47 |
return sr
|
data/srdata.py
CHANGED
@@ -9,13 +9,15 @@ import imageio
|
|
9 |
import torch.utils.data as data
|
10 |
|
11 |
class SRData(data.Dataset):
|
12 |
-
def __init__(self, args, name='', benchmark=True):
|
13 |
self.args = args
|
14 |
self.name = name
|
15 |
self.benchmark = benchmark
|
16 |
self.input_large = False
|
17 |
self.scale = args.scale
|
18 |
self.idx_scale = 0
|
|
|
|
|
19 |
|
20 |
self._set_filesystem(args.dir_data)
|
21 |
if args.ext.find('img') < 0:
|
@@ -87,7 +89,7 @@ class SRData(data.Dataset):
|
|
87 |
lr, hr, filename = self._load_file(idx)
|
88 |
pair = self.get_patch(lr, hr)
|
89 |
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
90 |
-
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
91 |
|
92 |
return pair_t[0], pair_t[1], filename
|
93 |
|
|
|
9 |
import torch.utils.data as data
|
10 |
|
11 |
class SRData(data.Dataset):
|
12 |
+
def __init__(self, args, name='', benchmark=True, input_data_format='NCHW'):
|
13 |
self.args = args
|
14 |
self.name = name
|
15 |
self.benchmark = benchmark
|
16 |
self.input_large = False
|
17 |
self.scale = args.scale
|
18 |
self.idx_scale = 0
|
19 |
+
assert input_data_format in ('NCHW', 'NHWC')
|
20 |
+
self.input_data_format = input_data_format
|
21 |
|
22 |
self._set_filesystem(args.dir_data)
|
23 |
if args.ext.find('img') < 0:
|
|
|
89 |
lr, hr, filename = self._load_file(idx)
|
90 |
pair = self.get_patch(lr, hr)
|
91 |
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
92 |
+
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range, format=self.input_data_format)
|
93 |
|
94 |
return pair_t[0], pair_t[1], filename
|
95 |
|
eval_onnx.py
CHANGED
@@ -26,6 +26,10 @@ def test_model(session, loader):
|
|
26 |
sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
|
27 |
sr = torch.from_numpy(sr)
|
28 |
sr = utility.quantize(sr, 255)
|
|
|
|
|
|
|
|
|
29 |
eval_psnr += utility.calc_psnr(
|
30 |
sr, hr, scale, 255, benchmark=d)
|
31 |
eval_ssim += utility.calc_ssim(
|
|
|
26 |
sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
|
27 |
sr = torch.from_numpy(sr)
|
28 |
sr = utility.quantize(sr, 255)
|
29 |
+
|
30 |
+
# Transform from NHWC to NCHW to calculate metric
|
31 |
+
sr = sr.permute((0, 3, 1, 2))
|
32 |
+
hr = hr.permute((0, 3, 1, 2))
|
33 |
eval_psnr += utility.calc_psnr(
|
34 |
sr, hr, scale, 255, benchmark=d)
|
35 |
eval_ssim += utility.calc_ssim(
|
infer_onnx.py
CHANGED
@@ -22,12 +22,12 @@ def main(args):
|
|
22 |
providers = ['CPUExecutionProvider']
|
23 |
provider_options = None
|
24 |
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
25 |
-
lr = cv2.imread(image_path)[np.newaxis,:,:,:].
|
26 |
|
27 |
# Tiled inference
|
28 |
sr = tiling_inference(ort_session, lr, 8, (56, 56))
|
29 |
sr = np.clip(sr, 0, 255)
|
30 |
-
sr = sr.squeeze().
|
31 |
cv2.imwrite(output_path, sr)
|
32 |
|
33 |
|
|
|
22 |
providers = ['CPUExecutionProvider']
|
23 |
provider_options = None
|
24 |
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
25 |
+
lr = cv2.imread(image_path)[np.newaxis,:,:,:].astype(np.float32)
|
26 |
|
27 |
# Tiled inference
|
28 |
sr = tiling_inference(ort_session, lr, 8, (56, 56))
|
29 |
sr = np.clip(sr, 0, 255)
|
30 |
+
sr = sr.squeeze().astype(np.uint8)
|
31 |
cv2.imwrite(output_path, sr)
|
32 |
|
33 |
|
utility.py
CHANGED
@@ -8,6 +8,9 @@ def quantize(img, rgb_range):
|
|
8 |
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
|
9 |
|
10 |
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
|
|
|
|
|
|
|
11 |
if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
|
12 |
print("the dimention of sr image is not equal to hr's! ")
|
13 |
sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
|
|
|
8 |
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
|
9 |
|
10 |
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
|
11 |
+
if sr.size(-1) == 3 and sr.size(1) > 3:
|
12 |
+
sr = sr.transpose((0, 3, 1, 2))
|
13 |
+
hr = hr.transpose((0, 3, 1, 2))
|
14 |
if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
|
15 |
print("the dimention of sr image is not equal to hr's! ")
|
16 |
sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
|