HarborYuan
commited on
Commit
·
d45e0b2
1
Parent(s):
1ea9622
add video
Browse files- modeling_sa2va_chat.py +56 -23
- sam2.py +2 -2
modeling_sa2va_chat.py
CHANGED
@@ -485,6 +485,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
485 |
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
|
486 |
vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
|
487 |
i_vp_img += 1
|
|
|
488 |
vp_embeds = torch.cat(vp_embeds, dim=0)
|
489 |
else:
|
490 |
vp_embeds = None
|
@@ -583,6 +584,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
583 |
def predict_forward(
|
584 |
self,
|
585 |
image=None,
|
|
|
586 |
text=None,
|
587 |
past_text='',
|
588 |
mask_prompts=None,
|
@@ -593,29 +595,57 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
593 |
self.preparing_for_generation(tokenizer=tokenizer)
|
594 |
|
595 |
input_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
|
597 |
-
|
|
|
|
|
598 |
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
|
|
|
|
|
|
604 |
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
|
615 |
-
|
616 |
-
|
|
|
|
|
|
|
617 |
input_dict['pixel_values'] = pixel_values
|
618 |
-
|
619 |
|
620 |
if mask_prompts is not None:
|
621 |
# reshape mask prompts to feature size
|
@@ -627,7 +657,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
627 |
mode='nearest').squeeze(0) for item in mask_prompts]
|
628 |
region_pixels = []
|
629 |
for mask_prompt in mask_prompts[0]:
|
630 |
-
region_pixels.append(mask_prompt.to(torch.int64).sum())
|
631 |
|
632 |
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
633 |
for i in range(len(mask_prompts[0])):
|
@@ -645,6 +675,9 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
645 |
image_token_str = f'{self.IMG_START_TOKEN}' \
|
646 |
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
647 |
f'{self.IMG_END_TOKEN}'
|
|
|
|
|
|
|
648 |
|
649 |
ret_masks = []
|
650 |
|
@@ -695,16 +728,14 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
695 |
|
696 |
for seg_hidden_states in all_seg_hidden_states:
|
697 |
seg_hidden_states = seg_hidden_states.unsqueeze(0)
|
698 |
-
g_pixel_values =
|
699 |
-
self.grounding_encoder.preprocess_image(pixel, dtype=self.torch_dtype)
|
700 |
-
for pixel in [input_dict['g_pixel_values']]])
|
701 |
sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values)
|
702 |
-
pred_masks = self.grounding_encoder.
|
703 |
w, h = ori_image_size
|
704 |
masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False)
|
705 |
masks = masks[:, 0]
|
706 |
masks = masks.sigmoid() > 0.5
|
707 |
-
masks = masks.
|
708 |
ret_masks.append(masks)
|
709 |
|
710 |
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
@@ -712,6 +743,8 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
712 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
713 |
seg_mask = output_ids == seg_id
|
714 |
n_out = len(seg_mask)
|
|
|
|
|
715 |
return hidden_states[-n_out:][seg_mask]
|
716 |
|
717 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
|
|
485 |
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
|
486 |
vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
|
487 |
i_vp_img += 1
|
488 |
+
|
489 |
vp_embeds = torch.cat(vp_embeds, dim=0)
|
490 |
else:
|
491 |
vp_embeds = None
|
|
|
584 |
def predict_forward(
|
585 |
self,
|
586 |
image=None,
|
587 |
+
video=None,
|
588 |
text=None,
|
589 |
past_text='',
|
590 |
mask_prompts=None,
|
|
|
595 |
self.preparing_for_generation(tokenizer=tokenizer)
|
596 |
|
597 |
input_dict = {}
|
598 |
+
if video is not None:
|
599 |
+
pixel_values = []
|
600 |
+
extra_pixel_values = []
|
601 |
+
ori_image_size = video[0].size
|
602 |
+
for frame_idx, frame_image in enumerate(video):
|
603 |
+
assert ori_image_size == frame_image.size
|
604 |
+
g_image = np.array(frame_image) # for grounding
|
605 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
606 |
+
g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
607 |
+
extra_pixel_values.append(g_image)
|
608 |
+
if frame_idx < 5:
|
609 |
+
img = self.transformer(frame_image)
|
610 |
+
pixel_values.append(img)
|
611 |
+
|
612 |
+
pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
|
613 |
+
g_pixel_values = torch.stack([
|
614 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
615 |
+
]).to(self.torch_dtype)
|
616 |
+
num_image_tokens = self.patch_token
|
617 |
+
num_frames = 5
|
618 |
|
619 |
+
input_dict['vp_overall_mask'] = None
|
620 |
+
else:
|
621 |
+
ori_image_size = image.size
|
622 |
|
623 |
+
# prepare grounding images
|
624 |
+
g_image = np.array(image) # for grounding
|
625 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
626 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
|
627 |
+
extra_pixel_values = [g_pixel_values]
|
628 |
+
g_pixel_values = torch.stack([
|
629 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
630 |
+
]).to(self.torch_dtype)
|
631 |
|
632 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
633 |
+
self.max_dynamic_patch,
|
634 |
+
self.image_size, self.use_thumbnail)
|
635 |
|
636 |
+
if mask_prompts is not None:
|
637 |
+
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
|
638 |
+
input_dict['vp_overall_mask'] = vp_overall_mask
|
639 |
+
else:
|
640 |
+
input_dict['vp_overall_mask'] = None
|
641 |
|
642 |
+
pixel_values = [self.transformer(image) for image in images]
|
643 |
+
pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
|
644 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
645 |
+
num_frames = 1
|
646 |
+
input_dict['g_pixel_values'] = g_pixel_values
|
647 |
input_dict['pixel_values'] = pixel_values
|
648 |
+
|
649 |
|
650 |
if mask_prompts is not None:
|
651 |
# reshape mask prompts to feature size
|
|
|
657 |
mode='nearest').squeeze(0) for item in mask_prompts]
|
658 |
region_pixels = []
|
659 |
for mask_prompt in mask_prompts[0]:
|
660 |
+
region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
|
661 |
|
662 |
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
663 |
for i in range(len(mask_prompts[0])):
|
|
|
675 |
image_token_str = f'{self.IMG_START_TOKEN}' \
|
676 |
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
677 |
f'{self.IMG_END_TOKEN}'
|
678 |
+
image_token_str = image_token_str + '\n'
|
679 |
+
image_token_str = image_token_str * num_frames
|
680 |
+
image_token_str = image_token_str.strip()
|
681 |
|
682 |
ret_masks = []
|
683 |
|
|
|
728 |
|
729 |
for seg_hidden_states in all_seg_hidden_states:
|
730 |
seg_hidden_states = seg_hidden_states.unsqueeze(0)
|
731 |
+
g_pixel_values = input_dict['g_pixel_values']
|
|
|
|
|
732 |
sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values)
|
733 |
+
pred_masks = self.grounding_encoder.language_embd_inference(sam_states, [seg_hidden_states] * num_frames)
|
734 |
w, h = ori_image_size
|
735 |
masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False)
|
736 |
masks = masks[:, 0]
|
737 |
masks = masks.sigmoid() > 0.5
|
738 |
+
masks = masks.cpu().numpy()
|
739 |
ret_masks.append(masks)
|
740 |
|
741 |
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
|
|
743 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
744 |
seg_mask = output_ids == seg_id
|
745 |
n_out = len(seg_mask)
|
746 |
+
if n_out == 0:
|
747 |
+
return hidden_states[0:0]
|
748 |
return hidden_states[-n_out:][seg_mask]
|
749 |
|
750 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
sam2.py
CHANGED
@@ -623,8 +623,8 @@ class CXBlock(nn.Module):
|
|
623 |
x = self.pwconv1(x)
|
624 |
x = self.act(x)
|
625 |
x = self.pwconv2(x)
|
626 |
-
if self.
|
627 |
-
x = self.
|
628 |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
629 |
|
630 |
x = input + self.drop_path(x)
|
|
|
623 |
x = self.pwconv1(x)
|
624 |
x = self.act(x)
|
625 |
x = self.pwconv2(x)
|
626 |
+
if self.g_weight is not None:
|
627 |
+
x = self.g_weight * x
|
628 |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
629 |
|
630 |
x = input + self.drop_path(x)
|