File size: 1,884 Bytes
faac7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils import data
import torchvision.transforms as transform
import torch.nn.functional as F
import onnxruntime
from PIL import Image
import argparse
from datasets.utils import colorize_mask, build_img


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='SemanticFPN model')
    parser.add_argument('--onnx_path', type=str, default='FPN_int.onnx')
    parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
    parser.add_argument('--input_path', type=str, default='data/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
    parser.add_argument('--ipu', action='store_true', help='use ipu')
    parser.add_argument('--provider_config', type=str, default=None,
                    help='provider config path')
    args = parser.parse_args()

    if args.ipu:
        providers = ["VitisAIExecutionProvider"]
        provider_options = [{"config_file": args.provider_config}]
    else:
        providers = ['CPUExecutionProvider']
        provider_options = None

    onnx_path = args.onnx_path
    input_img = build_img(args)
    session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
    ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy()}
    ort_output = session.run(None, ort_input)[0]
    if isinstance(ort_output, (tuple, list)):
        ort_output = ort_output[0]

    output = ort_output[0].transpose(1, 2, 0)
    seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
    color_mask = colorize_mask(seg_pred)
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    color_mask.save(args.save_path)