yolov3 / onnx_inference.py
yixionghuo's picture
Upload onnx_inference.py
5c271c7 verified
raw
history blame
4.9 kB
import onnxruntime
import argparse
import os
from utils import *
def pre_process(img):
"""
Preprocessing part of YOLOv3 for scaling and padding image as input to the network.
Args:
img (numpy.ndarray): H x W x C, image read with OpenCV
Returns:
padded_img (numpy.ndarray): preprocessed image to be fed to the network
"""
img = letterbox(img, auto=False)[0]
# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
img = img.astype("float32")
img = img / 255.0
img = img[np.newaxis, :]
return img
def post_process(x, conf_thres=0.1, iou_thres=0.6, multi_label=True,
classes=None, agnostic=False):
"""
Post-processing part of YOLOv3 for generating final results from outputs of the network.
Returns:
pred (torch.tensor): n x 6, dets[:,:4] -> boxes, dets[:,4] -> scores, dets[:,5] -> class indices
"""
stride = [32, 16, 8]
anchors = [[10, 13, 16, 30, 33, 23],
[30, 61, 62, 45, 59, 119],
[116, 90, 156, 198, 373, 326]]
temp = [13, 26, 52]
res = []
def create_grids(ng=(13, 13)):
nx, ny = ng # x and y grid size
ng = torch.tensor(ng, dtype=torch.float)
# build xy offsets
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
grid = torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
return grid
for i in range(3):
out = torch.from_numpy(x[i])
bs, _, ny, nx = out.shape # bs, 255, 13, 13
anchor = torch.Tensor(anchors[2 - i]).reshape(3, 2)
anchor_vec = anchor / stride[i]
anchor_wh = anchor_vec.view(1, 3, 1, 1, 2)
grid = create_grids((nx, ny))
out = out.view(
bs, 3, 85, temp[i], temp[i]).permute(
0, 1, 3, 4, 2).contiguous() # prediction
io = out.clone()
io[..., :2] = torch.sigmoid(io[..., :2]) + grid
io[..., 2:4] = torch.exp(io[..., 2:4]) * anchor_wh
io[..., :4] *= stride[i]
torch.sigmoid_(io[..., 4:])
res.append(io.view(bs, -1, 85))
pred = non_max_suppression(torch.cat(res, 1), conf_thres,
iou_thres, multi_label=multi_label,
classes=classes, agnostic=agnostic)
return pred
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='One image inference of onnx model')
parser.add_argument(
'--img',
type=str,
help='Path of input image')
parser.add_argument(
'--out',
type=str,
default='.',
help='Path of out put image')
parser.add_argument(
"--ipu",
action="store_true",
help="Use IPU for inference.")
parser.add_argument(
"--provider_config",
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.")
parser.add_argument(
"--onnx_path",
type=str,
default="yolov3-8.onnx",
help="Path of the onnx model.")
opt = parser.parse_args()
with open('coco.names', 'r') as f:
names = f.read()
if opt.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": opt.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
onnx_path = opt.onnx_path
onnx_model = onnxruntime.InferenceSession(
onnx_path, providers=providers, provider_options=provider_options)
path = opt.img
new_path = os.path.join(opt.out, "demo_infer.jpg")
conf_thres, iou_thres, classes, agnostic_nms, max_det = 0.25, \
0.45, None, False, 1000
img0 = cv2.imread(path)
img = pre_process(img0)
# onnx_input = {onnx_model.get_inputs()[0].name: img}
onnx_input = {onnx_model.get_inputs()[0].name: np.transpose(img, (0, 2 ,3, 1))}
onnx_output = onnx_model.run(None, onnx_input)
onnx_output = [np.transpose(out, (0, 3, 1, 2)) for out in onnx_output]
pred = post_process(onnx_output, conf_thres,
iou_thres, multi_label=False,
classes=classes, agnostic=agnostic_nms)
colors = [[random.randint(0, 255) for _ in range(3)]
for _ in range(len(names))]
det = pred[0]
im0 = img0.copy()
if det is None:
print('No objects detected!')
elif len(det):
# Rescale boxes from imgsz to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
label = '%s %.2f' % (names[int(cls)], conf)
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
# Stream results
cv2.imwrite(new_path, im0)