zhengrongzhang
commited on
Commit
•
30c4d88
1
Parent(s):
da9195c
change onnx to NHWC (#1)
Browse files- change onnx to NHWC (9f86314792bc63c7d4327617e9dd76b69ed0c1b6)
- RCAN_int8.onnx → RCAN_int8_NHWC.onnx +2 -2
- data/data_tiling.py +1 -1
- eval_onnx.py +1 -1
- infer_onnx.py +1 -1
RCAN_int8.onnx → RCAN_int8_NHWC.onnx
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a1cad5da6396a4c812bb5b0d60c1470a1e2de1b9b4e8fe58ce4132972b164f7
|
3 |
+
size 445692
|
data/data_tiling.py
CHANGED
@@ -18,7 +18,7 @@ def tiling_inference(session, lr, overlapping, patch_size):
|
|
18 |
w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
|
19 |
|
20 |
tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
|
21 |
-
sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
|
22 |
|
23 |
left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
|
24 |
left += overlapping//2
|
|
|
18 |
w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
|
19 |
|
20 |
tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
|
21 |
+
sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr.transpose(0,2,3,1)})[0].transpose(0,3,1,2)
|
22 |
|
23 |
left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
|
24 |
left += overlapping//2
|
eval_onnx.py
CHANGED
@@ -19,7 +19,7 @@ class Configs():
|
|
19 |
# ipu test or cpu, you need to provide onnx path
|
20 |
parser.add_argument('--ipu', action='store_true',
|
21 |
help='use ipu')
|
22 |
-
parser.add_argument('--onnx_path', type=str, default='
|
23 |
help='onnx path')
|
24 |
parser.add_argument('--provider_config', type=str, default=None,
|
25 |
help='provider config path')
|
|
|
19 |
# ipu test or cpu, you need to provide onnx path
|
20 |
parser.add_argument('--ipu', action='store_true',
|
21 |
help='use ipu')
|
22 |
+
parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx',
|
23 |
help='onnx path')
|
24 |
parser.add_argument('--provider_config', type=str, default=None,
|
25 |
help='provider config path')
|
infer_onnx.py
CHANGED
@@ -31,7 +31,7 @@ def main(args):
|
|
31 |
|
32 |
if __name__ == '__main__':
|
33 |
parser = argparse.ArgumentParser(description='RCAN SISR')
|
34 |
-
parser.add_argument('--onnx_path', type=str, default='
|
35 |
help='onnx path')
|
36 |
parser.add_argument('--image_path', default='test_data/test.png',
|
37 |
help='path of your image')
|
|
|
31 |
|
32 |
if __name__ == '__main__':
|
33 |
parser = argparse.ArgumentParser(description='RCAN SISR')
|
34 |
+
parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx',
|
35 |
help='onnx path')
|
36 |
parser.add_argument('--image_path', default='test_data/test.png',
|
37 |
help='path of your image')
|