Acly commited on
Commit
110a69d
·
1 Parent(s): b5e8e2f

Export script for image encoder

Browse files
mobile_sam_encoder_onnx/export_image_encoder.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from mobile_sam import sam_model_registry
5
+ from .onnx_image_encoder import ImageEncoderOnnxModel
6
+
7
+ import os
8
+ import argparse
9
+ import warnings
10
+
11
+ try:
12
+ import onnxruntime # type: ignore
13
+
14
+ onnxruntime_exists = True
15
+ except ImportError:
16
+ onnxruntime_exists = False
17
+
18
+ parser = argparse.ArgumentParser(
19
+ description="Export the SAM image encoder to an ONNX model."
20
+ )
21
+
22
+ parser.add_argument(
23
+ "--checkpoint",
24
+ type=str,
25
+ required=True,
26
+ help="The path to the SAM model checkpoint.",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--output", type=str, required=True, help="The filename to save the ONNX model to."
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--model-type",
35
+ type=str,
36
+ required=True,
37
+ help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--use-preprocess",
42
+ action="store_true",
43
+ help="Whether to preprocess the image by resizing, standardizing, etc.",
44
+ )
45
+
46
+ parser.add_argument(
47
+ "--opset",
48
+ type=int,
49
+ default=17,
50
+ help="The ONNX opset version to use. Must be >=11",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--quantize-out",
55
+ type=str,
56
+ default=None,
57
+ help=(
58
+ "If set, will quantize the model and save it with this name. "
59
+ "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
60
+ ),
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--gelu-approximate",
65
+ action="store_true",
66
+ help=(
67
+ "Replace GELU operations with approximations using tanh. Useful "
68
+ "for some runtimes that have slow or unimplemented erf ops, used in GELU."
69
+ ),
70
+ )
71
+
72
+
73
+ def run_export(
74
+ model_type: str,
75
+ checkpoint: str,
76
+ output: str,
77
+ use_preprocess: bool,
78
+ opset: int,
79
+ gelu_approximate: bool = False,
80
+ ):
81
+ print("Loading model...")
82
+ sam = sam_model_registry[model_type](checkpoint=checkpoint)
83
+
84
+ onnx_model = ImageEncoderOnnxModel(
85
+ model=sam,
86
+ use_preprocess=use_preprocess,
87
+ pixel_mean=[123.675, 116.28, 103.53],
88
+ pixel_std=[58.395, 57.12, 57.375],
89
+ )
90
+
91
+ if gelu_approximate:
92
+ for n, m in onnx_model.named_modules():
93
+ if isinstance(m, torch.nn.GELU):
94
+ m.approximate = "tanh"
95
+
96
+ image_size = sam.image_encoder.img_size
97
+ if use_preprocess:
98
+ dummy_input = {
99
+ "input_image": torch.randn((image_size, image_size, 3), dtype=torch.float)
100
+ }
101
+ dynamic_axes = {
102
+ "input_image": {0: "image_height", 1: "image_width"},
103
+ }
104
+ else:
105
+ dummy_input = {
106
+ "input_image": torch.randn(
107
+ (1, 3, image_size, image_size), dtype=torch.float
108
+ )
109
+ }
110
+ dynamic_axes = None
111
+
112
+ _ = onnx_model(**dummy_input)
113
+
114
+ output_names = ["image_embeddings"]
115
+
116
+ with warnings.catch_warnings():
117
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
118
+ warnings.filterwarnings("ignore", category=UserWarning)
119
+ print(f"Exporting onnx model to {output}...")
120
+ if model_type == "vit_h":
121
+ output_dir, output_file = os.path.split(output)
122
+ os.makedirs(output_dir, mode=0o777, exist_ok=True)
123
+ torch.onnx.export(
124
+ onnx_model,
125
+ tuple(dummy_input.values()),
126
+ output,
127
+ export_params=True,
128
+ verbose=False,
129
+ opset_version=opset,
130
+ do_constant_folding=True,
131
+ input_names=list(dummy_input.keys()),
132
+ output_names=output_names,
133
+ dynamic_axes=dynamic_axes,
134
+ )
135
+ else:
136
+ with open(output, "wb") as f:
137
+ torch.onnx.export(
138
+ onnx_model,
139
+ tuple(dummy_input.values()),
140
+ f,
141
+ export_params=True,
142
+ verbose=False,
143
+ opset_version=opset,
144
+ do_constant_folding=True,
145
+ input_names=list(dummy_input.keys()),
146
+ output_names=output_names,
147
+ dynamic_axes=dynamic_axes,
148
+ )
149
+
150
+ if onnxruntime_exists:
151
+ ort_inputs = {k: to_numpy(v) for k, v in dummy_input.items()}
152
+ providers = ["CPUExecutionProvider"]
153
+
154
+ if model_type == "vit_h":
155
+ session_option = onnxruntime.SessionOptions()
156
+ ort_session = onnxruntime.InferenceSession(output, providers=providers)
157
+ param_file = os.listdir(output_dir)
158
+ param_file.remove(output_file)
159
+ for i, layer in enumerate(param_file):
160
+ with open(os.path.join(output_dir, layer), "rb") as fp:
161
+ weights = np.frombuffer(fp.read(), dtype=np.float32)
162
+ weights = onnxruntime.OrtValue.ortvalue_from_numpy(weights)
163
+ session_option.add_initializer(layer, weights)
164
+ else:
165
+ ort_session = onnxruntime.InferenceSession(output, providers=providers)
166
+
167
+ _ = ort_session.run(None, ort_inputs)
168
+ print("Model has successfully been run with ONNXRuntime.")
169
+
170
+
171
+ def to_numpy(tensor):
172
+ return tensor.cpu().numpy()
173
+
174
+
175
+ if __name__ == "__main__":
176
+ args = parser.parse_args()
177
+ run_export(
178
+ model_type=args.model_type,
179
+ checkpoint=args.checkpoint,
180
+ output=args.output,
181
+ use_preprocess=args.use_preprocess,
182
+ opset=args.opset,
183
+ gelu_approximate=args.gelu_approximate,
184
+ )
185
+
186
+ if args.quantize_out is not None:
187
+ assert onnxruntime_exists, "onnxruntime is required to quantize the model."
188
+ from onnxruntime.quantization import QuantType # type: ignore
189
+ from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
190
+
191
+ print(f"Quantizing model and writing to {args.quantize_out}...")
192
+ quantize_dynamic(
193
+ model_input=args.output,
194
+ model_output=args.quantize_out,
195
+ optimize_model=True,
196
+ per_channel=False,
197
+ reduce_range=False,
198
+ weight_type=QuantType.QUInt8,
199
+ )
200
+ print("Done!")
mobile_sam_encoder_onnx/onnx_image_encoder.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from typing import Tuple, List
6
+
7
+ import mobile_sam
8
+ from mobile_sam.modeling import Sam
9
+ from mobile_sam.utils.amg import calculate_stability_score
10
+
11
+
12
+ class ImageEncoderOnnxModel(nn.Module):
13
+ """
14
+ This model should not be called directly, but is used in ONNX export.
15
+ It combines the image encoder of Sam, with some functions modified to enable
16
+ model tracing. Also supports extra options controlling what information. See
17
+ the ONNX export script for details.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model: Sam,
23
+ use_preprocess: bool,
24
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
25
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
26
+ ):
27
+ super().__init__()
28
+ self.use_preprocess = use_preprocess
29
+ self.pixel_mean = torch.tensor(pixel_mean, dtype=torch.float)
30
+ self.pixel_std = torch.tensor(pixel_std, dtype=torch.float)
31
+ self.image_encoder = model.image_encoder
32
+
33
+ @torch.no_grad()
34
+ def forward(self, input_image: torch.Tensor):
35
+ if self.use_preprocess:
36
+ input_image = self.preprocess(input_image)
37
+ image_embeddings = self.image_encoder(input_image)
38
+ return image_embeddings
39
+
40
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
41
+ # Normalize colors
42
+ x = (x - self.pixel_mean) / self.pixel_std
43
+
44
+ # permute channels
45
+ x = torch.permute(x, (2, 0, 1))
46
+
47
+ # Pad
48
+ h, w = x.shape[-2:]
49
+ padh = self.image_encoder.img_size - h
50
+ padw = self.image_encoder.img_size - w
51
+ x = F.pad(x, (0, padw, 0, padh))
52
+
53
+ # expand channels
54
+ x = torch.unsqueeze(x, 0)
55
+ return x