Xiangtai commited on
Commit
e7e9303
·
verified ·
1 Parent(s): 30715cb

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_sa2va_chat.py +4 -3
modeling_sa2va_chat.py CHANGED
@@ -689,6 +689,7 @@ class Sa2VAChatModel(PreTrainedModel):
689
  input=text, round=1, bot_name=self.bot_name)
690
  input_text = past_text + input_text
691
  ids = self.tokenizer.encode(input_text)
 
692
  ids = torch.tensor(ids).cuda().unsqueeze(0)
693
 
694
  attention_mask = torch.ones_like(ids, dtype=torch.bool)
@@ -715,7 +716,8 @@ class Sa2VAChatModel(PreTrainedModel):
715
  )
716
  predict = self.tokenizer.decode(
717
  generate_output.sequences[0], skip_special_tokens=False).strip()
718
-
 
719
  # if have seg result, find the seg hidden states
720
  hidden_states = generate_output.hidden_states
721
  last_hidden_states = [item[-1][0] for item in hidden_states]
@@ -737,8 +739,7 @@ class Sa2VAChatModel(PreTrainedModel):
737
  masks = masks.sigmoid() > 0.5
738
  masks = masks.cpu().numpy()
739
  ret_masks.append(masks)
740
-
741
- return {'prediction': predict, 'prediction_masks': ret_masks,}
742
 
743
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
744
  seg_mask = output_ids == seg_id
 
689
  input=text, round=1, bot_name=self.bot_name)
690
  input_text = past_text + input_text
691
  ids = self.tokenizer.encode(input_text)
692
+ ret_past_text = self.tokenizer.decode(ids)
693
  ids = torch.tensor(ids).cuda().unsqueeze(0)
694
 
695
  attention_mask = torch.ones_like(ids, dtype=torch.bool)
 
716
  )
717
  predict = self.tokenizer.decode(
718
  generate_output.sequences[0], skip_special_tokens=False).strip()
719
+ ret_past_text = ret_past_text + self.tokenizer.decode(
720
+ generate_output.sequences[0], skip_special_tokens=False)
721
  # if have seg result, find the seg hidden states
722
  hidden_states = generate_output.hidden_states
723
  last_hidden_states = [item[-1][0] for item in hidden_states]
 
739
  masks = masks.sigmoid() > 0.5
740
  masks = masks.cpu().numpy()
741
  ret_masks.append(masks)
742
+ return {'prediction': predict, 'prediction_masks': ret_masks, "past_text": ret_past_text}
 
743
 
744
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
745
  seg_mask = output_ids == seg_id