Tellll commited on
Commit
90e4acb
1 Parent(s): 3135a01

Update code and model to support NHWC input format

Browse files
PAN_int8.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c5b5e35f9eeaf54988685263e868a1c54cb075a0560d5228af5f423d123af3be
3
- size 1263469
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:138287c52ea8f1a58857397964bf06a77d2e2314d4824796ed560c0dc245990d
3
+ size 1263653
data/benchmark.py CHANGED
@@ -3,9 +3,9 @@ import os
3
  from data import srdata
4
 
5
  class Benchmark(srdata.SRData):
6
- def __init__(self, args, name='', benchmark=True):
7
  super(Benchmark, self).__init__(
8
- args, name=name, benchmark=True
9
  )
10
 
11
  def _set_filesystem(self, dir_data):
 
3
  from data import srdata
4
 
5
  class Benchmark(srdata.SRData):
6
+ def __init__(self, args, name='', benchmark=True, input_data_format ='NHWC'):
7
  super(Benchmark, self).__init__(
8
+ args, name=name, benchmark=True, input_data_format=input_data_format
9
  )
10
 
11
  def _set_filesystem(self, dir_data):
data/common.py CHANGED
@@ -20,12 +20,13 @@ def set_channel(*args, n_channels=3):
20
 
21
  return [_set_channel(a) for a in args]
22
 
23
- def np2Tensor(*args, rgb_range=255):
24
- def _np2Tensor(img):
25
- np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
26
- tensor = torch.from_numpy(np_transpose).float()
 
27
  tensor.mul_(rgb_range / 255)
28
 
29
  return tensor
30
 
31
- return [_np2Tensor(a) for a in args]
 
20
 
21
  return [_set_channel(a) for a in args]
22
 
23
+ def np2Tensor(*args, rgb_range=255, format='NCHW'):
24
+ def _np2Tensor(img, channel_format):
25
+ assert channel_format in ('NCHW', 'NHWC')
26
+ img = np.ascontiguousarray(img.transpose((2, 0, 1))) if channel_format == ('NCHW') else img
27
+ tensor = torch.from_numpy(img).float()
28
  tensor.mul_(rgb_range / 255)
29
 
30
  return tensor
31
 
32
+ return [_np2Tensor(a, format) for a in args]
data/data_tiling.py CHANGED
@@ -11,8 +11,8 @@ def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
11
  - patch_size: a tuple of (height, width) that specifies the size of each patch
12
  Returns: - a numpy array that represents the enhanced image
13
  """
14
- _, _, h, w = lr.shape
15
- sr = np.zeros((1, 3, 2*h, 2*w))
16
  n_h = math.ceil(h / float(patch_size[0] - overlapping))
17
  n_w = math.ceil(w / float(patch_size[1] - overlapping))
18
  #every tilling input has same size of patch_size
@@ -23,8 +23,9 @@ def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
23
  w_idx = iw * (patch_size[1] - overlapping)
24
  w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
25
 
26
- tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
27
- sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
 
28
 
29
  left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
30
  left += overlapping//2
@@ -42,5 +43,5 @@ def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
42
  right += overlapping//2
43
 
44
  #get preditions
45
- sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right)] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right]
46
  return sr
 
11
  - patch_size: a tuple of (height, width) that specifies the size of each patch
12
  Returns: - a numpy array that represents the enhanced image
13
  """
14
+ _, h, w, _ = lr.shape
15
+ sr = np.zeros((1, 2*h, 2*w, 3))
16
  n_h = math.ceil(h / float(patch_size[0] - overlapping))
17
  n_w = math.ceil(w / float(patch_size[1] - overlapping))
18
  #every tilling input has same size of patch_size
 
23
  w_idx = iw * (patch_size[1] - overlapping)
24
  w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
25
 
26
+ tiling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1], :]
27
+ # import pdb; pdb.set_trace()
28
+ sr_tiling = session.run(None, {session.get_inputs()[0].name: tiling_lr})[0]
29
 
30
  left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
31
  left += overlapping//2
 
43
  right += overlapping//2
44
 
45
  #get preditions
46
+ sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right), :] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right, :]
47
  return sr
data/srdata.py CHANGED
@@ -9,13 +9,15 @@ import imageio
9
  import torch.utils.data as data
10
 
11
  class SRData(data.Dataset):
12
- def __init__(self, args, name='', benchmark=True):
13
  self.args = args
14
  self.name = name
15
  self.benchmark = benchmark
16
  self.input_large = False
17
  self.scale = args.scale
18
  self.idx_scale = 0
 
 
19
 
20
  self._set_filesystem(args.dir_data)
21
  if args.ext.find('img') < 0:
@@ -87,7 +89,7 @@ class SRData(data.Dataset):
87
  lr, hr, filename = self._load_file(idx)
88
  pair = self.get_patch(lr, hr)
89
  pair = common.set_channel(*pair, n_channels=self.args.n_colors)
90
- pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
91
 
92
  return pair_t[0], pair_t[1], filename
93
 
 
9
  import torch.utils.data as data
10
 
11
  class SRData(data.Dataset):
12
+ def __init__(self, args, name='', benchmark=True, input_data_format='NCHW'):
13
  self.args = args
14
  self.name = name
15
  self.benchmark = benchmark
16
  self.input_large = False
17
  self.scale = args.scale
18
  self.idx_scale = 0
19
+ assert input_data_format in ('NCHW', 'NHWC')
20
+ self.input_data_format = input_data_format
21
 
22
  self._set_filesystem(args.dir_data)
23
  if args.ext.find('img') < 0:
 
89
  lr, hr, filename = self._load_file(idx)
90
  pair = self.get_patch(lr, hr)
91
  pair = common.set_channel(*pair, n_channels=self.args.n_colors)
92
+ pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range, format=self.input_data_format)
93
 
94
  return pair_t[0], pair_t[1], filename
95
 
eval_onnx.py CHANGED
@@ -26,6 +26,10 @@ def test_model(session, loader):
26
  sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
27
  sr = torch.from_numpy(sr)
28
  sr = utility.quantize(sr, 255)
 
 
 
 
29
  eval_psnr += utility.calc_psnr(
30
  sr, hr, scale, 255, benchmark=d)
31
  eval_ssim += utility.calc_ssim(
 
26
  sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
27
  sr = torch.from_numpy(sr)
28
  sr = utility.quantize(sr, 255)
29
+
30
+ # Transform from NHWC to NCHW to calculate metric
31
+ sr = sr.permute((0, 3, 1, 2))
32
+ hr = hr.permute((0, 3, 1, 2))
33
  eval_psnr += utility.calc_psnr(
34
  sr, hr, scale, 255, benchmark=d)
35
  eval_ssim += utility.calc_ssim(
infer_onnx.py CHANGED
@@ -22,12 +22,12 @@ def main(args):
22
  providers = ['CPUExecutionProvider']
23
  provider_options = None
24
  ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
25
- lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
26
 
27
  # Tiled inference
28
  sr = tiling_inference(ort_session, lr, 8, (56, 56))
29
  sr = np.clip(sr, 0, 255)
30
- sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
31
  cv2.imwrite(output_path, sr)
32
 
33
 
 
22
  providers = ['CPUExecutionProvider']
23
  provider_options = None
24
  ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
25
+ lr = cv2.imread(image_path)[np.newaxis,:,:,:].astype(np.float32)
26
 
27
  # Tiled inference
28
  sr = tiling_inference(ort_session, lr, 8, (56, 56))
29
  sr = np.clip(sr, 0, 255)
30
+ sr = sr.squeeze().astype(np.uint8)
31
  cv2.imwrite(output_path, sr)
32
 
33
 
utility.py CHANGED
@@ -8,6 +8,9 @@ def quantize(img, rgb_range):
8
  return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
9
 
10
  def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
 
 
 
11
  if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
12
  print("the dimention of sr image is not equal to hr's! ")
13
  sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
 
8
  return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
9
 
10
  def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
11
+ if sr.size(-1) == 3 and sr.size(1) > 3:
12
+ sr = sr.transpose((0, 3, 1, 2))
13
+ hr = hr.transpose((0, 3, 1, 2))
14
  if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
15
  print("the dimention of sr image is not equal to hr's! ")
16
  sr = sr[:,:,:hr.size(-2),:hr.size(-1)]