ryanzhangfan
commited on
add support for batch multimodal understanding
Browse files- processing_emu3.py +48 -5
processing_emu3.py
CHANGED
@@ -14,12 +14,14 @@
|
|
14 |
# limitations under the License.
|
15 |
""" Processor class for Emu3. """
|
16 |
|
|
|
17 |
import re
|
18 |
from typing import List, Optional, Sequence, Union
|
19 |
from functools import partial
|
20 |
|
21 |
from PIL import Image
|
22 |
import torch
|
|
|
23 |
from transformers.feature_extraction_utils import BatchFeature
|
24 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
25 |
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
@@ -73,6 +75,7 @@ class Emu3Processor(ProcessorMixin):
|
|
73 |
self.vision_tokenizer = vision_tokenizer
|
74 |
self.prefix_template = prefix_template
|
75 |
self.visual_template = visual_template
|
|
|
76 |
|
77 |
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
78 |
self.const_helper = self.build_const_helper()
|
@@ -86,6 +89,7 @@ class Emu3Processor(ProcessorMixin):
|
|
86 |
mode: str = "G",
|
87 |
ratio: str | List[str] = "1:1",
|
88 |
image_area: int = 518400,
|
|
|
89 |
**kwargs,
|
90 |
) -> BatchFeature:
|
91 |
"""
|
@@ -106,6 +110,8 @@ class Emu3Processor(ProcessorMixin):
|
|
106 |
the image width-height ratio for generation
|
107 |
image_area (`int`, *optional*):
|
108 |
image area used to calcualte the generated image height and width
|
|
|
|
|
109 |
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
110 |
If set, will return tensors of a particular framework. Acceptable values are:
|
111 |
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
@@ -121,10 +127,13 @@ class Emu3Processor(ProcessorMixin):
|
|
121 |
if isinstance(text, str):
|
122 |
text = [text]
|
123 |
|
|
|
|
|
|
|
124 |
if not isinstance(text[0], str):
|
125 |
raise ValueError("`text` must be string or list of string")
|
126 |
|
127 |
-
|
128 |
if mode == 'G':
|
129 |
if image is not None:
|
130 |
raise ValueError("You have to specify only `text` in generation mode")
|
@@ -144,10 +153,7 @@ class Emu3Processor(ProcessorMixin):
|
|
144 |
if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
|
145 |
raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
|
146 |
|
147 |
-
|
148 |
-
image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
149 |
-
image_tokens = self.vision_tokenizer.encode(image_inputs)
|
150 |
-
|
151 |
if len(text) != len(image_tokens):
|
152 |
raise ValueError("number of image must match number of text prompt")
|
153 |
|
@@ -254,6 +260,43 @@ class Emu3Processor(ProcessorMixin):
|
|
254 |
tw = int(round(w * target_ratio / spatial_scale_factor))
|
255 |
return th, tw
|
256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
def build_const_helper(self):
|
258 |
(
|
259 |
img_token,
|
|
|
14 |
# limitations under the License.
|
15 |
""" Processor class for Emu3. """
|
16 |
|
17 |
+
from math import ceil
|
18 |
import re
|
19 |
from typing import List, Optional, Sequence, Union
|
20 |
from functools import partial
|
21 |
|
22 |
from PIL import Image
|
23 |
import torch
|
24 |
+
from torch.nn import functional as F
|
25 |
from transformers.feature_extraction_utils import BatchFeature
|
26 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
27 |
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
|
|
75 |
self.vision_tokenizer = vision_tokenizer
|
76 |
self.prefix_template = prefix_template
|
77 |
self.visual_template = visual_template
|
78 |
+
self.vis_tok_spatial_factor = 2 ** (len(self.vision_tokenizer.config.ch_mult) - 1)
|
79 |
|
80 |
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
81 |
self.const_helper = self.build_const_helper()
|
|
|
89 |
mode: str = "G",
|
90 |
ratio: str | List[str] = "1:1",
|
91 |
image_area: int = 518400,
|
92 |
+
padding_image: bool = False,
|
93 |
**kwargs,
|
94 |
) -> BatchFeature:
|
95 |
"""
|
|
|
110 |
the image width-height ratio for generation
|
111 |
image_area (`int`, *optional*):
|
112 |
image area used to calcualte the generated image height and width
|
113 |
+
padding_image (`bool`, *optional*):
|
114 |
+
whether pad images to same size for fast preprocessing if they have different sizes
|
115 |
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
116 |
If set, will return tensors of a particular framework. Acceptable values are:
|
117 |
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
|
127 |
if isinstance(text, str):
|
128 |
text = [text]
|
129 |
|
130 |
+
if isinstance(image, Image.Image):
|
131 |
+
image = [image]
|
132 |
+
|
133 |
if not isinstance(text[0], str):
|
134 |
raise ValueError("`text` must be string or list of string")
|
135 |
|
136 |
+
image_tokens = None
|
137 |
if mode == 'G':
|
138 |
if image is not None:
|
139 |
raise ValueError("You have to specify only `text` in generation mode")
|
|
|
153 |
if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
|
154 |
raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
|
155 |
|
156 |
+
image_tokens = self.tokenize_image(image, padding_image=padding_image)
|
|
|
|
|
|
|
157 |
if len(text) != len(image_tokens):
|
158 |
raise ValueError("number of image must match number of text prompt")
|
159 |
|
|
|
260 |
tw = int(round(w * target_ratio / spatial_scale_factor))
|
261 |
return th, tw
|
262 |
|
263 |
+
def tokenize_image(self, image: List[Image.Image], *, padding_image: bool = False):
|
264 |
+
is_all_same_size, prev_size = True, None
|
265 |
+
for im in image:
|
266 |
+
if prev_size is not None:
|
267 |
+
is_all_same_size &= (prev_size == im.size)
|
268 |
+
prev_size = im.size
|
269 |
+
|
270 |
+
if is_all_same_size:
|
271 |
+
image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"]
|
272 |
+
image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
273 |
+
image_tokens = self.vision_tokenizer.encode(image_inputs)
|
274 |
+
elif padding_image:
|
275 |
+
image_inputs = [self.image_processor(im, return_tensors="pt")["pixel_values"] for im in image]
|
276 |
+
image_shapes = [im.shape[2:] for im in image_inputs]
|
277 |
+
max_shape = (
|
278 |
+
max([im_shape[0] for im_shape in image_shapes]),
|
279 |
+
max([im_shape[1] for im_shape in image_shapes]),
|
280 |
+
)
|
281 |
+
image_inputs = [
|
282 |
+
F.pad(im_inp, (0, max_shape[1] - im_shape[1], 0, max_shape[0] - im_shape[0]))
|
283 |
+
for im_inp, im_shape in zip(image_inputs, image_shapes)
|
284 |
+
]
|
285 |
+
image_inputs = torch.cat(image_inputs, dim=0).to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
286 |
+
image_tokens = self.vision_tokenizer.encode(image_inputs)
|
287 |
+
image_tokens = [
|
288 |
+
im_tok[:ceil(im_shape[0] / self.vis_tok_spatial_factor), :ceil(im_shape[1] / self.vis_tok_spatial_factor)]
|
289 |
+
for im_tok, im_shape in zip(image_tokens, image_shapes)
|
290 |
+
]
|
291 |
+
else:
|
292 |
+
image_tokens = []
|
293 |
+
for im in image:
|
294 |
+
image_input = self.image_processor(im, return_tensors="pt")["pixel_values"]
|
295 |
+
image_input = image_input.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
296 |
+
image_tokens.append(self.vision_tokenizer.encode(image_input).squeeze(0))
|
297 |
+
|
298 |
+
return image_tokens
|
299 |
+
|
300 |
def build_const_helper(self):
|
301 |
(
|
302 |
img_token,
|