HarborYuan commited on
Commit
d45e0b2
·
1 Parent(s): 1ea9622
Files changed (2) hide show
  1. modeling_sa2va_chat.py +56 -23
  2. sam2.py +2 -2
modeling_sa2va_chat.py CHANGED
@@ -485,6 +485,7 @@ class Sa2VAChatModel(PreTrainedModel):
485
  objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
486
  vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
487
  i_vp_img += 1
 
488
  vp_embeds = torch.cat(vp_embeds, dim=0)
489
  else:
490
  vp_embeds = None
@@ -583,6 +584,7 @@ class Sa2VAChatModel(PreTrainedModel):
583
  def predict_forward(
584
  self,
585
  image=None,
 
586
  text=None,
587
  past_text='',
588
  mask_prompts=None,
@@ -593,29 +595,57 @@ class Sa2VAChatModel(PreTrainedModel):
593
  self.preparing_for_generation(tokenizer=tokenizer)
594
 
595
  input_dict = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
- ori_image_size = image.size
 
 
598
 
599
- # prepare grounding images
600
- g_image = np.array(image) # for grounding
601
- g_image = self.extra_image_processor.apply_image(g_image)
602
- g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
603
- input_dict['g_pixel_values'] = g_pixel_values
 
 
 
604
 
605
- images = dynamic_preprocess(image, self.min_dynamic_patch,
606
- self.max_dynamic_patch,
607
- self.image_size, self.use_thumbnail)
608
 
609
- if mask_prompts is not None:
610
- vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
611
- input_dict['vp_overall_mask'] = vp_overall_mask
612
- else:
613
- input_dict['vp_overall_mask'] = None
614
 
615
- pixel_values = [self.transformer(image) for image in images]
616
- pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
 
 
 
617
  input_dict['pixel_values'] = pixel_values
618
- num_image_tokens = pixel_values.shape[0] * self.patch_token
619
 
620
  if mask_prompts is not None:
621
  # reshape mask prompts to feature size
@@ -627,7 +657,7 @@ class Sa2VAChatModel(PreTrainedModel):
627
  mode='nearest').squeeze(0) for item in mask_prompts]
628
  region_pixels = []
629
  for mask_prompt in mask_prompts[0]:
630
- region_pixels.append(mask_prompt.to(torch.int64).sum())
631
 
632
  vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
633
  for i in range(len(mask_prompts[0])):
@@ -645,6 +675,9 @@ class Sa2VAChatModel(PreTrainedModel):
645
  image_token_str = f'{self.IMG_START_TOKEN}' \
646
  f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
647
  f'{self.IMG_END_TOKEN}'
 
 
 
648
 
649
  ret_masks = []
650
 
@@ -695,16 +728,14 @@ class Sa2VAChatModel(PreTrainedModel):
695
 
696
  for seg_hidden_states in all_seg_hidden_states:
697
  seg_hidden_states = seg_hidden_states.unsqueeze(0)
698
- g_pixel_values = torch.stack([
699
- self.grounding_encoder.preprocess_image(pixel, dtype=self.torch_dtype)
700
- for pixel in [input_dict['g_pixel_values']]])
701
  sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values)
702
- pred_masks = self.grounding_encoder.inject_language_embd(sam_states, [seg_hidden_states])
703
  w, h = ori_image_size
704
  masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False)
705
  masks = masks[:, 0]
706
  masks = masks.sigmoid() > 0.5
707
- masks = masks.int().cpu()
708
  ret_masks.append(masks)
709
 
710
  return {'prediction': predict, 'prediction_masks': ret_masks,}
@@ -712,6 +743,8 @@ class Sa2VAChatModel(PreTrainedModel):
712
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
713
  seg_mask = output_ids == seg_id
714
  n_out = len(seg_mask)
 
 
715
  return hidden_states[-n_out:][seg_mask]
716
 
717
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
 
485
  objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
486
  vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
487
  i_vp_img += 1
488
+
489
  vp_embeds = torch.cat(vp_embeds, dim=0)
490
  else:
491
  vp_embeds = None
 
584
  def predict_forward(
585
  self,
586
  image=None,
587
+ video=None,
588
  text=None,
589
  past_text='',
590
  mask_prompts=None,
 
595
  self.preparing_for_generation(tokenizer=tokenizer)
596
 
597
  input_dict = {}
598
+ if video is not None:
599
+ pixel_values = []
600
+ extra_pixel_values = []
601
+ ori_image_size = video[0].size
602
+ for frame_idx, frame_image in enumerate(video):
603
+ assert ori_image_size == frame_image.size
604
+ g_image = np.array(frame_image) # for grounding
605
+ g_image = self.extra_image_processor.apply_image(g_image)
606
+ g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
607
+ extra_pixel_values.append(g_image)
608
+ if frame_idx < 5:
609
+ img = self.transformer(frame_image)
610
+ pixel_values.append(img)
611
+
612
+ pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
613
+ g_pixel_values = torch.stack([
614
+ self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
615
+ ]).to(self.torch_dtype)
616
+ num_image_tokens = self.patch_token
617
+ num_frames = 5
618
 
619
+ input_dict['vp_overall_mask'] = None
620
+ else:
621
+ ori_image_size = image.size
622
 
623
+ # prepare grounding images
624
+ g_image = np.array(image) # for grounding
625
+ g_image = self.extra_image_processor.apply_image(g_image)
626
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
627
+ extra_pixel_values = [g_pixel_values]
628
+ g_pixel_values = torch.stack([
629
+ self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
630
+ ]).to(self.torch_dtype)
631
 
632
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
633
+ self.max_dynamic_patch,
634
+ self.image_size, self.use_thumbnail)
635
 
636
+ if mask_prompts is not None:
637
+ vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
638
+ input_dict['vp_overall_mask'] = vp_overall_mask
639
+ else:
640
+ input_dict['vp_overall_mask'] = None
641
 
642
+ pixel_values = [self.transformer(image) for image in images]
643
+ pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
644
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
645
+ num_frames = 1
646
+ input_dict['g_pixel_values'] = g_pixel_values
647
  input_dict['pixel_values'] = pixel_values
648
+
649
 
650
  if mask_prompts is not None:
651
  # reshape mask prompts to feature size
 
657
  mode='nearest').squeeze(0) for item in mask_prompts]
658
  region_pixels = []
659
  for mask_prompt in mask_prompts[0]:
660
+ region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
661
 
662
  vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
663
  for i in range(len(mask_prompts[0])):
 
675
  image_token_str = f'{self.IMG_START_TOKEN}' \
676
  f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
677
  f'{self.IMG_END_TOKEN}'
678
+ image_token_str = image_token_str + '\n'
679
+ image_token_str = image_token_str * num_frames
680
+ image_token_str = image_token_str.strip()
681
 
682
  ret_masks = []
683
 
 
728
 
729
  for seg_hidden_states in all_seg_hidden_states:
730
  seg_hidden_states = seg_hidden_states.unsqueeze(0)
731
+ g_pixel_values = input_dict['g_pixel_values']
 
 
732
  sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values)
733
+ pred_masks = self.grounding_encoder.language_embd_inference(sam_states, [seg_hidden_states] * num_frames)
734
  w, h = ori_image_size
735
  masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False)
736
  masks = masks[:, 0]
737
  masks = masks.sigmoid() > 0.5
738
+ masks = masks.cpu().numpy()
739
  ret_masks.append(masks)
740
 
741
  return {'prediction': predict, 'prediction_masks': ret_masks,}
 
743
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
744
  seg_mask = output_ids == seg_id
745
  n_out = len(seg_mask)
746
+ if n_out == 0:
747
+ return hidden_states[0:0]
748
  return hidden_states[-n_out:][seg_mask]
749
 
750
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
sam2.py CHANGED
@@ -623,8 +623,8 @@ class CXBlock(nn.Module):
623
  x = self.pwconv1(x)
624
  x = self.act(x)
625
  x = self.pwconv2(x)
626
- if self.gamma is not None:
627
- x = self.gamma * x
628
  x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
629
 
630
  x = input + self.drop_path(x)
 
623
  x = self.pwconv1(x)
624
  x = self.act(x)
625
  x = self.pwconv2(x)
626
+ if self.g_weight is not None:
627
+ x = self.g_weight * x
628
  x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
629
 
630
  x = input + self.drop_path(x)