amd
/

timm
ONNX
PyTorch
RyzenAI
vision
classification
Files changed (2) hide show
  1. eval_onnx.py +1 -1
  2. mnasnet_b1_int.onnx +2 -2
eval_onnx.py CHANGED
@@ -144,7 +144,7 @@ def val_imagenet():
144
  val_loader = tqdm(val_loader, file=sys.stdout)
145
  with torch.no_grad():
146
  for batch_idx, (images, targets) in enumerate(val_loader):
147
- inputs, targets = images.numpy(), targets
148
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
149
 
150
  outputs = ort_session.run(None, ort_inputs)
 
144
  val_loader = tqdm(val_loader, file=sys.stdout)
145
  with torch.no_grad():
146
  for batch_idx, (images, targets) in enumerate(val_loader):
147
+ inputs, targets = images.permute([0,2,3,1]).numpy(), targets
148
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
149
 
150
  outputs = ort_session.run(None, ort_inputs)
mnasnet_b1_int.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a4784fa94c18ad2cb8c91081663cbdc825ba7bc61d1d913abd50d0fe0ff84949
3
- size 17571696
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:573e80009e1d0c21858d507ee6702ea73d092f9bf8be09ac7d901f0259e4bc1f
3
+ size 17571818