amd
/

Image Classification
ONNX
RyzenAI
hangyang-amd commited on
Commit
1332ee4
1 Parent(s): 0b5f4ac

Update infer_onnx.py

Browse files
Files changed (1) hide show
  1. infer_onnx.py +3 -1
infer_onnx.py CHANGED
@@ -36,7 +36,7 @@ parser.add_argument(
36
  default="vaip_config.json",
37
  help="Path of the config file for seting provider_options.",
38
  )
39
-
40
  args = parser.parse_args()
41
 
42
 
@@ -51,6 +51,8 @@ def read_image():
51
  normalize,
52
  ])
53
  img_tensor = transform(image).unsqueeze(0)
 
 
54
  return img_tensor.numpy()
55
 
56
 
 
36
  default="vaip_config.json",
37
  help="Path of the config file for seting provider_options.",
38
  )
39
+ parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
40
  args = parser.parse_args()
41
 
42
 
 
51
  normalize,
52
  ])
53
  img_tensor = transform(image).unsqueeze(0)
54
+ if args.data_format == "nhwc":
55
+ img_tensor = transform(image).unsqueeze(0).permute((0, 2, 3, 1))
56
  return img_tensor.numpy()
57
 
58