|
import pprint |
|
import numpy as np |
|
import torch |
|
from third_party.cgdetr.utils.basic_utils import load_jsonl |
|
from third_party.cgdetr.standalone_eval.eval import eval_submission |
|
from tqdm import tqdm |
|
|
|
|
|
class PostProcessorDETR: |
|
def __init__(self, clip_length=2, min_ts_val=0, max_ts_val=150, |
|
min_w_l=2, max_w_l=70, move_window_method="center", |
|
process_func_names=("clip_window_l", "clip_ts", "round_multiple")): |
|
self.clip_length = clip_length |
|
self.min_ts_val = min_ts_val |
|
self.max_ts_val = max_ts_val |
|
self.min_w_l = min_w_l |
|
self.max_w_l = max_w_l |
|
self.move_window_method = move_window_method |
|
self.process_func_names = process_func_names |
|
self.name2func = dict( |
|
clip_ts=self.clip_min_max_timestamps, |
|
round_multiple=self.round_to_multiple_clip_lengths, |
|
clip_window_l=self.clip_window_lengths |
|
) |
|
|
|
def __call__(self, lines): |
|
processed_lines = [] |
|
for line in tqdm(lines, desc=f"convert to multiples of clip_length={self.clip_length}"): |
|
windows_and_scores = torch.tensor(line["pred_relevant_windows"]) |
|
windows = windows_and_scores[:, :2] |
|
for func_name in self.process_func_names: |
|
windows = self.name2func[func_name](windows) |
|
line["pred_relevant_windows"] = torch.cat( |
|
[windows, windows_and_scores[:, 2:3]], dim=1).tolist() |
|
line["pred_relevant_windows"] = [e[:2] + [float(f"{e[2]:.4f}")] for e in line["pred_relevant_windows"]] |
|
processed_lines.append(line) |
|
return processed_lines |
|
|
|
def clip_min_max_timestamps(self, windows): |
|
""" |
|
windows: (#windows, 2) torch.Tensor |
|
ensure timestamps for all windows is within [min_val, max_val], clip is out of boundaries. |
|
""" |
|
return torch.clamp(windows, min=self.min_ts_val, max=self.max_ts_val) |
|
|
|
def round_to_multiple_clip_lengths(self, windows): |
|
""" |
|
windows: (#windows, 2) torch.Tensor |
|
ensure the final window timestamps are multiples of `clip_length` |
|
""" |
|
return torch.round(windows / self.clip_length) * self.clip_length |
|
|
|
def clip_window_lengths(self, windows): |
|
""" |
|
windows: (#windows, 2) np.ndarray |
|
ensure the final window duration are within [self.min_w_l, self.max_w_l] |
|
""" |
|
window_lengths = windows[:, 1] - windows[:, 0] |
|
small_rows = window_lengths < self.min_w_l |
|
if torch.sum(small_rows) > 0: |
|
windows = self.move_windows( |
|
windows, small_rows, self.min_w_l, move_method=self.move_window_method) |
|
large_rows = window_lengths > self.max_w_l |
|
if torch.sum(large_rows) > 0: |
|
windows = self.move_windows( |
|
windows, large_rows, self.max_w_l, move_method=self.move_window_method) |
|
return windows |
|
|
|
@classmethod |
|
def move_windows(cls, windows, row_selector, new_length, move_method="left"): |
|
""" |
|
Args: |
|
windows: |
|
row_selector: |
|
new_length: |
|
move_method: str, |
|
left: keep left unchanged |
|
center: keep center unchanged |
|
right: keep right unchanged |
|
|
|
Returns: |
|
|
|
""" |
|
|
|
|
|
if move_method == "left": |
|
windows[row_selector, 1] = windows[row_selector, 0] + new_length |
|
elif move_method == "right": |
|
windows[row_selector, 0] = windows[row_selector, 1] - new_length |
|
elif move_method == "center": |
|
center = (windows[row_selector, 1] + windows[row_selector, 0]) / 2. |
|
windows[row_selector, 0] = center - new_length / 2. |
|
windows[row_selector, 1] = center + new_length / 2. |
|
return windows |
|
|
|
|