Spaces:
Sleeping
Sleeping
Add code
Browse files- app.py +252 -0
- concepts/merged_2560.pkl +3 -0
- models/tap_vit_l_03f8ec.pkl +3 -0
- requirements.txt +5 -0
- tokenize_anything/__init__.py +19 -0
- tokenize_anything/build_model.py +114 -0
- tokenize_anything/modeling/__init__.py +24 -0
- tokenize_anything/modeling/concept_projector.py +74 -0
- tokenize_anything/modeling/image_decoder.py +224 -0
- tokenize_anything/modeling/image_encoder.py +254 -0
- tokenize_anything/modeling/image_tokenizer.py +201 -0
- tokenize_anything/modeling/prompt_encoder.py +100 -0
- tokenize_anything/modeling/text_decoder.py +206 -0
- tokenize_anything/modeling/text_tokenizer.model +3 -0
- tokenize_anything/modeling/text_tokenizer.py +127 -0
- tokenize_anything/test_engine.py +81 -0
- tokenize_anything/utils/__init__.py +15 -0
- tokenize_anything/utils/image.py +73 -0
- tokenize_anything/utils/mask.py +45 -0
- tokenize_anything/utils/timer.py +51 -0
- tokenize_anything/version.py +3 -0
app.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Gradio application."""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import multiprocessing as mp
|
20 |
+
import os
|
21 |
+
import time
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from tokenize_anything import test_engine
|
27 |
+
from tokenize_anything.utils.image import im_rescale
|
28 |
+
from tokenize_anything.utils.image import im_vstack
|
29 |
+
|
30 |
+
|
31 |
+
def parse_args():
|
32 |
+
"""Parse arguments."""
|
33 |
+
parser = argparse.ArgumentParser(description="Launch gradio app.")
|
34 |
+
parser.add_argument("--model-type", type=str, default="tap_vit_l")
|
35 |
+
parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_03f8ec.pkl")
|
36 |
+
parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
|
37 |
+
parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices.")
|
38 |
+
return parser.parse_args()
|
39 |
+
|
40 |
+
|
41 |
+
class Predictor(object):
|
42 |
+
"""Predictor."""
|
43 |
+
|
44 |
+
def __init__(self, model, kwargs):
|
45 |
+
self.model = model
|
46 |
+
self.kwargs = kwargs
|
47 |
+
self.batch_size = kwargs.get("batch_size", 256)
|
48 |
+
self.model.concept_projector.reset_weights(kwargs["concept_weights"])
|
49 |
+
self.model.text_decoder.reset_cache(max_batch_size=self.batch_size)
|
50 |
+
|
51 |
+
def preprocess_images(self, imgs):
|
52 |
+
"""Preprocess the inference images."""
|
53 |
+
im_batch, im_shapes, im_scales = [], [], []
|
54 |
+
for img in imgs:
|
55 |
+
scaled_imgs, scales = im_rescale(img, scales=[1024])
|
56 |
+
im_batch += scaled_imgs
|
57 |
+
im_scales += scales
|
58 |
+
im_shapes += [x.shape[:2] for x in scaled_imgs]
|
59 |
+
im_batch = im_vstack(im_batch, self.model.pixel_mean_value, size=(1024, 1024))
|
60 |
+
im_shapes = np.array(im_shapes)
|
61 |
+
im_scales = np.array(im_scales).reshape((len(im_batch), -1))
|
62 |
+
im_info = np.hstack([im_shapes, im_scales]).astype("float32")
|
63 |
+
return im_batch, im_info
|
64 |
+
|
65 |
+
@torch.inference_mode()
|
66 |
+
def get_results(self, examples):
|
67 |
+
"""Return the results."""
|
68 |
+
# Preprocess images and prompts.
|
69 |
+
imgs = [example["img"] for example in examples]
|
70 |
+
points = np.concatenate([example["points"] for example in examples])
|
71 |
+
im_batch, im_info = self.preprocess_images(imgs)
|
72 |
+
num_prompts = points.shape[0] if len(points.shape) > 2 else 1
|
73 |
+
batch_shape = im_batch.shape[0], num_prompts // im_batch.shape[0]
|
74 |
+
batch_points = points.reshape(batch_shape + (-1, 3))
|
75 |
+
batch_points[:, :, :, :2] *= im_info[:, None, None, 2:4]
|
76 |
+
batch_points = batch_points.reshape(points.shape)
|
77 |
+
# Predict tokens and masks.
|
78 |
+
inputs = self.model.get_inputs({"img": im_batch})
|
79 |
+
inputs.update(self.model.get_features(inputs))
|
80 |
+
outputs = self.model.get_outputs(dict(**inputs, **{"points": batch_points}))
|
81 |
+
# Select final mask.
|
82 |
+
iou_pred = outputs["iou_pred"].cpu().numpy()
|
83 |
+
point_score = batch_points[:, 0, 2].__eq__(2).__sub__(0.5)[:, None]
|
84 |
+
rank_scores = iou_pred + point_score * ([1000] + [0] * (iou_pred.shape[1] - 1))
|
85 |
+
mask_index = np.arange(rank_scores.shape[0]), rank_scores.argmax(1)
|
86 |
+
iou_scores = outputs["iou_pred"][mask_index].cpu().numpy().reshape(batch_shape)
|
87 |
+
# Upscale masks to the original image resolution.
|
88 |
+
mask_pred = outputs["mask_pred"][mask_index][:, None]
|
89 |
+
mask_pred = self.model.upscale_masks(mask_pred, im_batch.shape[1:-1])
|
90 |
+
mask_pred = mask_pred.view(batch_shape + mask_pred.shape[2:])
|
91 |
+
# Predict concepts.
|
92 |
+
concepts, scores = self.model.predict_concept(outputs["sem_embeds"][mask_index])
|
93 |
+
concepts, scores = [x.reshape(batch_shape) for x in (concepts, scores)]
|
94 |
+
# Generate captions.
|
95 |
+
sem_tokens = outputs["sem_tokens"][mask_index][:, None, :]
|
96 |
+
captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
|
97 |
+
# Postprecess results.
|
98 |
+
results = []
|
99 |
+
for i in range(batch_shape[0]):
|
100 |
+
pred_h, pred_w = im_info[i, :2].astype("int")
|
101 |
+
masks = mask_pred[i : i + 1, :, :pred_h, :pred_w]
|
102 |
+
masks = self.model.upscale_masks(masks, imgs[i].shape[:2])[0]
|
103 |
+
results.append(
|
104 |
+
{
|
105 |
+
"scores": np.stack([iou_scores[i], scores[i]], axis=-1),
|
106 |
+
"masks": masks.gt(0).cpu().numpy().astype("uint8"),
|
107 |
+
"concepts": concepts[i],
|
108 |
+
"captions": captions[i],
|
109 |
+
}
|
110 |
+
)
|
111 |
+
return results
|
112 |
+
|
113 |
+
|
114 |
+
class ServingCommand(object):
|
115 |
+
"""Command to run serving."""
|
116 |
+
|
117 |
+
def __init__(self, output_queue):
|
118 |
+
self.output_queue = output_queue
|
119 |
+
self.output_dict = mp.Manager().dict()
|
120 |
+
self.output_index = mp.Value("i", 0)
|
121 |
+
|
122 |
+
def postprocess_outputs(self, outputs):
|
123 |
+
"""Main the detection objects."""
|
124 |
+
scores, masks = outputs["scores"], outputs["masks"]
|
125 |
+
concepts, captions = outputs["concepts"], outputs["captions"]
|
126 |
+
text_template = "{} ({:.2f}, {:.2f}): {}"
|
127 |
+
text_contents = concepts, scores[:, 0], scores[:, 1], captions
|
128 |
+
texts = np.array([text_template.format(*vals) for vals in zip(*text_contents)])
|
129 |
+
return masks, texts
|
130 |
+
|
131 |
+
def run(self):
|
132 |
+
"""Main loop to make the serving outputs."""
|
133 |
+
while True:
|
134 |
+
img_id, outputs = self.output_queue.get()
|
135 |
+
self.output_dict[img_id] = self.postprocess_outputs(outputs)
|
136 |
+
|
137 |
+
|
138 |
+
def build_gradio_app(queues, command):
|
139 |
+
"""Build the gradio application."""
|
140 |
+
import cv2
|
141 |
+
import gradio as gr
|
142 |
+
import gradio_image_prompter as gr_ext
|
143 |
+
|
144 |
+
title = "Tokenize Anything"
|
145 |
+
header = (
|
146 |
+
"<div align='center'>"
|
147 |
+
f"<h1>{title}</h1>"
|
148 |
+
"<h3>A promptable model capable of simultaneously segmenting, recognizing and captioning</h3>"
|
149 |
+
"</div>"
|
150 |
+
)
|
151 |
+
theme = "soft"
|
152 |
+
css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
|
153 |
+
#anno-img .mask.active {opacity: 0.7}"""
|
154 |
+
|
155 |
+
def get_examples():
|
156 |
+
assets_dir = os.path.join(os.path.dirname(__file__), "../assets")
|
157 |
+
app_images = list(filter(lambda x: x.startswith("app_image"), os.listdir(assets_dir)))
|
158 |
+
app_images.sort()
|
159 |
+
return [{"image": os.path.join(assets_dir, x)} for x in app_images]
|
160 |
+
|
161 |
+
def on_prompt_opt(index):
|
162 |
+
click_img = gr.Image(None, visible=index == 0)
|
163 |
+
draw_img = gr.ImageEditor(None, visible=index != 0)
|
164 |
+
anno_img = gr.AnnotatedImage(None)
|
165 |
+
return click_img, draw_img, anno_img
|
166 |
+
|
167 |
+
def on_reset_btn():
|
168 |
+
click_img, draw_img = gr.Image(None), gr.ImageEditor(None)
|
169 |
+
anno_img = gr.AnnotatedImage(None)
|
170 |
+
return click_img, draw_img, anno_img
|
171 |
+
|
172 |
+
def on_submit_btn(click_img, mask_img, prompt, multipoint):
|
173 |
+
if prompt == 0:
|
174 |
+
img = cv2.imread(click_img["image"])
|
175 |
+
points = np.array(click_img["points"]).reshape((-1, 2, 3))
|
176 |
+
if multipoint == 1:
|
177 |
+
points = points.reshape((-1, 3))
|
178 |
+
lt = points[np.where(points[:, 2] == 2)[0]][None, :, :]
|
179 |
+
rb = points[np.where(points[:, 2] == 3)[0]][None, :, :]
|
180 |
+
poly = points[np.where(points[:, 2] <= 1)[0]][None, :, :]
|
181 |
+
points = [lt, rb, poly] if len(lt) > 0 else [poly, np.array([[[0, 0, 4]]])]
|
182 |
+
points = np.concatenate(points, axis=1)
|
183 |
+
points = (np.array([[[0, 0, 4]]]) if len(points) == 0 else points).astype("float32")
|
184 |
+
elif prompt == 1:
|
185 |
+
img, points = mask_img["background"][:, :, (2, 1, 0)], []
|
186 |
+
for layer in mask_img["layers"]:
|
187 |
+
ys, xs = np.nonzero(layer[:, :, 0])
|
188 |
+
keep = np.linspace(0, ys.shape[0], 11, dtype="int64")[1:-1]
|
189 |
+
points.append(np.stack([xs[keep][None, :], ys[keep][None, :]], 2))
|
190 |
+
points = np.concatenate(points).astype("float32")
|
191 |
+
points = np.pad(points, [(0, 0), (0, 0), (0, 1)], constant_values=1)
|
192 |
+
pad_points = np.array([[[0, 0, 4]]], "float32").repeat(points.shape[0], 0)
|
193 |
+
points = np.concatenate([points, pad_points], axis=1)
|
194 |
+
inputs = {"img": img, "points": points}
|
195 |
+
with command.output_index.get_lock():
|
196 |
+
command.output_index.value += 1
|
197 |
+
img_id = command.output_index.value
|
198 |
+
queues[img_id % len(queues)].put((img_id, inputs))
|
199 |
+
while img_id not in command.output_dict:
|
200 |
+
time.sleep(0.005)
|
201 |
+
masks, texts = command.output_dict.pop(img_id)
|
202 |
+
annotations = [(x, y) for x, y in zip(masks, texts)]
|
203 |
+
return inputs["img"][:, :, ::-1], annotations
|
204 |
+
|
205 |
+
app = gr.Blocks(title=title, theme=theme, css=css).__enter__()
|
206 |
+
gr.Markdown(header)
|
207 |
+
container, column = gr.Row().__enter__(), gr.Column().__enter__()
|
208 |
+
click_img = gr_ext.ImagePrompter(type="filepath", show_label=False)
|
209 |
+
draw_img = gr.ImageEditor(type="numpy", show_label=False, visible=False)
|
210 |
+
interactions = "LeftClick (FG) | MiddleClick (BG) | PressMove (Box) | Draw (Sketch)"
|
211 |
+
gr.Markdown("<h3 style='text-align: center'>[🖱️ | 🖐️]: 🌟🌟 {} 🌟🌟 </h3>".format(interactions))
|
212 |
+
row = gr.Row().__enter__()
|
213 |
+
prompt_opt = gr.Radio(["Point+Box", "Sketch"], label="Prompt", type="index", value="Point+Box")
|
214 |
+
point_opt = gr.Radio(["Batch", "Ensemble"], label="Multipoint", type="index", value="Batch")
|
215 |
+
_, row = row.__exit__(), gr.Row().__enter__()
|
216 |
+
reset_btn, submit_btn = gr.Button("Reset"), gr.Button("Execute")
|
217 |
+
_, row = row.__exit__(), gr.Row().__enter__()
|
218 |
+
gr.Examples(get_examples(), inputs=[click_img], label="Examples (for Point+Box only)")
|
219 |
+
_, _, column = row.__exit__(), column.__exit__(), gr.Column().__enter__()
|
220 |
+
anno_img = gr.AnnotatedImage(elem_id="anno-img", show_label=False)
|
221 |
+
reset_btn.click(on_reset_btn, [], [click_img, draw_img, anno_img])
|
222 |
+
submit_btn.click(on_submit_btn, [click_img, draw_img, prompt_opt, point_opt], [anno_img])
|
223 |
+
prompt_opt.change(on_prompt_opt, [prompt_opt], [click_img, draw_img, anno_img])
|
224 |
+
column.__exit__(), container.__exit__(), app.__exit__()
|
225 |
+
return app
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
args = parse_args()
|
230 |
+
queues = [mp.Queue(1024) for _ in range(len(args.device) + 1)]
|
231 |
+
commands = [
|
232 |
+
test_engine.InferenceCommand(
|
233 |
+
queues[i],
|
234 |
+
queues[-1],
|
235 |
+
kwargs={
|
236 |
+
"model_type": args.model_type,
|
237 |
+
"weights": args.checkpoint,
|
238 |
+
"concept_weights": args.concept,
|
239 |
+
"device": args.device[i],
|
240 |
+
"predictor_type": Predictor,
|
241 |
+
"verbose": i == 0,
|
242 |
+
},
|
243 |
+
)
|
244 |
+
for i in range(len(args.device))
|
245 |
+
]
|
246 |
+
commands += [ServingCommand(queues[-1])]
|
247 |
+
actors = [mp.Process(target=command.run, daemon=True) for command in commands]
|
248 |
+
for actor in actors:
|
249 |
+
actor.start()
|
250 |
+
app = build_gradio_app(queues[:-1], commands[-1])
|
251 |
+
app.queue()
|
252 |
+
app.launch()
|
concepts/merged_2560.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7a17403190a7a44669136d0ab278b1bb1e095bb68eff178c3e2617b2744bbb7
|
3 |
+
size 10514948
|
models/tap_vit_l_03f8ec.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d63a5aba993c34bf29c0466026136e18e25d2bd4ac9e51b8fc407b76c431707d
|
3 |
+
size 811637521
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
Pillow
|
3 |
+
gradio-image-prompter
|
4 |
+
torch>=2.0.0
|
5 |
+
flash-attn>=2.3.3
|
tokenize_anything/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Simultaneously Segment, Recognize, and Caption Anything with Promptable Tokenization."""
|
17 |
+
|
18 |
+
from tokenize_anything.build_model import model_registry
|
19 |
+
from tokenize_anything.version import __version__
|
tokenize_anything/build_model.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Build model."""
|
17 |
+
|
18 |
+
from functools import partial
|
19 |
+
import pickle
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from tokenize_anything.modeling import ConceptProjector
|
25 |
+
from tokenize_anything.modeling import ImageDecoder
|
26 |
+
from tokenize_anything.modeling import ImageEncoderViT
|
27 |
+
from tokenize_anything.modeling import ImageTokenizer
|
28 |
+
from tokenize_anything.modeling import PromptEncoder
|
29 |
+
from tokenize_anything.modeling import TextDecoder
|
30 |
+
from tokenize_anything.modeling import TextTokenizer
|
31 |
+
|
32 |
+
|
33 |
+
def get_device(device_index):
|
34 |
+
"""Create an available device object."""
|
35 |
+
if torch.cuda.is_available():
|
36 |
+
return torch.device("cuda", device_index)
|
37 |
+
return torch.device("cpu")
|
38 |
+
|
39 |
+
|
40 |
+
def load_weights(module, weights_file, strict=True):
|
41 |
+
"""Load a weights file."""
|
42 |
+
if not weights_file:
|
43 |
+
return module._IncompatibleKeys([], [])
|
44 |
+
if weights_file.endswith(".pkl"):
|
45 |
+
with open(weights_file, "rb") as f:
|
46 |
+
state_dict = pickle.load(f)
|
47 |
+
for k, v in state_dict.items():
|
48 |
+
state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
|
49 |
+
else:
|
50 |
+
state_dict = torch.load(weights_file)
|
51 |
+
return module.load_state_dict(state_dict, strict=strict)
|
52 |
+
|
53 |
+
|
54 |
+
def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
|
55 |
+
"""Build an image encoder with ViT."""
|
56 |
+
return ImageEncoderViT(
|
57 |
+
depth=depth,
|
58 |
+
embed_dim=embed_dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
mlp_ratio=4,
|
61 |
+
patch_size=16,
|
62 |
+
window_size=16,
|
63 |
+
image_size=image_size,
|
64 |
+
out_dim=out_dim,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", **kwargs):
|
69 |
+
"""Build an image tokenizer."""
|
70 |
+
image_size = kwargs.get("image_size", 1024)
|
71 |
+
prompt_embed_dim = kwargs.get("prompt_embed_dim", 256)
|
72 |
+
sem_embed_dim = kwargs.get("sem_embed_dim", 1024)
|
73 |
+
text_embed_dim = kwargs.get("text_embed_dim", 512)
|
74 |
+
text_decoder_depth = kwargs.get("text_decoder_depth", 12)
|
75 |
+
text_seq_len = kwargs.get("text_seq_len", 40)
|
76 |
+
text_tokenizer = TextTokenizer()
|
77 |
+
model = ImageTokenizer(
|
78 |
+
image_encoder=image_encoder(out_dim=prompt_embed_dim, image_size=image_size),
|
79 |
+
prompt_encoder=PromptEncoder(embed_dim=prompt_embed_dim, image_size=image_size),
|
80 |
+
image_decoder=ImageDecoder(
|
81 |
+
depth=2,
|
82 |
+
embed_dim=prompt_embed_dim,
|
83 |
+
num_heads=prompt_embed_dim // 32,
|
84 |
+
num_mask_tokens=4,
|
85 |
+
sem_embed_dim=sem_embed_dim,
|
86 |
+
),
|
87 |
+
text_tokenizer=text_tokenizer,
|
88 |
+
concept_projector=ConceptProjector(),
|
89 |
+
text_decoder=TextDecoder(
|
90 |
+
depth=text_decoder_depth,
|
91 |
+
embed_dim=text_embed_dim,
|
92 |
+
num_heads=text_embed_dim // 64,
|
93 |
+
mlp_ratio=4,
|
94 |
+
prompt_embed_dim=prompt_embed_dim,
|
95 |
+
max_seq_len=text_seq_len,
|
96 |
+
vocab_size=text_tokenizer.n_words,
|
97 |
+
),
|
98 |
+
)
|
99 |
+
load_weights(model, checkpoint)
|
100 |
+
model = model.to(device=get_device(device))
|
101 |
+
model = model.eval() if not kwargs.get("training", False) else model
|
102 |
+
model = model.half() if dtype == "float16" else model
|
103 |
+
model = model.bfloat16() if dtype == "bfloat16" else model
|
104 |
+
model = model.float() if dtype == "float32" else model
|
105 |
+
return model
|
106 |
+
|
107 |
+
|
108 |
+
vit_b_encoder = partial(vit_encoder, depth=12, embed_dim=768, num_heads=12)
|
109 |
+
vit_l_encoder = partial(vit_encoder, depth=24, embed_dim=1024, num_heads=16)
|
110 |
+
|
111 |
+
model_registry = {
|
112 |
+
"tap_vit_b": partial(image_tokenizer, image_encoder=vit_b_encoder),
|
113 |
+
"tap_vit_l": partial(image_tokenizer, image_encoder=vit_l_encoder),
|
114 |
+
}
|
tokenize_anything/modeling/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Modeling components."""
|
17 |
+
|
18 |
+
from tokenize_anything.modeling.concept_projector import ConceptProjector
|
19 |
+
from tokenize_anything.modeling.image_decoder import ImageDecoder
|
20 |
+
from tokenize_anything.modeling.image_encoder import ImageEncoderViT
|
21 |
+
from tokenize_anything.modeling.image_tokenizer import ImageTokenizer
|
22 |
+
from tokenize_anything.modeling.prompt_encoder import PromptEncoder
|
23 |
+
from tokenize_anything.modeling.text_decoder import TextDecoder
|
24 |
+
from tokenize_anything.modeling.text_tokenizer import TextTokenizer
|
tokenize_anything/modeling/concept_projector.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Concet projector."""
|
17 |
+
|
18 |
+
import pickle
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
from torch import nn
|
23 |
+
|
24 |
+
|
25 |
+
class ConceptProjector(nn.Module):
|
26 |
+
"""Encode and decode concept using CLIP."""
|
27 |
+
|
28 |
+
def __init__(self, src_weights=None, tgt_weights=None):
|
29 |
+
super(ConceptProjector, self).__init__()
|
30 |
+
self.reset_weights(src_weights, tgt_weights)
|
31 |
+
|
32 |
+
def reset_weights(self, src_weights=None, tgt_weights=None):
|
33 |
+
"""Reset the normalized projection weights."""
|
34 |
+
if src_weights is not None:
|
35 |
+
with open(src_weights, "rb") as f:
|
36 |
+
self.src_weights, self.concepts = pickle.load(f)
|
37 |
+
self.src_weights = torch.from_numpy(self.src_weights)
|
38 |
+
self.concepts = np.array(self.concepts)
|
39 |
+
if tgt_weights is not None:
|
40 |
+
with open(tgt_weights, "rb") as f:
|
41 |
+
self.tgt_weights, self.concepts = pickle.load(f)
|
42 |
+
self.tgt_weights = torch.from_numpy(self.tgt_weights)
|
43 |
+
self.concepts = np.array(self.concepts)
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def maybe_convert(embeds, proj):
|
47 |
+
"""Convert inputs for safe projection."""
|
48 |
+
if embeds.dtype != torch.float32:
|
49 |
+
embeds = embeds.float()
|
50 |
+
if embeds.device != proj.device:
|
51 |
+
proj = proj.to(device=embeds.device)
|
52 |
+
return embeds, proj
|
53 |
+
|
54 |
+
def encode_src(self, src_embeds):
|
55 |
+
"""Encode source visual embedding via concept projection."""
|
56 |
+
src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
|
57 |
+
logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
|
58 |
+
return nn.functional.log_softmax(logits, dim=-1)
|
59 |
+
|
60 |
+
def encode_tgt(self, tgt_embeds):
|
61 |
+
"""Encode target visual embedding via concept projection."""
|
62 |
+
tgt_embeds, self.tgt_weights = self.maybe_convert(tgt_embeds, self.tgt_weights)
|
63 |
+
logits = nn.functional.normalize(tgt_embeds, dim=-1) @ self.tgt_weights
|
64 |
+
return nn.functional.log_softmax(logits, dim=-1)
|
65 |
+
|
66 |
+
def decode(self, src_embeds, k=1, return_index=False, return_prob=False):
|
67 |
+
"""Return the top-k concepts of source visual embedding."""
|
68 |
+
src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
|
69 |
+
logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
|
70 |
+
probs = nn.functional.softmax(logits, dim=-1)
|
71 |
+
if return_prob:
|
72 |
+
return probs.cpu().numpy()
|
73 |
+
score, index = [x.cpu().numpy() for x in probs.topk(k, -1)]
|
74 |
+
return (index if return_index else self.concepts[index]), score
|
tokenize_anything/modeling/image_decoder.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Image decoder."""
|
17 |
+
|
18 |
+
try:
|
19 |
+
from flash_attn import flash_attn_func
|
20 |
+
except ImportError:
|
21 |
+
flash_attn_func = None
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
|
26 |
+
|
27 |
+
class TransposedLayerNorm(nn.LayerNorm):
|
28 |
+
"""LayerNorm with pre-transposed spatial axes."""
|
29 |
+
|
30 |
+
def forward(self, input):
|
31 |
+
return super().forward(input.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
32 |
+
|
33 |
+
|
34 |
+
class MLP(nn.Module):
|
35 |
+
"""Two layers MLP."""
|
36 |
+
|
37 |
+
def __init__(self, dim, mlp_dim, activation_type="ReLU"):
|
38 |
+
super(MLP, self).__init__()
|
39 |
+
self.fc1 = nn.Linear(dim, mlp_dim)
|
40 |
+
self.fc2 = nn.Linear(mlp_dim, dim)
|
41 |
+
self.activation = getattr(nn, activation_type)()
|
42 |
+
self.activation.inplace = True
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return self.fc2(self.activation(self.fc1(x)))
|
46 |
+
|
47 |
+
|
48 |
+
class Attention(nn.Module):
|
49 |
+
"""Multi-head attention."""
|
50 |
+
|
51 |
+
def __init__(self, dim=256, num_heads=8, attn_ratio=1):
|
52 |
+
super(Attention, self).__init__()
|
53 |
+
qkv_dim = int(dim * attn_ratio)
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.head_dim = qkv_dim // num_heads
|
56 |
+
self.q_proj = nn.Linear(dim, qkv_dim)
|
57 |
+
self.k_proj = nn.Linear(dim, qkv_dim)
|
58 |
+
self.v_proj = nn.Linear(dim, qkv_dim)
|
59 |
+
self.proj = nn.Linear(qkv_dim, dim)
|
60 |
+
self.scale = self.head_dim**-0.5
|
61 |
+
|
62 |
+
def forward(self, q, k, v):
|
63 |
+
q = self.q_proj(q).view((-1, q.size(1), self.num_heads, self.head_dim))
|
64 |
+
k = self.k_proj(k).view((-1, k.size(1), self.num_heads, self.head_dim))
|
65 |
+
v = self.v_proj(v).view((-1, v.size(1), self.num_heads, self.head_dim))
|
66 |
+
o = flash_attn_func(q, k, v, softmax_scale=self.scale)
|
67 |
+
return self.proj(o.flatten(2))
|
68 |
+
|
69 |
+
|
70 |
+
class Block(nn.Module):
|
71 |
+
"""Transformer block."""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
dim=256,
|
76 |
+
num_heads=8,
|
77 |
+
attn_ratio=0.5,
|
78 |
+
mlp_dim=2048,
|
79 |
+
dropout=0.1,
|
80 |
+
activation_type="ReLU",
|
81 |
+
skip_first_query_pos=False,
|
82 |
+
):
|
83 |
+
super(Block, self).__init__()
|
84 |
+
self.self_attn = Attention(dim, num_heads)
|
85 |
+
self.norm1 = nn.LayerNorm(dim)
|
86 |
+
self.cross_attn_token_to_image = Attention(dim, num_heads, attn_ratio)
|
87 |
+
self.norm2 = nn.LayerNorm(dim)
|
88 |
+
self.mlp = MLP(dim, mlp_dim, activation_type)
|
89 |
+
self.norm3 = nn.LayerNorm(dim)
|
90 |
+
self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio)
|
91 |
+
self.norm4 = nn.LayerNorm(dim)
|
92 |
+
self.dropout = nn.Dropout(dropout, inplace=True)
|
93 |
+
self.skip_first_query_pos = skip_first_query_pos
|
94 |
+
|
95 |
+
def forward(self, query, key, query_pos, key_pos):
|
96 |
+
if self.skip_first_query_pos:
|
97 |
+
query = self.norm1(self.self_attn(query, query, query))
|
98 |
+
else:
|
99 |
+
q = query + query_pos
|
100 |
+
query = self.norm1(self.dropout(self.self_attn(q, q, query)).add_(query))
|
101 |
+
q, k = query + query_pos, key + key_pos
|
102 |
+
query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query))
|
103 |
+
query = self.norm3(self.dropout(self.mlp(query)).add_(query))
|
104 |
+
q = query + query_pos
|
105 |
+
key = self.norm4(self.cross_attn_image_to_token(k, q, query).add_(key))
|
106 |
+
return query, key
|
107 |
+
|
108 |
+
|
109 |
+
class Transformer(nn.Module):
|
110 |
+
"""Two-way transformer decoder."""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
embed_dim=256,
|
115 |
+
num_heads=8,
|
116 |
+
attn_ratio=0.5,
|
117 |
+
mlp_dim=2048,
|
118 |
+
dropout=0.1,
|
119 |
+
activation_type="ReLU",
|
120 |
+
depth=2,
|
121 |
+
):
|
122 |
+
super(Transformer, self).__init__()
|
123 |
+
self.blocks = nn.ModuleList(
|
124 |
+
Block(
|
125 |
+
embed_dim,
|
126 |
+
num_heads,
|
127 |
+
attn_ratio=attn_ratio,
|
128 |
+
mlp_dim=mlp_dim,
|
129 |
+
dropout=dropout,
|
130 |
+
activation_type=activation_type,
|
131 |
+
skip_first_query_pos=i == 0,
|
132 |
+
)
|
133 |
+
for i in range(depth)
|
134 |
+
)
|
135 |
+
self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio)
|
136 |
+
self.norm = nn.LayerNorm(embed_dim)
|
137 |
+
self.dropout = nn.Dropout(dropout, inplace=True)
|
138 |
+
|
139 |
+
def forward(self, query, key, query_pos, key_pos):
|
140 |
+
for blk in self.blocks:
|
141 |
+
query, key = blk(query, key, query_pos, key_pos)
|
142 |
+
q, k = query + query_pos, key + key_pos
|
143 |
+
query = self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query)
|
144 |
+
query = self.norm(query)
|
145 |
+
return query, key
|
146 |
+
|
147 |
+
|
148 |
+
class Predictor(nn.Module):
|
149 |
+
"""MLP predictor."""
|
150 |
+
|
151 |
+
def __init__(self, in_dim, out_dim, mlp_dim=None, depth=3):
|
152 |
+
super(Predictor, self).__init__()
|
153 |
+
mlp_dims = [mlp_dim or in_dim] * (depth - 1)
|
154 |
+
in_dims, out_dims = [in_dim] + mlp_dims, mlp_dims + [out_dim]
|
155 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip(in_dims, out_dims))
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
for fc in self.layers[:-1]:
|
159 |
+
x = nn.functional.relu(fc(x), inplace=True)
|
160 |
+
return self.layers[-1](x)
|
161 |
+
|
162 |
+
|
163 |
+
class ImageDecoder(nn.Module):
|
164 |
+
"""Module to decode region tokens and masks."""
|
165 |
+
|
166 |
+
def __init__(self, depth, embed_dim, num_heads, num_mask_tokens=4, sem_embed_dim=1024):
|
167 |
+
super(ImageDecoder, self).__init__()
|
168 |
+
self.embed_dim = embed_dim
|
169 |
+
self.num_mask_tokens = num_mask_tokens
|
170 |
+
self.transformer = Transformer(embed_dim, num_heads=num_heads, depth=depth)
|
171 |
+
self.iou_token = nn.Embedding(1, embed_dim)
|
172 |
+
self.sem_tokens = nn.Embedding(self.num_mask_tokens, embed_dim)
|
173 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, embed_dim)
|
174 |
+
self.output_conv = nn.Sequential(
|
175 |
+
nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2),
|
176 |
+
TransposedLayerNorm(embed_dim // 4),
|
177 |
+
nn.GELU(),
|
178 |
+
nn.ConvTranspose2d(embed_dim // 4, embed_dim // 8, 2, 2),
|
179 |
+
nn.GELU(),
|
180 |
+
)
|
181 |
+
self.mask_pred = nn.ModuleList(
|
182 |
+
Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens)
|
183 |
+
)
|
184 |
+
self.iou_pred = Predictor(embed_dim, self.num_mask_tokens)
|
185 |
+
self.sem_pred = Predictor(embed_dim, sem_embed_dim, 1024)
|
186 |
+
|
187 |
+
def get_outputs(self, inputs):
|
188 |
+
img_embeds = inputs["img_embeds"]
|
189 |
+
sparse_embeds = inputs["sparse_embeds"]
|
190 |
+
ims_per_batch = img_embeds.size(0)
|
191 |
+
prompts_per_batch = sparse_embeds.size(0)
|
192 |
+
img_embed_size = img_embeds.shape[2:-1]
|
193 |
+
# Prepare query.
|
194 |
+
tokens = [self.sem_tokens.weight, self.iou_token.weight, self.mask_tokens.weight]
|
195 |
+
query = torch.cat(tokens).unsqueeze_(0).expand(prompts_per_batch, -1, -1)
|
196 |
+
query = torch.cat((query, sparse_embeds), dim=1)
|
197 |
+
num_tokens = query.shape[1] - sparse_embeds.shape[1]
|
198 |
+
# Prepare key.
|
199 |
+
key = img_embeds.expand(-1, prompts_per_batch // ims_per_batch, -1, -1, -1)
|
200 |
+
key = key.flatten(0, 1).flatten(1, 2)
|
201 |
+
# Decode.
|
202 |
+
query, key = self.transformer(query, key, query, inputs["img_pos"])
|
203 |
+
# Upscale key.
|
204 |
+
key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
|
205 |
+
output_masks = self.output_conv(key).flatten(2)
|
206 |
+
# Unpack query.
|
207 |
+
tokens = query[:, :num_tokens].unbind(dim=1)
|
208 |
+
iou_tokens = tokens[num_tokens - self.num_mask_tokens - 1]
|
209 |
+
mask_tokens = tokens[num_tokens - self.num_mask_tokens :]
|
210 |
+
sem_tokens = tokens[: self.num_mask_tokens]
|
211 |
+
# Predict.
|
212 |
+
mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
|
213 |
+
mask_pred = torch.stack(mask_pred, dim=1) @ output_masks
|
214 |
+
mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
|
215 |
+
mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
|
216 |
+
outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
|
217 |
+
outputs["sem_tokens"] = torch.stack(sem_tokens, dim=1)
|
218 |
+
outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"])
|
219 |
+
return outputs
|
220 |
+
|
221 |
+
def forward(self, inputs):
|
222 |
+
outputs = self.get_outputs(inputs)
|
223 |
+
outputs["iou_pred"] = outputs["iou_pred"].float()
|
224 |
+
return outputs
|
tokenize_anything/modeling/image_encoder.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
##############################################################################
|
15 |
+
"""Image encoder."""
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
def space_to_depth(input, block_size):
|
22 |
+
"""Rearrange blocks of spatial data into depth."""
|
23 |
+
if input.dim() == 3:
|
24 |
+
hXw, c = input.size()[1:]
|
25 |
+
h = w = int(hXw**0.5)
|
26 |
+
else:
|
27 |
+
h, w, c = input.size()[1:]
|
28 |
+
h1, w1 = h // block_size, w // block_size
|
29 |
+
c1 = (block_size**2) * c
|
30 |
+
input = input.reshape((-1, h1, block_size, w1, block_size, c))
|
31 |
+
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h1, w1, c1))
|
32 |
+
|
33 |
+
|
34 |
+
def depth_to_space(input, block_size):
|
35 |
+
"""Rearrange blocks of depth data into spatial."""
|
36 |
+
h1, w1, c1 = input.size()[1:]
|
37 |
+
h, w = h1 * block_size, w1 * block_size
|
38 |
+
c = c1 // (block_size**2)
|
39 |
+
input = input.reshape((-1, h1, w1, block_size, block_size, c))
|
40 |
+
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h, w, c))
|
41 |
+
|
42 |
+
|
43 |
+
class MLP(nn.Module):
|
44 |
+
"""Two layers MLP."""
|
45 |
+
|
46 |
+
def __init__(self, dim, mlp_ratio=4):
|
47 |
+
super(MLP, self).__init__()
|
48 |
+
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
|
49 |
+
self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
|
50 |
+
self.activation = nn.GELU()
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
return self.fc2(self.activation(self.fc1(x)))
|
54 |
+
|
55 |
+
|
56 |
+
class Attention(nn.Module):
|
57 |
+
"""Multihead attention."""
|
58 |
+
|
59 |
+
def __init__(self, dim, num_heads, qkv_bias=True):
|
60 |
+
super(Attention, self).__init__()
|
61 |
+
self.num_heads = num_heads
|
62 |
+
self.head_dim = dim // num_heads
|
63 |
+
self.scale = self.head_dim**-0.5
|
64 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
65 |
+
self.proj = nn.Linear(dim, dim)
|
66 |
+
self.rel_pos_embed = nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
|
70 |
+
qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4)
|
71 |
+
q, k, v = qkv.unbind(dim=0)
|
72 |
+
attn = q @ k.transpose(-2, -1).mul(self.scale)
|
73 |
+
attn = self.rel_pos_embed(attn)
|
74 |
+
o = nn.functional.softmax(attn, dim=-1) @ v
|
75 |
+
return self.proj(o.transpose(1, 2).flatten(2))
|
76 |
+
|
77 |
+
|
78 |
+
class Block(nn.Module):
|
79 |
+
"""Transformer block."""
|
80 |
+
|
81 |
+
def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True):
|
82 |
+
super(Block, self).__init__()
|
83 |
+
self.norm1 = nn.LayerNorm(dim)
|
84 |
+
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
|
85 |
+
self.norm2 = nn.LayerNorm(dim)
|
86 |
+
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x = self.attn(self.norm1(x)).add_(x)
|
90 |
+
return self.mlp(self.norm2(x)).add_(x)
|
91 |
+
|
92 |
+
|
93 |
+
class Bottleneck(nn.Module):
|
94 |
+
"""The bottleneck block."""
|
95 |
+
|
96 |
+
def __init__(self, dim, expansion=2, width=None):
|
97 |
+
super(Bottleneck, self).__init__()
|
98 |
+
width = width or dim // expansion
|
99 |
+
self.conv1 = nn.Conv2d(dim, width, 1, bias=False)
|
100 |
+
self.norm1 = nn.SyncBatchNorm(width)
|
101 |
+
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False)
|
102 |
+
self.norm2 = nn.SyncBatchNorm(width)
|
103 |
+
self.conv3 = nn.Conv2d(width, dim, 1, bias=False)
|
104 |
+
self.norm3 = nn.SyncBatchNorm(dim)
|
105 |
+
self.activation = nn.GELU()
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
shortcut = x
|
109 |
+
x = self.activation(self.norm1(self.conv1(x)))
|
110 |
+
x = self.activation(self.norm2(self.conv2(x)))
|
111 |
+
return self.norm3(self.conv3(x)).add_(shortcut)
|
112 |
+
|
113 |
+
|
114 |
+
class PatchEmbed(nn.Module):
|
115 |
+
"""Patch embedding layer."""
|
116 |
+
|
117 |
+
def __init__(self, dim=768, patch_size=16, bias=True):
|
118 |
+
super(PatchEmbed, self).__init__()
|
119 |
+
self.proj = nn.Conv2d(3, dim, patch_size, patch_size, bias=bias)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
return self.proj(x).flatten(2).transpose(1, 2)
|
123 |
+
|
124 |
+
|
125 |
+
class PosEmbed(nn.Module):
|
126 |
+
"""Position embedding layer."""
|
127 |
+
|
128 |
+
def __init__(self, dim, num_patches):
|
129 |
+
super(PosEmbed, self).__init__()
|
130 |
+
self.dim = dim
|
131 |
+
self.num_patches = num_patches
|
132 |
+
self.weight = nn.Parameter(torch.zeros(num_patches, dim))
|
133 |
+
nn.init.normal_(self.weight, std=0.02)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
return x.add_(self.weight)
|
137 |
+
|
138 |
+
|
139 |
+
class RelPosEmbed(nn.Module):
|
140 |
+
"""Relative position embedding layer."""
|
141 |
+
|
142 |
+
def __init__(self, num_heads, size):
|
143 |
+
super(RelPosEmbed, self).__init__()
|
144 |
+
self.register_buffer("index", self.get_index(size))
|
145 |
+
self.weight = nn.Parameter(torch.zeros(num_heads, (2 * size - 1) ** 2))
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def get_index(size):
|
149 |
+
"""Return the relative index."""
|
150 |
+
grid = torch.arange(size)
|
151 |
+
grid = torch.stack(torch.meshgrid(grid, grid, indexing="ij")).reshape((2, -1))
|
152 |
+
coords = grid[:, :, None] - grid[:, None, :] + (size - 1)
|
153 |
+
coords[0] *= 2 * size - 1
|
154 |
+
return coords.sum(0)
|
155 |
+
|
156 |
+
def get_bias(self):
|
157 |
+
return self.weight[:, self.index]
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
return x.add_(self.get_bias())
|
161 |
+
|
162 |
+
|
163 |
+
class SimpleFeaturePyramid(nn.Module):
|
164 |
+
"""Module to create pyramid features."""
|
165 |
+
|
166 |
+
def __init__(self, embed_dim, out_dim, patch_size=16, min_lvl=4, max_lvl=4):
|
167 |
+
super(SimpleFeaturePyramid, self).__init__()
|
168 |
+
self.min_lvl, self.max_lvl = min_lvl, max_lvl
|
169 |
+
self.input_conv = nn.ModuleList()
|
170 |
+
self.lateral_conv = nn.ModuleList()
|
171 |
+
self.output_conv = nn.ModuleList()
|
172 |
+
patch_lvl = dict((2**i, i) for i in range(6))[patch_size]
|
173 |
+
for lvl in [min(i + 2, self.max_lvl) for i in range(4)]:
|
174 |
+
if lvl == patch_lvl or lvl < self.min_lvl:
|
175 |
+
self.input_conv += [nn.Identity()]
|
176 |
+
elif lvl < patch_lvl:
|
177 |
+
stride, layers = 2 ** (patch_lvl - lvl), []
|
178 |
+
while stride > 1:
|
179 |
+
layers += [nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)]
|
180 |
+
layers += [nn.SyncBatchNorm(embed_dim), nn.GELU()] if stride > 2 else []
|
181 |
+
stride /= 2
|
182 |
+
self.input_conv.append(nn.Sequential(*layers))
|
183 |
+
elif lvl > patch_lvl:
|
184 |
+
stride = 2 ** (lvl - patch_lvl)
|
185 |
+
self.input_conv += [nn.MaxPool2d(stride, stride)]
|
186 |
+
for _ in range(min_lvl, max_lvl + 1):
|
187 |
+
self.lateral_conv.append(
|
188 |
+
nn.Sequential(
|
189 |
+
nn.Conv2d(embed_dim, out_dim, kernel_size=1, bias=False),
|
190 |
+
nn.SyncBatchNorm(out_dim),
|
191 |
+
)
|
192 |
+
)
|
193 |
+
self.output_conv.append(
|
194 |
+
nn.Sequential(
|
195 |
+
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
|
196 |
+
nn.SyncBatchNorm(out_dim),
|
197 |
+
)
|
198 |
+
)
|
199 |
+
|
200 |
+
def forward(self, inputs):
|
201 |
+
inputs = inputs + [inputs[-1]] * (4 - len(inputs))
|
202 |
+
inputs = [conv(x) for conv, x in zip(self.input_conv, inputs)]
|
203 |
+
features = inputs[self.min_lvl - 1 : self.max_lvl]
|
204 |
+
laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)]
|
205 |
+
return [conv(x) for conv, x in zip(self.output_conv, laterals)]
|
206 |
+
|
207 |
+
|
208 |
+
class ImageEncoderViT(nn.Module):
|
209 |
+
"""ViT image encoder."""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
depth,
|
214 |
+
embed_dim,
|
215 |
+
num_heads,
|
216 |
+
mlp_ratio=4,
|
217 |
+
patch_size=16,
|
218 |
+
window_size=16,
|
219 |
+
image_size=1024,
|
220 |
+
out_dim=256,
|
221 |
+
):
|
222 |
+
super(ImageEncoderViT, self).__init__()
|
223 |
+
self.embed_dim = embed_dim
|
224 |
+
self.image_size = image_size
|
225 |
+
self.window_size = window_size or image_size // patch_size
|
226 |
+
self.patch_embed = PatchEmbed(embed_dim, patch_size)
|
227 |
+
self.pos_embed = PosEmbed(embed_dim, (image_size // patch_size) ** 2)
|
228 |
+
self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth))
|
229 |
+
for blk in self.blocks:
|
230 |
+
blk.attn.rel_pos_embed = RelPosEmbed(num_heads, self.window_size)
|
231 |
+
self.norm = nn.LayerNorm(embed_dim)
|
232 |
+
self.cross_conv = nn.ModuleList(Bottleneck(embed_dim) for _ in range(4))
|
233 |
+
self.neck = SimpleFeaturePyramid(embed_dim, out_dim, patch_size)
|
234 |
+
self.cross_indices = list(range(depth // 4 - 1, depth, depth // 4))
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
x = self.patch_embed(x)
|
238 |
+
x = self.pos_embed(x)
|
239 |
+
x = space_to_depth(x, self.window_size)
|
240 |
+
wmsa_shape = (-1,) + x.shape[1:]
|
241 |
+
msa_shape = (-1, self.window_size**2, self.embed_dim)
|
242 |
+
x = x.reshape(msa_shape)
|
243 |
+
for i, blk in enumerate(self.blocks):
|
244 |
+
x = blk(x)
|
245 |
+
if i in self.cross_indices or i == len(self.blocks) - 1:
|
246 |
+
x = self.norm(x) if i == len(self.blocks) - 1 else x
|
247 |
+
x = depth_to_space(x.reshape(wmsa_shape), self.window_size)
|
248 |
+
x = x.permute(0, 3, 1, 2)
|
249 |
+
if i in self.cross_indices:
|
250 |
+
x = self.cross_conv[self.cross_indices.index(i)](x)
|
251 |
+
if i in self.cross_indices and i < len(self.blocks) - 1:
|
252 |
+
x = x.permute(0, 2, 3, 1)
|
253 |
+
x = space_to_depth(x, self.window_size).reshape(msa_shape)
|
254 |
+
return self.neck([x])
|
tokenize_anything/modeling/image_tokenizer.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Image tokenizer."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
|
23 |
+
class ImageTokenizer(nn.Module):
|
24 |
+
"""Tokenize image regions with visual prompts."""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
image_encoder,
|
29 |
+
prompt_encoder,
|
30 |
+
image_decoder,
|
31 |
+
concept_projector=None,
|
32 |
+
text_tokenizer=None,
|
33 |
+
text_decoder=None,
|
34 |
+
pixel_mean=(103.53, 116.28, 123.675),
|
35 |
+
pixel_std=(57.375, 57.12, 58.395),
|
36 |
+
):
|
37 |
+
super(ImageTokenizer, self).__init__()
|
38 |
+
self.image_encoder = image_encoder
|
39 |
+
self.prompt_encoder = prompt_encoder
|
40 |
+
self.image_decoder = image_decoder
|
41 |
+
self.concept_projector = concept_projector
|
42 |
+
self.text_tokenizer = text_tokenizer
|
43 |
+
self.text_decoder = text_decoder
|
44 |
+
self.pixel_mean_value = pixel_mean # BGR order.
|
45 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean))
|
46 |
+
self.register_buffer("pixel_rsig", torch.Tensor(pixel_std).reciprocal_())
|
47 |
+
|
48 |
+
def get_inputs(self, inputs):
|
49 |
+
"""Return the model inputs.
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
inputs : dict
|
54 |
+
The initial inputs.
|
55 |
+
|
56 |
+
Returns
|
57 |
+
-------
|
58 |
+
dict
|
59 |
+
The model inputs.
|
60 |
+
|
61 |
+
"""
|
62 |
+
if not isinstance(inputs["img"], torch.Tensor):
|
63 |
+
inputs["img"] = torch.from_numpy(inputs["img"])
|
64 |
+
if inputs["img"].device != self.pixel_mean.device:
|
65 |
+
inputs["img"] = inputs["img"].to(device=self.pixel_mean.device)
|
66 |
+
inputs["img"] = inputs["img"].to(dtype=self.pixel_mean.dtype)
|
67 |
+
inputs["img"] = inputs["img"].sub(self.pixel_mean).mul_(self.pixel_rsig)
|
68 |
+
inputs["img"] = inputs["img"].permute(0, 3, 1, 2)
|
69 |
+
return inputs
|
70 |
+
|
71 |
+
def get_features(self, inputs):
|
72 |
+
"""Return the image features.
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
inputs : dict
|
77 |
+
The inputs.
|
78 |
+
|
79 |
+
Returns
|
80 |
+
-------
|
81 |
+
dict
|
82 |
+
The image features.
|
83 |
+
|
84 |
+
"""
|
85 |
+
features = self.image_encoder(inputs["img"])
|
86 |
+
img_embeds = features[0].permute(0, 2, 3, 1).unsqueeze_(1)
|
87 |
+
return {"features": features, "img_embeds": img_embeds}
|
88 |
+
|
89 |
+
def get_outputs(self, inputs):
|
90 |
+
"""Return the model outputs.
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
inputs : dict
|
95 |
+
The model inputs.
|
96 |
+
|
97 |
+
Returns
|
98 |
+
-------
|
99 |
+
dict
|
100 |
+
The model outputs.
|
101 |
+
|
102 |
+
"""
|
103 |
+
inputs.update(self.prompt_encoder(inputs))
|
104 |
+
return self.image_decoder(inputs)
|
105 |
+
|
106 |
+
def forward(self, inputs):
|
107 |
+
"""Define the computation performed at every call.
|
108 |
+
|
109 |
+
Parameters
|
110 |
+
----------
|
111 |
+
inputs : dict
|
112 |
+
The initial inputs.
|
113 |
+
|
114 |
+
Returns
|
115 |
+
-------
|
116 |
+
dict
|
117 |
+
The model outputs.
|
118 |
+
|
119 |
+
"""
|
120 |
+
inputs = self.get_inputs(inputs)
|
121 |
+
inputs.update(self.get_features(inputs))
|
122 |
+
return self.get_outputs(inputs)
|
123 |
+
|
124 |
+
def upscale_masks(self, masks, size):
|
125 |
+
"""Upscale masks using bilinear interpolation.
|
126 |
+
|
127 |
+
Parameters
|
128 |
+
----------
|
129 |
+
masks : torch.Tensor
|
130 |
+
The input masks.
|
131 |
+
size : Union[int, Tuple[int]]
|
132 |
+
The output size.
|
133 |
+
|
134 |
+
Returns
|
135 |
+
-------
|
136 |
+
torch.Tensor
|
137 |
+
The output masks.
|
138 |
+
|
139 |
+
"""
|
140 |
+
return nn.functional.interpolate(masks, size, mode="bilinear", align_corners=False)
|
141 |
+
|
142 |
+
@torch.inference_mode()
|
143 |
+
def predict_concept(self, visual_embeds, k=1):
|
144 |
+
"""Predict top-k concepts based on visual embeddings.
|
145 |
+
|
146 |
+
Parameters
|
147 |
+
----------
|
148 |
+
visual_embeds: torch.Tensor
|
149 |
+
The embeddings to predict visual content.
|
150 |
+
k : int, optional, default=1
|
151 |
+
The k value.
|
152 |
+
|
153 |
+
Returns
|
154 |
+
-------
|
155 |
+
Tuple[numpy.ndarray, numpy.ndarray]
|
156 |
+
The concept scores and indices.
|
157 |
+
|
158 |
+
"""
|
159 |
+
return self.concept_projector.decode(visual_embeds, k)
|
160 |
+
|
161 |
+
@torch.inference_mode()
|
162 |
+
def generate_text(self, visual_tokens, max_gen_len=None, temperature=0):
|
163 |
+
"""Generate text sequences based on visual tokens.
|
164 |
+
|
165 |
+
Parameters
|
166 |
+
----------
|
167 |
+
visual_tokens: torch.Tensor
|
168 |
+
The tokens to prompt visual context.
|
169 |
+
max_gen_len : int, optional
|
170 |
+
The maximum length of the generated text sequences.
|
171 |
+
temperature : float, optional
|
172 |
+
The temperature for controlling randomness in sampling.
|
173 |
+
|
174 |
+
Returns
|
175 |
+
-------
|
176 |
+
np.ndarray
|
177 |
+
An array of generated texts.
|
178 |
+
|
179 |
+
"""
|
180 |
+
max_gen_len = max_gen_len or self.text_decoder.max_seq_len
|
181 |
+
prompts = self.text_decoder.get_prompts(visual_tokens)
|
182 |
+
out_shape = (prompts.size(0), self.text_decoder.max_text_len)
|
183 |
+
tokens = np.full(out_shape, self.text_tokenizer.pad_id, "int64")
|
184 |
+
tokens[:, 0], prev_pos = self.text_tokenizer.bos_id, 0
|
185 |
+
eos_reached = np.array([False] * tokens.shape[0])
|
186 |
+
for cur_pos in range(1, max_gen_len):
|
187 |
+
decode_seq_len = cur_pos - prev_pos
|
188 |
+
x = torch.from_numpy(tokens[:, prev_pos:cur_pos]).to(device=prompts.device)
|
189 |
+
logits = self.text_decoder.transformer(prompts, x, prev_pos)
|
190 |
+
next_logits = logits[: x.size(0), decode_seq_len - 1]
|
191 |
+
if temperature > 0:
|
192 |
+
p = nn.functional.softmax(next_logits / temperature, dim=-1)
|
193 |
+
next_token = torch.multinomial(p, 1).cpu().numpy().flatten()
|
194 |
+
else:
|
195 |
+
next_token = next_logits.argmax(-1).cpu().numpy()
|
196 |
+
tokens[:, cur_pos] = next_token
|
197 |
+
eos_reached |= next_token == self.text_tokenizer.eos_id
|
198 |
+
prev_pos, logits, next_logits = cur_pos, None, None
|
199 |
+
if eos_reached.all():
|
200 |
+
break
|
201 |
+
return np.array(self.text_tokenizer.detokenize(tokens))
|
tokenize_anything/modeling/prompt_encoder.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Prompt encoder."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
|
22 |
+
class PromptEncoder(nn.Module):
|
23 |
+
"""Module to encode geometric prompts."""
|
24 |
+
|
25 |
+
def __init__(self, embed_dim, image_size):
|
26 |
+
super(PromptEncoder, self).__init__()
|
27 |
+
self.img_size = [image_size] * 2
|
28 |
+
self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad]
|
29 |
+
self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64)
|
30 |
+
self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2)))
|
31 |
+
self.img_pos = None
|
32 |
+
|
33 |
+
def to_tensor(self, input):
|
34 |
+
"""Convert input to tensor."""
|
35 |
+
if input is None:
|
36 |
+
return input
|
37 |
+
if not isinstance(input, torch.Tensor):
|
38 |
+
input = torch.from_numpy(input)
|
39 |
+
if input.device != self.coord_matrix.device:
|
40 |
+
input = input.to(device=self.coord_matrix.device)
|
41 |
+
return input
|
42 |
+
|
43 |
+
def to_points(self, points=None, boxes=None):
|
44 |
+
"""Convert points or boxes to point prompts."""
|
45 |
+
if points is not None:
|
46 |
+
if isinstance(points, (tuple, list)):
|
47 |
+
coords, labels = points
|
48 |
+
else:
|
49 |
+
coords, labels = points[:, :, :2], points[:, :, 2]
|
50 |
+
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
|
51 |
+
coords = self.to_tensor(coords.clip(0, 1).astype("float32"))
|
52 |
+
labels = self.to_tensor(labels.astype("int64"))
|
53 |
+
return coords, labels
|
54 |
+
if boxes is not None:
|
55 |
+
coords = boxes.reshape((-1, 2, 2))
|
56 |
+
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
|
57 |
+
coords = self.to_tensor(coords.clip(0, 1).astype("float32"))
|
58 |
+
labels = self.to_tensor(self.corner_labels)
|
59 |
+
return coords, labels
|
60 |
+
return None
|
61 |
+
|
62 |
+
def encode_coords(self, coords):
|
63 |
+
"""Return the embedding for given coords."""
|
64 |
+
pi4, pi2 = 4 * 3.1415926, 2 * 3.1415926
|
65 |
+
if self.coord_matrix.dtype != torch.float32:
|
66 |
+
self.coord_matrix = self.coord_matrix.float()
|
67 |
+
rad = coords.mul(pi4).sub_(pi2) @ self.coord_matrix
|
68 |
+
dtype = self.point_embed.weight.dtype
|
69 |
+
return torch.cat([rad.sin(), rad.cos()], dim=-1).to(dtype=dtype)
|
70 |
+
|
71 |
+
def encode_points(self, coords, labels):
|
72 |
+
"""Return the embedding for given points."""
|
73 |
+
embed = self.encode_coords(coords)
|
74 |
+
embed.mul_(labels.ne(4).unsqueeze_(-1).float().to(dtype=embed.dtype))
|
75 |
+
return embed.add_(self.point_embed(labels))
|
76 |
+
|
77 |
+
def encode_grid(self, grid_size):
|
78 |
+
"""Return the embedding for a grid of specified size."""
|
79 |
+
grid = torch.ones(*grid_size, dtype=torch.float32)
|
80 |
+
y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0])
|
81 |
+
x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1])
|
82 |
+
coords = self.to_tensor(torch.stack([x, y], dim=-1))
|
83 |
+
return self.encode_coords(coords)
|
84 |
+
|
85 |
+
def forward(self, inputs):
|
86 |
+
sparse_embeds = []
|
87 |
+
if inputs.get("boxes", None) is not None:
|
88 |
+
coords, labels = self.to_points(boxes=inputs["boxes"])
|
89 |
+
sparse_embeds.append(self.encode_points(coords, labels))
|
90 |
+
if inputs.get("points", None) is not None:
|
91 |
+
coords, labels = self.to_points(points=inputs["points"])
|
92 |
+
sparse_embeds.append(self.encode_points(coords, labels))
|
93 |
+
if len(sparse_embeds) > 1:
|
94 |
+
sparse_embeds = [torch.cat(sparse_embeds, dim=1)]
|
95 |
+
elif len(sparse_embeds) == 0:
|
96 |
+
raise ValueError("Excepted ``points`` or ``boxes`` prompts.")
|
97 |
+
img_embed_size = torch.Size(inputs["img_embeds"].shape[2:-1])
|
98 |
+
if self.img_pos is None or self.img_pos.shape[0] != img_embed_size.numel():
|
99 |
+
self.img_pos = self.encode_grid(img_embed_size).flatten(0, 1)
|
100 |
+
return {"sparse_embeds": sparse_embeds[0], "img_pos": self.img_pos}
|
tokenize_anything/modeling/text_decoder.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Text decoder."""
|
17 |
+
|
18 |
+
try:
|
19 |
+
from flash_attn import flash_attn_func
|
20 |
+
from flash_attn import flash_attn_with_kvcache
|
21 |
+
from flash_attn.layers.rotary import apply_rotary_emb
|
22 |
+
except ImportError:
|
23 |
+
flash_attn_func = None
|
24 |
+
flash_attn_with_kvcache = None
|
25 |
+
apply_rotary_emb = None
|
26 |
+
|
27 |
+
import torch
|
28 |
+
from torch import nn
|
29 |
+
|
30 |
+
|
31 |
+
class TransformerCache(nn.Module):
|
32 |
+
"""Transformer cache module."""
|
33 |
+
|
34 |
+
def __init__(self, device=None, dtype=None):
|
35 |
+
super(TransformerCache, self).__init__()
|
36 |
+
self.device = device
|
37 |
+
self.dtype = dtype
|
38 |
+
self.start_pos = 0
|
39 |
+
self.cache_dict = {}
|
40 |
+
|
41 |
+
def init_seq(self, max_batch_size):
|
42 |
+
seq_lens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device)
|
43 |
+
self.cache_dict["seq_lens"] = seq_lens
|
44 |
+
|
45 |
+
def init_rotary(self, seq_len, dim, theta=10000.0):
|
46 |
+
grid = torch.arange(seq_len, dtype=torch.float32).unsqueeze_(-1)
|
47 |
+
freq = torch.pow(theta, torch.arange(0, dim, 2)[: dim // 2].float().div_(dim))
|
48 |
+
broadcast_freq = grid.mul(freq.reciprocal_().unsqueeze_(0))
|
49 |
+
cache_cos = broadcast_freq.cos().view((-1, dim // 2))
|
50 |
+
cache_sin = broadcast_freq.sin().view((-1, dim // 2))
|
51 |
+
self.cache_dict["cos"] = cache_cos.to(self.device, self.dtype)
|
52 |
+
self.cache_dict["sin"] = cache_sin.to(self.device, self.dtype)
|
53 |
+
|
54 |
+
def init_kv(self, mixer, kv_size):
|
55 |
+
cache_k = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
|
56 |
+
cache_v = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
|
57 |
+
self.cache_dict[f"{id(mixer)}_k"] = cache_k
|
58 |
+
self.cache_dict[f"{id(mixer)}_v"] = cache_v
|
59 |
+
|
60 |
+
def set_seq(self, start_pos=0, end_pos=None):
|
61 |
+
self.start_pos = start_pos
|
62 |
+
if "seq_lens" in self.cache_dict:
|
63 |
+
self.cache_dict["seq_lens"].fill_(start_pos)
|
64 |
+
if "cos" in self.cache_dict and end_pos is not None:
|
65 |
+
self.cache_dict["seq_cos"] = self.cache_dict["cos"][self.start_pos : end_pos]
|
66 |
+
self.cache_dict["seq_sin"] = self.cache_dict["sin"][self.start_pos : end_pos]
|
67 |
+
|
68 |
+
def forward_rotary(self, q, k, inplace=False):
|
69 |
+
cos = self.cache_dict.get("seq_cos", self.cache_dict.get("cos", None))
|
70 |
+
sin = self.cache_dict.get("seq_sin", self.cache_dict.get("sin", None))
|
71 |
+
if cos is None or sin is None:
|
72 |
+
return q, k
|
73 |
+
q = apply_rotary_emb(q, cos, sin, interleaved=True, inplace=inplace)
|
74 |
+
k = apply_rotary_emb(k, cos, sin, interleaved=True, inplace=inplace)
|
75 |
+
return q, k
|
76 |
+
|
77 |
+
def forward_flash(self, mixer, q, k, v):
|
78 |
+
cache_k = self.cache_dict.get(f"{id(mixer)}_k", None)
|
79 |
+
cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
|
80 |
+
flash_args = {"softmax_scale": mixer.scale, "causal": True}
|
81 |
+
if cache_k is None or cache_v is None:
|
82 |
+
return flash_attn_func(q, k, v, **flash_args)
|
83 |
+
flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
|
84 |
+
return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)
|
85 |
+
|
86 |
+
|
87 |
+
class Attention(nn.Module):
|
88 |
+
"""Self-Attention layer."""
|
89 |
+
|
90 |
+
def __init__(self, dim, num_heads, bias=True):
|
91 |
+
super(Attention, self).__init__()
|
92 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
|
93 |
+
self.proj = nn.Linear(dim, dim, bias=bias)
|
94 |
+
self.head_dim = dim // num_heads
|
95 |
+
self.num_heads = num_heads
|
96 |
+
self.scale = self.head_dim**-0.5
|
97 |
+
self.cache = nn.Module()
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
|
101 |
+
q, k, v = self.qkv(x).view(qkv_shape).unbind(dim=2)
|
102 |
+
q, k = self.cache.forward_rotary(q, k, inplace=True)
|
103 |
+
o = self.cache.forward_flash(self, q, k, v)
|
104 |
+
return self.proj(o.flatten(2))
|
105 |
+
|
106 |
+
|
107 |
+
class MLP(nn.Module):
|
108 |
+
"""Two layers MLP."""
|
109 |
+
|
110 |
+
def __init__(self, dim, mlp_dim, bias=True):
|
111 |
+
super(MLP, self).__init__()
|
112 |
+
self.fc1 = nn.Linear(dim, mlp_dim, bias=bias)
|
113 |
+
self.fc2 = nn.Linear(mlp_dim, dim, bias=bias)
|
114 |
+
self.activation = nn.GELU()
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return self.fc2(self.activation(self.fc1(x)))
|
118 |
+
|
119 |
+
|
120 |
+
class Block(nn.Module):
|
121 |
+
"""Transformer block."""
|
122 |
+
|
123 |
+
def __init__(self, dim, num_heads, mlp_dim, bias=True):
|
124 |
+
super(Block, self).__init__()
|
125 |
+
self.attn = Attention(dim, num_heads, bias=bias)
|
126 |
+
self.mlp = MLP(dim, mlp_dim, bias=bias)
|
127 |
+
self.norm1 = nn.LayerNorm(dim)
|
128 |
+
self.norm2 = nn.LayerNorm(dim)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = self.attn(self.norm1(x)).add_(x)
|
132 |
+
return self.mlp(self.norm2(x)).add_(x)
|
133 |
+
|
134 |
+
|
135 |
+
class Transformer(nn.Module):
|
136 |
+
"""Causal transformer decoder."""
|
137 |
+
|
138 |
+
def __init__(self, depth, dim, num_heads, mlp_dim, vocab_size):
|
139 |
+
super(Transformer, self).__init__()
|
140 |
+
self.dim = dim
|
141 |
+
self.num_heads = num_heads
|
142 |
+
self.head_dim = dim // num_heads
|
143 |
+
self.vocab_size = vocab_size
|
144 |
+
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
145 |
+
self.blocks = nn.ModuleList(Block(dim, num_heads, mlp_dim) for _ in range(depth))
|
146 |
+
self.norm = nn.LayerNorm(dim)
|
147 |
+
self.text_proj = nn.Linear(dim, vocab_size, bias=False)
|
148 |
+
|
149 |
+
def forward(self, prompts, tokens, start_pos=0):
|
150 |
+
prompt_len = prompts.size(1)
|
151 |
+
start_pos = start_pos + (prompt_len if start_pos > 0 else 0)
|
152 |
+
end_pos = start_pos + tokens.size(1) + (0 if start_pos > 0 else prompt_len)
|
153 |
+
self.cache.set_seq(start_pos, end_pos)
|
154 |
+
x = self.tok_embeddings(tokens)
|
155 |
+
x = x if start_pos > 0 else torch.cat([prompts, x], dim=1)
|
156 |
+
for blk in self.blocks:
|
157 |
+
x = blk(x)
|
158 |
+
x = self.norm(x[:, 0 if start_pos > 0 else prompt_len :])
|
159 |
+
return self.text_proj(x).float()
|
160 |
+
|
161 |
+
|
162 |
+
class TextDecoder(nn.Module):
|
163 |
+
"""Module to decode texts."""
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
depth,
|
168 |
+
embed_dim,
|
169 |
+
num_heads,
|
170 |
+
mlp_ratio,
|
171 |
+
prompt_embed_dim,
|
172 |
+
max_seq_len,
|
173 |
+
vocab_size,
|
174 |
+
):
|
175 |
+
super(TextDecoder, self).__init__()
|
176 |
+
self.max_seq_len = max_seq_len
|
177 |
+
self.max_text_len = self.max_seq_len - 1
|
178 |
+
self.encoder = nn.Linear(prompt_embed_dim, embed_dim, bias=False)
|
179 |
+
self.transformer = Transformer(
|
180 |
+
depth=depth,
|
181 |
+
dim=embed_dim,
|
182 |
+
mlp_dim=embed_dim * mlp_ratio,
|
183 |
+
num_heads=num_heads,
|
184 |
+
vocab_size=vocab_size,
|
185 |
+
)
|
186 |
+
|
187 |
+
def reset_cache(self, max_batch_size=1, max_seq_len=None):
|
188 |
+
device, dtype = self.encoder.weight.device, self.encoder.weight.dtype
|
189 |
+
max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
|
190 |
+
num_heads, head_dim = self.transformer.num_heads, self.transformer.head_dim
|
191 |
+
self.transformer.cache = TransformerCache(device=device, dtype=dtype)
|
192 |
+
self.transformer.cache.init_seq(max_batch_size)
|
193 |
+
self.transformer.cache.init_rotary(max_seq_len, head_dim, theta=10000.0)
|
194 |
+
kv_cache_size = (max_batch_size, max_seq_len, num_heads, head_dim)
|
195 |
+
for blk in self.transformer.blocks:
|
196 |
+
blk.attn.__dict__["cache"] = self.transformer.cache
|
197 |
+
self.transformer.cache.init_kv(blk.attn, kv_cache_size) if not self.training else None
|
198 |
+
|
199 |
+
def get_prompts(self, prompt_tokens):
|
200 |
+
return self.encoder(prompt_tokens)
|
201 |
+
|
202 |
+
def get_outputs(self, inputs, start_pos=0):
|
203 |
+
return {"text_pred": self.transformer(inputs["prompts"], inputs["tokens"], start_pos)}
|
204 |
+
|
205 |
+
def forward(self, inputs, start_pos=0):
|
206 |
+
return self.get_outputs(inputs, start_pos)
|
tokenize_anything/modeling/text_tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
tokenize_anything/modeling/text_tokenizer.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# Source from: https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
from logging import getLogger
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from sentencepiece import SentencePieceProcessor
|
9 |
+
|
10 |
+
|
11 |
+
logger = getLogger()
|
12 |
+
|
13 |
+
|
14 |
+
class TextTokenizer:
|
15 |
+
"""Tokenizing and encoding/decoding text using SentencePiece."""
|
16 |
+
|
17 |
+
def __init__(self, model_path=None):
|
18 |
+
"""
|
19 |
+
Initializes the Tokenizer with a SentencePiece model.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model_path (str): The path to the SentencePiece model file.
|
23 |
+
"""
|
24 |
+
if model_path is None:
|
25 |
+
model_path = os.path.join(
|
26 |
+
os.path.dirname(os.path.abspath(__file__)), "text_tokenizer.model"
|
27 |
+
)
|
28 |
+
# reload tokenizer
|
29 |
+
assert os.path.isfile(model_path), model_path
|
30 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
31 |
+
logger.info(f"Reloaded SentencePiece model from {model_path}")
|
32 |
+
# BOS / EOS token IDs
|
33 |
+
self.n_words: int = self.sp_model.vocab_size()
|
34 |
+
self.bos_id: int = self.sp_model.bos_id()
|
35 |
+
self.eos_id: int = self.sp_model.eos_id()
|
36 |
+
self.pad_id: int = self.sp_model.pad_id()
|
37 |
+
self.pad_id += self.n_words if self.pad_id < 0 else 0
|
38 |
+
logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
|
39 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
40 |
+
|
41 |
+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
42 |
+
"""
|
43 |
+
Encodes a string into a list of token IDs.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
s (str): The input string to be encoded.
|
47 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
48 |
+
eos (bool): Whether to append the end-of-sequence token.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
List[int]: A list of token IDs.
|
52 |
+
"""
|
53 |
+
assert type(s) is str
|
54 |
+
t = self.sp_model.encode(s)
|
55 |
+
if bos:
|
56 |
+
t = [self.bos_id] + t
|
57 |
+
if eos:
|
58 |
+
t = t + [self.eos_id]
|
59 |
+
return t
|
60 |
+
|
61 |
+
def decode(self, t: List[int]) -> str:
|
62 |
+
"""
|
63 |
+
Decodes a list of token IDs into a string.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
t (List[int]): The list of token IDs to be decoded.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
str: The decoded string.
|
70 |
+
"""
|
71 |
+
return self.sp_model.decode(t)
|
72 |
+
|
73 |
+
def tokenize(self, texts, context_length=None):
|
74 |
+
"""Encode a list of string.
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
texts : Union[str, List[str]]
|
79 |
+
The input text(s).
|
80 |
+
context_length : int, optional
|
81 |
+
The max token length.
|
82 |
+
|
83 |
+
Returns
|
84 |
+
-------
|
85 |
+
List[List[int]]
|
86 |
+
The encoded token indices.
|
87 |
+
|
88 |
+
"""
|
89 |
+
if isinstance(texts, str):
|
90 |
+
texts = [texts]
|
91 |
+
tokens = [self.encode(text, bos=True, eos=True) for text in texts]
|
92 |
+
if context_length is None:
|
93 |
+
return tokens
|
94 |
+
truncated_tokens = []
|
95 |
+
for k, t in enumerate(tokens):
|
96 |
+
if len(t) > context_length:
|
97 |
+
t = t[:context_length]
|
98 |
+
t[-1] = self.eos_id
|
99 |
+
truncated_tokens.append(t)
|
100 |
+
return truncated_tokens
|
101 |
+
|
102 |
+
def detokenize(self, tokens):
|
103 |
+
"""Decode a list of string.
|
104 |
+
|
105 |
+
Parameters
|
106 |
+
----------
|
107 |
+
tokens : Union[List[List[int]], numpy.ndarray]
|
108 |
+
The input tokens.
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
List[str]
|
113 |
+
The decoded text strings.
|
114 |
+
|
115 |
+
"""
|
116 |
+
if hasattr(tokens, "tolist"):
|
117 |
+
tokens = tokens.tolist()
|
118 |
+
texts = []
|
119 |
+
for i in range(len(tokens)):
|
120 |
+
t = tokens[i][1:]
|
121 |
+
try:
|
122 |
+
eot_idx = t.index(self.eos_id)
|
123 |
+
t = t[:eot_idx]
|
124 |
+
except ValueError:
|
125 |
+
pass
|
126 |
+
texts.append(self.decode(t))
|
127 |
+
return texts
|
tokenize_anything/test_engine.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Engine for testing."""
|
17 |
+
|
18 |
+
import time
|
19 |
+
|
20 |
+
from tokenize_anything.build_model import model_registry
|
21 |
+
|
22 |
+
|
23 |
+
class InferenceCommand(object):
|
24 |
+
"""Command to run batched inference."""
|
25 |
+
|
26 |
+
def __init__(self, input_queue, output_queue, kwargs):
|
27 |
+
self.input_queue = input_queue
|
28 |
+
self.output_queue = output_queue
|
29 |
+
self.kwargs = kwargs
|
30 |
+
|
31 |
+
def build_env(self):
|
32 |
+
"""Build the environment."""
|
33 |
+
self.batch_size = self.kwargs.get("batch_size", 1)
|
34 |
+
self.batch_timeout = self.kwargs.get("batch_timeout", None)
|
35 |
+
|
36 |
+
def build_model(self):
|
37 |
+
"""Build and return the model."""
|
38 |
+
builder = model_registry[self.kwargs["model_type"]]
|
39 |
+
return builder(device=self.kwargs["device"], checkpoint=self.kwargs["weights"])
|
40 |
+
|
41 |
+
def build_predictor(self, model):
|
42 |
+
"""Build and return the predictor."""
|
43 |
+
return self.kwargs["predictor_type"](model, self.kwargs)
|
44 |
+
|
45 |
+
def send_results(self, predictor, indices, examples):
|
46 |
+
"""Send the inference results."""
|
47 |
+
results = predictor.get_results(examples)
|
48 |
+
if hasattr(predictor, "timers"):
|
49 |
+
time_diffs = dict((k, v.average_time) for k, v in predictor.timers.items())
|
50 |
+
for i, outputs in enumerate(results):
|
51 |
+
self.output_queue.put((indices[i], time_diffs, outputs))
|
52 |
+
else:
|
53 |
+
for i, outputs in enumerate(results):
|
54 |
+
self.output_queue.put((indices[i], outputs))
|
55 |
+
|
56 |
+
def run(self):
|
57 |
+
"""Main loop to make the inference outputs."""
|
58 |
+
self.build_env()
|
59 |
+
model = self.build_model()
|
60 |
+
predictor = self.build_predictor(model)
|
61 |
+
must_stop = False
|
62 |
+
while not must_stop:
|
63 |
+
indices, examples = [], []
|
64 |
+
deadline, timeout = None, None
|
65 |
+
for i in range(self.batch_size):
|
66 |
+
if self.batch_timeout and i == 1:
|
67 |
+
deadline = time.monotonic() + self.batch_timeout
|
68 |
+
if self.batch_timeout and i >= 1:
|
69 |
+
timeout = deadline - time.monotonic()
|
70 |
+
try:
|
71 |
+
index, example = self.input_queue.get(timeout=timeout)
|
72 |
+
if index < 0:
|
73 |
+
must_stop = True
|
74 |
+
break
|
75 |
+
indices.append(index)
|
76 |
+
examples.append(example)
|
77 |
+
except Exception:
|
78 |
+
pass
|
79 |
+
if len(examples) == 0:
|
80 |
+
continue
|
81 |
+
self.send_results(predictor, indices, examples)
|
tokenize_anything/utils/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
tokenize_anything/utils/image.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Image utilities."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import PIL.Image
|
20 |
+
|
21 |
+
|
22 |
+
def im_resize(img, size=None, scale=None, mode="linear"):
|
23 |
+
"""Resize image by the scale or size."""
|
24 |
+
if size is None:
|
25 |
+
if not isinstance(scale, (tuple, list)):
|
26 |
+
scale = (scale, scale)
|
27 |
+
h, w = img.shape[:2]
|
28 |
+
size = int(h * scale[0] + 0.5), int(w * scale[1] + 0.5)
|
29 |
+
else:
|
30 |
+
if not isinstance(size, (tuple, list)):
|
31 |
+
size = (size, size)
|
32 |
+
resize_modes = {"linear": PIL.Image.BILINEAR}
|
33 |
+
img = PIL.Image.fromarray(img)
|
34 |
+
return np.array(img.resize(size[::-1], resize_modes[mode]))
|
35 |
+
|
36 |
+
|
37 |
+
def im_rescale(img, scales, max_size=0):
|
38 |
+
"""Rescale image to match the detecting scales."""
|
39 |
+
im_shape = img.shape
|
40 |
+
img_list, img_scales = [], []
|
41 |
+
size_min = np.min(im_shape[:2])
|
42 |
+
size_max = np.max(im_shape[:2])
|
43 |
+
for target_size in scales:
|
44 |
+
im_scale = float(target_size) / float(size_min)
|
45 |
+
target_size_max = max_size if max_size > 0 else target_size
|
46 |
+
if np.round(im_scale * size_max) > target_size_max:
|
47 |
+
im_scale = float(target_size_max) / float(size_max)
|
48 |
+
img_list.append(im_resize(img, scale=im_scale))
|
49 |
+
img_scales.append((im_scale, im_scale))
|
50 |
+
return img_list, img_scales
|
51 |
+
|
52 |
+
|
53 |
+
def im_vstack(arrays, fill_value=None, dtype=None, size=None, align=None):
|
54 |
+
"""Stack image arrays in sequence vertically."""
|
55 |
+
if fill_value is None:
|
56 |
+
return np.vstack(arrays)
|
57 |
+
# Compute the max stack shape.
|
58 |
+
max_shape = np.max(np.stack([arr.shape for arr in arrays]), 0)
|
59 |
+
if size is not None and min(size) > 0:
|
60 |
+
max_shape[: len(size)] = size
|
61 |
+
if align is not None and min(align) > 0:
|
62 |
+
align_size = np.ceil(max_shape[: len(align)] / align)
|
63 |
+
max_shape[: len(align)] = align_size.astype("int64") * align
|
64 |
+
# Fill output with the given value.
|
65 |
+
output_dtype = dtype or arrays[0].dtype
|
66 |
+
output_shape = [len(arrays)] + list(max_shape)
|
67 |
+
output = np.empty(output_shape, output_dtype)
|
68 |
+
output[:] = fill_value
|
69 |
+
# Copy arrays.
|
70 |
+
for i, arr in enumerate(arrays):
|
71 |
+
copy_slices = (slice(0, d) for d in arr.shape)
|
72 |
+
output[(i,) + tuple(copy_slices)] = arr
|
73 |
+
return output
|
tokenize_anything/utils/mask.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Mask utilities."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from pycocotools.mask import encode
|
20 |
+
|
21 |
+
|
22 |
+
def mask_to_box(mask):
|
23 |
+
"""Convert binary masks to boxes."""
|
24 |
+
shape, (h, w) = mask.shape, mask.shape[-2:]
|
25 |
+
masks = mask.reshape((-1, h, w)).astype("bool")
|
26 |
+
in_height = np.max(masks, axis=-1)
|
27 |
+
in_width = np.max(masks, axis=-2)
|
28 |
+
in_height_coords = in_height * np.arange(h, dtype="int32")
|
29 |
+
in_width_coords = in_width * np.arange(w, dtype="int32")
|
30 |
+
bottom_edges = np.max(in_height_coords, axis=-1)
|
31 |
+
top_edges = np.min(in_height_coords + h * (~in_height), axis=-1)
|
32 |
+
right_edges = np.max(in_width_coords, axis=-1)
|
33 |
+
left_edges = np.min(in_width_coords + w * (~in_width), axis=-1)
|
34 |
+
is_empty = (right_edges < left_edges) | (bottom_edges < top_edges)
|
35 |
+
boxes = np.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
|
36 |
+
boxes = boxes.astype("float32") * ((~is_empty)[:, None])
|
37 |
+
return boxes.reshape(*shape[:-2], 4) if len(shape) > 2 else boxes[0]
|
38 |
+
|
39 |
+
|
40 |
+
def encode_masks(masks):
|
41 |
+
"""Encode a set of masks to RLEs."""
|
42 |
+
rles = encode(np.asfortranarray(masks))
|
43 |
+
for rle in rles:
|
44 |
+
rle["counts"] = rle["counts"].decode()
|
45 |
+
return rles
|
tokenize_anything/utils/timer.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Timing functions."""
|
17 |
+
|
18 |
+
import contextlib
|
19 |
+
import time
|
20 |
+
|
21 |
+
|
22 |
+
class Timer(object):
|
23 |
+
"""Simple timer."""
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
self.total_time = 0.0
|
27 |
+
self.calls = 0
|
28 |
+
self.start_time = 0.0
|
29 |
+
self.diff = 0.0
|
30 |
+
self.average_time = 0.0
|
31 |
+
|
32 |
+
def add_diff(self, diff, n=1, average=True):
|
33 |
+
self.total_time += diff
|
34 |
+
self.calls += n
|
35 |
+
self.average_time = self.total_time / self.calls
|
36 |
+
return self.average_time if average else self.diff
|
37 |
+
|
38 |
+
@contextlib.contextmanager
|
39 |
+
def tic_and_toc(self, n=1, average=True):
|
40 |
+
try:
|
41 |
+
yield self.tic()
|
42 |
+
finally:
|
43 |
+
self.toc(n, average)
|
44 |
+
|
45 |
+
def tic(self):
|
46 |
+
self.start_time = time.time()
|
47 |
+
return self
|
48 |
+
|
49 |
+
def toc(self, n=1, average=True):
|
50 |
+
self.diff = time.time() - self.start_time
|
51 |
+
return self.add_diff(self.diff, n, average)
|
tokenize_anything/version.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version = "0.1.0a0"
|
2 |
+
git_version = "None"
|
3 |
+
__version__ = version
|