yixionghuo commited on
Commit
675b749
1 Parent(s): 32865f3

Update onnx_test.py

Browse files
Files changed (1) hide show
  1. onnx_test.py +5 -1
onnx_test.py CHANGED
@@ -962,8 +962,12 @@ def test(data,
962
  whwh = torch.Tensor([width, height, width, height]).to(device)
963
 
964
  if onnx_runtime:
 
 
965
  outputs = onnx_model.run(
966
- None, {onnx_model.get_inputs()[0].name: imgs.cpu().numpy()})
 
 
967
  outputs = [torch.tensor(item).to(device) for item in outputs]
968
  inf_out, train_out = post_process(outputs)
969
 
 
962
  whwh = torch.Tensor([width, height, width, height]).to(device)
963
 
964
  if onnx_runtime:
965
+ # outputs = onnx_model.run(
966
+ # None, {onnx_model.get_inputs()[0].name: imgs.cpu().numpy()})
967
  outputs = onnx_model.run(
968
+ None, {onnx_model.get_inputs()[0].name: np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1))})
969
+ outputs = [np.transpose(out, (0, 3, 1, 2)) for out in outputs]
970
+
971
  outputs = [torch.tensor(item).to(device) for item in outputs]
972
  inf_out, train_out = post_process(outputs)
973