change onnx to NHWC
#1
by
wensongc
- opened
- FPN_int.onnx → FPN_int_NHWC.onnx +2 -2
- infer_onnx.py +3 -3
- test_onnx.py +3 -3
FPN_int.onnx → FPN_int_NHWC.onnx
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1cca42fe8ac3429c55aa1a28a3ecdf08cd279b6193e6487d9b18fa28cfc2c7e
|
3 |
+
size 45595686
|
infer_onnx.py
CHANGED
@@ -19,7 +19,7 @@ from datasets.utils import colorize_mask, build_img
|
|
19 |
|
20 |
if __name__ == "__main__":
|
21 |
parser = argparse.ArgumentParser(description='SemanticFPN model')
|
22 |
-
parser.add_argument('--onnx_path', type=str, default='
|
23 |
parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
|
24 |
parser.add_argument('--input_path', type=str, default='data/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
|
25 |
parser.add_argument('--ipu', action='store_true', help='use ipu')
|
@@ -37,8 +37,8 @@ if __name__ == "__main__":
|
|
37 |
onnx_path = args.onnx_path
|
38 |
input_img = build_img(args)
|
39 |
session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
|
40 |
-
ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy()}
|
41 |
-
ort_output = session.run(None, ort_input)[0]
|
42 |
if isinstance(ort_output, (tuple, list)):
|
43 |
ort_output = ort_output[0]
|
44 |
|
|
|
19 |
|
20 |
if __name__ == "__main__":
|
21 |
parser = argparse.ArgumentParser(description='SemanticFPN model')
|
22 |
+
parser.add_argument('--onnx_path', type=str, default='FPN_int_NHWC.onnx')
|
23 |
parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
|
24 |
parser.add_argument('--input_path', type=str, default='data/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
|
25 |
parser.add_argument('--ipu', action='store_true', help='use ipu')
|
|
|
37 |
onnx_path = args.onnx_path
|
38 |
input_img = build_img(args)
|
39 |
session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
|
40 |
+
ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy().transpose(0,2,3,1)}
|
41 |
+
ort_output = session.run(None, ort_input)[0].transpose(0,3,1,2)
|
42 |
if isinstance(ort_output, (tuple, list)):
|
43 |
ort_output = ort_output[0]
|
44 |
|
test_onnx.py
CHANGED
@@ -22,7 +22,7 @@ class Configs():
|
|
22 |
# dataset
|
23 |
|
24 |
parser.add_argument('--dataset', type=str, default='citys', help='dataset name (default: citys)')
|
25 |
-
parser.add_argument('--onnx_path', type=str, default='
|
26 |
parser.add_argument('--num-classes', type=int, default=19,
|
27 |
help='the classes numbers (default: 19 for cityscapes)')
|
28 |
parser.add_argument('--test-folder', type=str, default='./data/cityscapes',
|
@@ -78,8 +78,8 @@ def eval_miou(data,path="FPN_int.onnx", device='cpu'):
|
|
78 |
|
79 |
for i, (image, target) in enumerate(tbar):
|
80 |
image, target = image.to(device), target.to(device)
|
81 |
-
ort_input = {session.get_inputs()[0].name: image.cpu().numpy()}
|
82 |
-
ort_output = session.run(None, ort_input)[0]
|
83 |
if isinstance(ort_output, (tuple, list)):
|
84 |
ort_output = ort_output[0]
|
85 |
ort_output = torch.from_numpy(ort_output).to(device)
|
|
|
22 |
# dataset
|
23 |
|
24 |
parser.add_argument('--dataset', type=str, default='citys', help='dataset name (default: citys)')
|
25 |
+
parser.add_argument('--onnx_path', type=str, default='FPN_int_NHWC.onnx', help='onnx path')
|
26 |
parser.add_argument('--num-classes', type=int, default=19,
|
27 |
help='the classes numbers (default: 19 for cityscapes)')
|
28 |
parser.add_argument('--test-folder', type=str, default='./data/cityscapes',
|
|
|
78 |
|
79 |
for i, (image, target) in enumerate(tbar):
|
80 |
image, target = image.to(device), target.to(device)
|
81 |
+
ort_input = {session.get_inputs()[0].name: image.cpu().numpy().transpose(0,2,3,1)}
|
82 |
+
ort_output = session.run(None, ort_input)[0].transpose(0,3,1,2)
|
83 |
if isinstance(ort_output, (tuple, list)):
|
84 |
ort_output = ort_output[0]
|
85 |
ort_output = torch.from_numpy(ort_output).to(device)
|