BAAI
/

ryanzhangfan commited on
Commit
c059b33
·
verified ·
1 Parent(s): dabb3aa

add support for batch multimodal understanding

Browse files
Files changed (1) hide show
  1. processing_emu3.py +48 -5
processing_emu3.py CHANGED
@@ -14,12 +14,14 @@
14
  # limitations under the License.
15
  """ Processor class for Emu3. """
16
 
 
17
  import re
18
  from typing import List, Optional, Sequence, Union
19
  from functools import partial
20
 
21
  from PIL import Image
22
  import torch
 
23
  from transformers.feature_extraction_utils import BatchFeature
24
  from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
25
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
@@ -73,6 +75,7 @@ class Emu3Processor(ProcessorMixin):
73
  self.vision_tokenizer = vision_tokenizer
74
  self.prefix_template = prefix_template
75
  self.visual_template = visual_template
 
76
 
77
  super().__init__(image_processor, tokenizer, chat_template=chat_template)
78
  self.const_helper = self.build_const_helper()
@@ -86,6 +89,7 @@ class Emu3Processor(ProcessorMixin):
86
  mode: str = "G",
87
  ratio: str | List[str] = "1:1",
88
  image_area: int = 518400,
 
89
  **kwargs,
90
  ) -> BatchFeature:
91
  """
@@ -106,6 +110,8 @@ class Emu3Processor(ProcessorMixin):
106
  the image width-height ratio for generation
107
  image_area (`int`, *optional*):
108
  image area used to calcualte the generated image height and width
 
 
109
  return_tensors (`str` or [`~utils.TensorType`], *optional*):
110
  If set, will return tensors of a particular framework. Acceptable values are:
111
  - `'pt'`: Return PyTorch `torch.Tensor` objects.
@@ -121,10 +127,13 @@ class Emu3Processor(ProcessorMixin):
121
  if isinstance(text, str):
122
  text = [text]
123
 
 
 
 
124
  if not isinstance(text[0], str):
125
  raise ValueError("`text` must be string or list of string")
126
 
127
- image_inputs = None
128
  if mode == 'G':
129
  if image is not None:
130
  raise ValueError("You have to specify only `text` in generation mode")
@@ -144,10 +153,7 @@ class Emu3Processor(ProcessorMixin):
144
  if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
145
  raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
146
 
147
- image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"]
148
- image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
149
- image_tokens = self.vision_tokenizer.encode(image_inputs)
150
-
151
  if len(text) != len(image_tokens):
152
  raise ValueError("number of image must match number of text prompt")
153
 
@@ -254,6 +260,43 @@ class Emu3Processor(ProcessorMixin):
254
  tw = int(round(w * target_ratio / spatial_scale_factor))
255
  return th, tw
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def build_const_helper(self):
258
  (
259
  img_token,
 
14
  # limitations under the License.
15
  """ Processor class for Emu3. """
16
 
17
+ from math import ceil
18
  import re
19
  from typing import List, Optional, Sequence, Union
20
  from functools import partial
21
 
22
  from PIL import Image
23
  import torch
24
+ from torch.nn import functional as F
25
  from transformers.feature_extraction_utils import BatchFeature
26
  from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
27
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
 
75
  self.vision_tokenizer = vision_tokenizer
76
  self.prefix_template = prefix_template
77
  self.visual_template = visual_template
78
+ self.vis_tok_spatial_factor = 2 ** (len(self.vision_tokenizer.config.ch_mult) - 1)
79
 
80
  super().__init__(image_processor, tokenizer, chat_template=chat_template)
81
  self.const_helper = self.build_const_helper()
 
89
  mode: str = "G",
90
  ratio: str | List[str] = "1:1",
91
  image_area: int = 518400,
92
+ padding_image: bool = False,
93
  **kwargs,
94
  ) -> BatchFeature:
95
  """
 
110
  the image width-height ratio for generation
111
  image_area (`int`, *optional*):
112
  image area used to calcualte the generated image height and width
113
+ padding_image (`bool`, *optional*):
114
+ whether pad images to same size for fast preprocessing if they have different sizes
115
  return_tensors (`str` or [`~utils.TensorType`], *optional*):
116
  If set, will return tensors of a particular framework. Acceptable values are:
117
  - `'pt'`: Return PyTorch `torch.Tensor` objects.
 
127
  if isinstance(text, str):
128
  text = [text]
129
 
130
+ if isinstance(image, Image.Image):
131
+ image = [image]
132
+
133
  if not isinstance(text[0], str):
134
  raise ValueError("`text` must be string or list of string")
135
 
136
+ image_tokens = None
137
  if mode == 'G':
138
  if image is not None:
139
  raise ValueError("You have to specify only `text` in generation mode")
 
153
  if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
154
  raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
155
 
156
+ image_tokens = self.tokenize_image(image, padding_image=padding_image)
 
 
 
157
  if len(text) != len(image_tokens):
158
  raise ValueError("number of image must match number of text prompt")
159
 
 
260
  tw = int(round(w * target_ratio / spatial_scale_factor))
261
  return th, tw
262
 
263
+ def tokenize_image(self, image: List[Image.Image], *, padding_image: bool = False):
264
+ is_all_same_size, prev_size = True, None
265
+ for im in image:
266
+ if prev_size is not None:
267
+ is_all_same_size &= (prev_size == im.size)
268
+ prev_size = im.size
269
+
270
+ if is_all_same_size:
271
+ image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"]
272
+ image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
273
+ image_tokens = self.vision_tokenizer.encode(image_inputs)
274
+ elif padding_image:
275
+ image_inputs = [self.image_processor(im, return_tensors="pt")["pixel_values"] for im in image]
276
+ image_shapes = [im.shape[2:] for im in image_inputs]
277
+ max_shape = (
278
+ max([im_shape[0] for im_shape in image_shapes]),
279
+ max([im_shape[1] for im_shape in image_shapes]),
280
+ )
281
+ image_inputs = [
282
+ F.pad(im_inp, (0, max_shape[1] - im_shape[1], 0, max_shape[0] - im_shape[0]))
283
+ for im_inp, im_shape in zip(image_inputs, image_shapes)
284
+ ]
285
+ image_inputs = torch.cat(image_inputs, dim=0).to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
286
+ image_tokens = self.vision_tokenizer.encode(image_inputs)
287
+ image_tokens = [
288
+ im_tok[:ceil(im_shape[0] / self.vis_tok_spatial_factor), :ceil(im_shape[1] / self.vis_tok_spatial_factor)]
289
+ for im_tok, im_shape in zip(image_tokens, image_shapes)
290
+ ]
291
+ else:
292
+ image_tokens = []
293
+ for im in image:
294
+ image_input = self.image_processor(im, return_tensors="pt")["pixel_values"]
295
+ image_input = image_input.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
296
+ image_tokens.append(self.vision_tokenizer.encode(image_input).squeeze(0))
297
+
298
+ return image_tokens
299
+
300
  def build_const_helper(self):
301
  (
302
  img_token,