Spaces:
Running
on
Zero
Running
on
Zero
from enum import Enum | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from transformers import PreTrainedModel | |
from typing import List, Optional | |
IGNORE_INDEX = -100 | |
IMAGE_TOKEN_INDEX = -200 | |
DEFAULT_EOS_TOKEN = '</s>' | |
DEFAULT_BOS_TOKEN = '<s>' | |
DEFAULT_UNK_TOKEN = '<unk>' | |
DEFAULT_IMAGE_TOKEN = "<image>" | |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" | |
DEFAULT_IM_START_TOKEN = "<im_start>" | |
DEFAULT_IM_END_TOKEN = "<im_end>" | |
DEFAULT_BBOX_TOKEN = "<bbox>" | |
# Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501 | |
def prepare_inputs_labels_for_multimodal( | |
llm: PreTrainedModel, | |
input_ids: torch.LongTensor = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
labels: Optional[torch.LongTensor] = None, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
**kwargs): | |
if pixel_values is None: | |
kwargs.update({ | |
'input_ids': input_ids, | |
'position_ids': position_ids, | |
'attention_mask': attention_mask, | |
'past_key_values': past_key_values, | |
'inputs_embeds': None, | |
'labels': labels | |
}) | |
return kwargs | |
_labels = labels | |
_position_ids = position_ids | |
_attention_mask = attention_mask | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) | |
else: | |
attention_mask = attention_mask.bool() | |
if position_ids is None: | |
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) | |
if labels is None: | |
labels = torch.full_like(input_ids, IGNORE_INDEX) | |
# remove the padding using attention_mask -- TODO: double check | |
input_ids = [ | |
cur_input_ids[cur_attention_mask] | |
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) | |
] | |
labels = [ | |
cur_labels[cur_attention_mask] | |
for cur_labels, cur_attention_mask in zip(labels, attention_mask) | |
] | |
new_inputs_embeds = [] | |
new_labels = [] | |
new_input_ids = [] | |
cur_image_idx = 0 | |
for batch_idx, cur_input_ids in enumerate(input_ids): | |
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() | |
if num_images == 0: | |
cur_pixel_values = pixel_values[cur_image_idx] | |
cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) | |
cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) | |
new_inputs_embeds.append(cur_inputs_embeds) | |
new_labels.append(labels[batch_idx]) | |
new_input_ids.append(cur_input_ids) | |
cur_image_idx += 1 | |
continue | |
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] | |
cur_input_ids_noim = [] | |
cur_labels = labels[batch_idx] | |
cur_labels_noim = [] | |
for i in range(len(image_token_indices) - 1): | |
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]]) | |
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]]) | |
split_sizes = [x.shape[0] for x in cur_labels_noim] | |
cur_inputs_embeds = llm.get_input_embeddings()(torch.cat(cur_input_ids_noim)) | |
cur_inputs_embeds_no_im = torch.split(cur_inputs_embeds, split_sizes, dim=0) | |
cur_new_inputs_embeds = [] | |
cur_new_labels = [] | |
cur_new_input_ids = [] | |
for i in range(num_images + 1): | |
cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) | |
cur_new_labels.append(cur_labels_noim[i]) | |
cur_new_input_ids.append(cur_input_ids_noim[i]) | |
if i < num_images: | |
cur_pixel_values = pixel_values[cur_image_idx] | |
cur_image_idx += 1 | |
cur_new_inputs_embeds.append(cur_pixel_values) | |
cur_new_labels.append(torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) | |
cur_new_input_ids.append(torch.full((cur_pixel_values.shape[0], ), IMAGE_TOKEN_INDEX, device=cur_input_ids.device, dtype=cur_input_ids.dtype)) | |
cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) | |
cur_new_labels = torch.cat(cur_new_labels) | |
cur_new_input_ids = torch.cat(cur_new_input_ids) | |
new_inputs_embeds.append(cur_new_inputs_embeds) | |
new_labels.append(cur_new_labels) | |
new_input_ids.append(cur_new_input_ids) | |
# Combine them | |
max_len = max(x.shape[0] for x in new_inputs_embeds) | |
batch_size = len(new_inputs_embeds) | |
new_inputs_embeds_padded = [] | |
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) | |
new_input_ids_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device) | |
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) | |
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) | |
for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_inputs_embeds, new_labels, new_input_ids)): | |
cur_len = cur_new_embed.shape[0] | |
new_inputs_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) | |
if cur_len > 0: | |
new_labels_padded[i, :cur_len] = cur_new_labels | |
new_input_ids_padded[i, :cur_len] = cur_new_input_ids | |
attention_mask[i, :cur_len] = True | |
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) | |
new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) | |
if _labels is None: | |
new_labels = None | |
else: | |
new_labels = new_labels_padded | |
new_input_ids = new_input_ids_padded | |
if _attention_mask is None: | |
attention_mask = None | |
else: | |
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) | |
if _position_ids is None: | |
position_ids = None | |
kwargs.update({ | |
'input_ids': None, | |
'position_ids': position_ids, | |
'attention_mask': attention_mask, | |
'past_key_values': past_key_values, | |
'inputs_embeds': new_inputs_embeds, | |
'labels': new_labels, | |
'new_input_ids': new_input_ids | |
}) | |
return kwargs | |
class Summary(Enum): | |
NONE = 0 | |
AVERAGE = 1 | |
SUM = 2 | |
COUNT = 3 | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): | |
self.name = name | |
self.fmt = fmt | |
self.summary_type = summary_type | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def all_reduce(self): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if isinstance(self.sum, np.ndarray): | |
total = torch.tensor( | |
self.sum.tolist() | |
+ [ | |
self.count, | |
], | |
dtype=torch.float32, | |
device=device, | |
) | |
else: | |
total = torch.tensor( | |
[self.sum, self.count], dtype=torch.float32, device=device | |
) | |
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) | |
if total.shape[0] > 2: | |
self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() | |
else: | |
self.sum, self.count = total.tolist() | |
self.avg = self.sum / (self.count + 1e-5) | |
def __str__(self): | |
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" | |
return fmtstr.format(**self.__dict__) | |
def summary(self): | |
fmtstr = "" | |
if self.summary_type is Summary.NONE: | |
fmtstr = "" | |
elif self.summary_type is Summary.AVERAGE: | |
fmtstr = "{name} {avg:.3f}" | |
elif self.summary_type is Summary.SUM: | |
fmtstr = "{name} {sum:.3f}" | |
elif self.summary_type is Summary.COUNT: | |
fmtstr = "{name} {count:.3f}" | |
else: | |
raise ValueError("invalid summary type %r" % self.summary_type) | |
return fmtstr.format(**self.__dict__) | |
def intersectionAndUnionGPU(output, target, K, ignore_index=255): | |
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. | |
assert output.dim() in [1, 2, 3] | |
assert output.shape == target.shape | |
output = output.view(-1) | |
target = target.view(-1) | |
output[target == ignore_index] = ignore_index | |
intersection = output[output == target] | |
area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) | |
area_output = torch.histc(output, bins=K, min=0, max=K - 1) | |
area_target = torch.histc(target, bins=K, min=0, max=K - 1) | |
area_union = area_output + area_target - area_intersection | |
return area_intersection, area_union, area_target | |
class ProgressMeter(object): | |
def __init__(self, num_batches, meters, prefix=""): | |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
self.meters = meters | |
self.prefix = prefix | |
def display(self, batch): | |
entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
entries += [str(meter) for meter in self.meters] | |
print("\t".join(entries)) | |
def display_summary(self): | |
entries = [" *"] | |
entries += [meter.summary() for meter in self.meters] | |
print(" ".join(entries)) | |
def _get_batch_fmtstr(self, num_batches): | |
num_digits = len(str(num_batches // 1)) | |
fmt = "{:" + str(num_digits) + "d}" | |
return "[" + fmt + "/" + fmt.format(num_batches) + "]" | |
def dict_to_cuda(input_dict): | |
for k, v in input_dict.items(): | |
if isinstance(input_dict[k], torch.Tensor): | |
input_dict[k] = v.cuda(non_blocking=True) | |
elif isinstance(v, list) and len(v) > 0: | |
input_dict[k] = [ele.cuda(non_blocking=True) if isinstance(ele, torch.Tensor) else ele for ele in v] | |
return input_dict | |