jiliu1 commited on
Commit
d124dee
1 Parent(s): 0c41ef1

Upload 3 files

Browse files

Update code and model for NHWC format usage.

HighResolutionNet_int.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1daf04ac2d732c753c4bd79759d6942d62b55acb4cedc30e22ecbfa7647a5c22
3
+ size 263764409
hrnet_quantized_onnx_eval.py CHANGED
@@ -127,7 +127,11 @@ def run_onnx_inference(ort_session, img):
127
  ndarray: Model inference result.
128
  """
129
  pre_img, pad_h, pad_w = preprocess(img)
 
 
130
  img = np.expand_dims(pre_img, 0)
 
 
131
  ort_inputs = {ort_session.get_inputs()[0].name: img}
132
  o1 = ort_session.run(None, ort_inputs)[0]
133
  h, w = o1.shape[-2:]
@@ -160,6 +164,8 @@ def testval(ort_session, root, list_path):
160
  image = image.numpy()[0]
161
  out = run_onnx_inference(ort_session, image)
162
  size = label.size()
 
 
163
  if out.shape[2] != size[1] or out.shape[3] != size[2]:
164
  out = torch.from_numpy(out).cpu()
165
  pred = F.interpolate(
 
127
  ndarray: Model inference result.
128
  """
129
  pre_img, pad_h, pad_w = preprocess(img)
130
+ # transform chw into hwc format
131
+
132
  img = np.expand_dims(pre_img, 0)
133
+ img = np.transpose(img, (0,2,3,1))
134
+
135
  ort_inputs = {ort_session.get_inputs()[0].name: img}
136
  o1 = ort_session.run(None, ort_inputs)[0]
137
  h, w = o1.shape[-2:]
 
164
  image = image.numpy()[0]
165
  out = run_onnx_inference(ort_session, image)
166
  size = label.size()
167
+ # for hwc output
168
+ out = out.transpose(0, 3, 1, 2)
169
  if out.shape[2] != size[1] or out.shape[3] != size[2]:
170
  out = torch.from_numpy(out).cpu()
171
  pred = F.interpolate(
hrnet_quantized_onnx_inference.py CHANGED
@@ -37,6 +37,7 @@ def run_onnx_inference(ort_session, img):
37
  """
38
  pre_img, pad_h, pad_w = preprocess(img)
39
  img = np.expand_dims(pre_img, 0)
 
40
  ort_inputs = {ort_session.get_inputs()[0].name: img}
41
  o1 = ort_session.run(None, ort_inputs)[0]
42
  h, w = o1.shape[-2:]
@@ -52,6 +53,7 @@ def vis(out, image, save_path='color_.png'):
52
  220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
53
  0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32 ]
54
  # out = out[0]
 
55
  if out.shape[2] != image.shape[0] or out.shape[3] != image.shape[1]:
56
  out = torch.from_numpy(out).cpu()
57
  out = F.interpolate(
 
37
  """
38
  pre_img, pad_h, pad_w = preprocess(img)
39
  img = np.expand_dims(pre_img, 0)
40
+ img = np.transpose(img, (0,2,3,1))
41
  ort_inputs = {ort_session.get_inputs()[0].name: img}
42
  o1 = ort_session.run(None, ort_inputs)[0]
43
  h, w = o1.shape[-2:]
 
53
  220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
54
  0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32 ]
55
  # out = out[0]
56
+ out = out.transpose(0, 3, 1, 2)
57
  if out.shape[2] != image.shape[0] or out.shape[3] != image.shape[1]:
58
  out = torch.from_numpy(out).cpu()
59
  out = F.interpolate(