|
"""Core part of LaDeco v2
|
|
|
|
Example usage:
|
|
>>> from core import Ladeco
|
|
>>> from PIL import Image
|
|
>>> from pathlib import Path
|
|
>>>
|
|
>>> # predict
|
|
>>> ldc = Ladeco()
|
|
>>> imgs = (thing for thing in Path("example").glob("*.jpg"))
|
|
>>> out = ldc.predict(imgs)
|
|
>>>
|
|
>>> # output - visualization
|
|
>>> segs = out.visualize(level=2)
|
|
>>> segs[0].image.show()
|
|
>>>
|
|
>>> # output - element area
|
|
>>> area = out.area()
|
|
>>> area[0]
|
|
{"fid": "example/.jpg", "l1_nature": 0.673, "l1_man_made": 0.241, ...}
|
|
"""
|
|
from matplotlib.patches import Rectangle
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
from transformers import AutoModelForUniversalSegmentation, AutoProcessor
|
|
import math
|
|
import matplotlib as mpl
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from functools import lru_cache
|
|
from matplotlib.figure import Figure
|
|
import numpy.typing as npt
|
|
from typing import Iterable, NamedTuple, Generator
|
|
from tqdm import tqdm
|
|
|
|
|
|
class LadecoVisualization(NamedTuple):
|
|
filename: str
|
|
image: Figure
|
|
|
|
|
|
class Ladeco:
|
|
|
|
def __init__(self,
|
|
model_name: str = "shi-labs/oneformer_ade20k_swin_large",
|
|
area_threshold: float = 0.01,
|
|
device: str | None = None,
|
|
):
|
|
if device is None:
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
else:
|
|
self.device = device
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
self.model = AutoModelForUniversalSegmentation.from_pretrained(model_name).to(self.device)
|
|
|
|
self.area_threshold = area_threshold
|
|
|
|
self.ade20k_labels = {
|
|
name.strip(): int(idx)
|
|
for name, idx in self.model.config.label2id.items()
|
|
}
|
|
self.ladeco2ade20k: dict[str, tuple[int]] = _get_ladeco_labels(self.ade20k_labels)
|
|
|
|
def predict(
|
|
self, image_paths: str | Path | Iterable[str | Path], show_progress: bool = False
|
|
) -> "LadecoOutput":
|
|
if isinstance(image_paths, (str, Path)):
|
|
imgpaths = [image_paths]
|
|
else:
|
|
imgpaths = list(image_paths)
|
|
|
|
images = (
|
|
Image.open(img_path).convert("RGB")
|
|
for img_path in imgpaths
|
|
)
|
|
|
|
|
|
masks: list[torch.Tensor] = []
|
|
for img in tqdm(images, total=len(imgpaths), desc="Segmenting", disable=not show_progress):
|
|
samples = self.processor(
|
|
images=img, task_inputs=["semantic"], return_tensors="pt"
|
|
).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(**samples)
|
|
|
|
masks.append(
|
|
self.processor.post_process_semantic_segmentation(outputs)[0]
|
|
)
|
|
|
|
return LadecoOutput(imgpaths, masks, self.ladeco2ade20k, self.area_threshold)
|
|
|
|
|
|
class LadecoOutput:
|
|
|
|
def __init__(
|
|
self,
|
|
filenames: list[str | Path],
|
|
masks: torch.Tensor,
|
|
ladeco2ade: dict[str, tuple[int]],
|
|
threshold: float,
|
|
):
|
|
self.filenames = filenames
|
|
self.masks = masks
|
|
self.ladeco2ade: dict[str, tuple[int]] = ladeco2ade
|
|
self.ade2ladeco: dict[int, str] = {
|
|
idx: label
|
|
for label, indices in self.ladeco2ade.items()
|
|
for idx in indices
|
|
}
|
|
self.threshold = threshold
|
|
|
|
def visualize(self, level: int) -> list[LadecoVisualization]:
|
|
return list(self.ivisualize(level))
|
|
|
|
def ivisualize(self, level: int) -> Generator[LadecoVisualization, None, None]:
|
|
colormaps = self.color_map(level)
|
|
labelnames = [name for name in self.ladeco2ade if name.startswith(f"l{level}")]
|
|
|
|
for fname, mask in zip(self.filenames, self.masks):
|
|
size = mask.shape + (3,)
|
|
vis = torch.zeros(size, dtype=torch.uint8)
|
|
for name in labelnames:
|
|
for idx in self.ladeco2ade[name]:
|
|
color = torch.tensor(colormaps[name] * 255, dtype=torch.uint8)
|
|
vis[mask == idx] = color
|
|
|
|
with Image.open(fname) as img:
|
|
target_size = img.size
|
|
vis = Image.fromarray(vis.numpy(), mode="RGB").resize(target_size)
|
|
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(vis)
|
|
ax.axis('off')
|
|
|
|
yield LadecoVisualization(filename=str(fname), image=fig)
|
|
|
|
def area(self) -> list[dict[str, float | str]]:
|
|
return list(self.iarea())
|
|
|
|
def iarea(self) -> Generator[dict[str, float | str], None, None]:
|
|
n_label_ADE20k = 150
|
|
for filename, mask in zip(self.filenames, self.masks):
|
|
ade_ratios = torch.tensor([(mask == i).count_nonzero() / mask.numel() for i in range(n_label_ADE20k)])
|
|
|
|
ldc_ratios: dict[str, float] = {
|
|
label: round(ade_ratios[list(ade_indices)].sum().item(), 4)
|
|
for label, ade_indices in self.ladeco2ade.items()
|
|
}
|
|
ldc_ratios: dict[str, float] = {
|
|
label: 0 if ratio < self.threshold else ratio
|
|
for label, ratio in ldc_ratios.items()
|
|
}
|
|
others = round(1 - ldc_ratios["l1_nature"] - ldc_ratios["l1_man_made"], 4)
|
|
nfi = round(ldc_ratios["l1_nature"]/ (ldc_ratios["l1_nature"] + ldc_ratios.get("l1_man_made", 0) + 1e-6), 4)
|
|
|
|
yield {
|
|
"fid": str(filename), **ldc_ratios, "others": others, "LC_NFI": nfi,
|
|
}
|
|
|
|
def color_map(self, level: int) -> dict[str, npt.NDArray[np.float64]]:
|
|
"returns {'label_name': (R, G, B), ...}, where (R, G, B) in range [0, 1]"
|
|
labels = [
|
|
name for name in self.ladeco2ade.keys() if name.startswith(f"l{level}")
|
|
]
|
|
if len(labels) == 0:
|
|
raise RuntimeError(
|
|
f"LaDeco only has 4 levels in 1, 2, 3, 4. You assigned {level}."
|
|
)
|
|
colormap = mpl.colormaps["viridis"].resampled(len(labels)).colors[:, :-1]
|
|
|
|
return {name: color for name, color in zip(labels, colormap)}
|
|
|
|
def color_legend(self, level: int) -> Figure:
|
|
colors = self.color_map(level)
|
|
|
|
match level:
|
|
case 1:
|
|
ncols = 1
|
|
case 2:
|
|
ncols = 1
|
|
case 3:
|
|
ncols = 2
|
|
case 4:
|
|
ncols = 5
|
|
|
|
cell_width = 212
|
|
cell_height = 22
|
|
swatch_width = 48
|
|
margin = 12
|
|
|
|
nrows = math.ceil(len(colors) / ncols)
|
|
|
|
width = cell_width * ncols + 2 * margin
|
|
height = cell_height * nrows + 2 * margin
|
|
dpi = 72
|
|
|
|
fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
|
|
fig.subplots_adjust(margin/width, margin/height,
|
|
(width-margin)/width, (height-margin*2)/height)
|
|
ax.set_xlim(0, cell_width * ncols)
|
|
ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
|
|
ax.yaxis.set_visible(False)
|
|
ax.xaxis.set_visible(False)
|
|
ax.set_axis_off()
|
|
|
|
for i, name in enumerate(colors):
|
|
row = i % nrows
|
|
col = i // nrows
|
|
y = row * cell_height
|
|
|
|
swatch_start_x = cell_width * col
|
|
text_pos_x = cell_width * col + swatch_width + 7
|
|
|
|
ax.text(text_pos_x, y, name, fontsize=14,
|
|
horizontalalignment='left',
|
|
verticalalignment='center')
|
|
|
|
ax.add_patch(
|
|
Rectangle(xy=(swatch_start_x, y-9), width=swatch_width,
|
|
height=18, facecolor=colors[name], edgecolor='0.7')
|
|
)
|
|
|
|
ax.set_title(f"LaDeco Color Legend - Level {level}")
|
|
|
|
return fig
|
|
|
|
|
|
def _get_ladeco_labels(ade20k: dict[str, int]) -> dict[str, tuple[int]]:
|
|
labels = {
|
|
|
|
|
|
"l4_hovel": (ade20k["hovel, hut, hutch, shack, shanty"],),
|
|
"l4_building": (ade20k["building"], ade20k["house"]),
|
|
"l4_skyscraper": (ade20k["skyscraper"],),
|
|
"l4_tower": (ade20k["tower"],),
|
|
|
|
"l4_step": (ade20k["step, stair"],),
|
|
"l4_canopy": (ade20k["awning, sunshade, sunblind"], ade20k["canopy"]),
|
|
"l4_arcade": (ade20k["arcade machine"],),
|
|
"l4_door": (ade20k["door"],),
|
|
"l4_window": (ade20k["window"],),
|
|
"l4_wall": (ade20k["wall"],),
|
|
|
|
"l4_stairway": (ade20k["stairway, staircase"],),
|
|
"l4_sidewalk": (ade20k["sidewalk, pavement"],),
|
|
"l4_road": (ade20k["road, route"],),
|
|
|
|
"l4_sculpture": (ade20k["sculpture"],),
|
|
"l4_flag": (ade20k["flag"],),
|
|
"l4_can": (ade20k["trash can"],),
|
|
"l4_chair": (ade20k["chair"],),
|
|
"l4_pot": (ade20k["pot"],),
|
|
"l4_booth": (ade20k["booth"],),
|
|
"l4_streetlight": (ade20k["street lamp"],),
|
|
"l4_bench": (ade20k["bench"],),
|
|
"l4_fence": (ade20k["fence"],),
|
|
"l4_table": (ade20k["table"],),
|
|
|
|
"l4_bike": (ade20k["bicycle"],),
|
|
"l4_motorbike": (ade20k["minibike, motorbike"],),
|
|
"l4_van": (ade20k["van"],),
|
|
"l4_truck": (ade20k["truck"],),
|
|
"l4_bus": (ade20k["bus"],),
|
|
"l4_car": (ade20k["car"],),
|
|
|
|
"l4_traffic_sign": (ade20k["traffic light"],),
|
|
"l4_poster": (ade20k["poster, posting, placard, notice, bill, card"],),
|
|
"l4_signboard": (ade20k["signboard, sign"],),
|
|
|
|
"l4_rock": (ade20k["rock, stone"],),
|
|
"l4_hill": (ade20k["hill"],),
|
|
"l4_mountain": (ade20k["mountain, mount"],),
|
|
|
|
"l4_ground": (ade20k["earth, ground"], ade20k["land, ground, soil"]),
|
|
"l4_field": (ade20k["field"],),
|
|
"l4_sand": (ade20k["sand"],),
|
|
"l4_dirt": (ade20k["dirt track"],),
|
|
"l4_path": (ade20k["path"],),
|
|
|
|
"l4_flower": (ade20k["flower"],),
|
|
|
|
"l4_grass": (ade20k["grass"],),
|
|
|
|
"l4_flora": (ade20k["plant"],),
|
|
|
|
"l4_tree": (ade20k["tree"],),
|
|
"l4_palm": (ade20k["palm, palm tree"],),
|
|
|
|
"l4_lake": (ade20k["lake"],),
|
|
"l4_pool": (ade20k["pool"],),
|
|
"l4_river": (ade20k["river"],),
|
|
"l4_sea": (ade20k["sea"],),
|
|
"l4_water": (ade20k["water"],),
|
|
|
|
"l4_fountain": (ade20k["fountain"],),
|
|
"l4_waterfall": (ade20k["falls"],),
|
|
|
|
"l4_person": (ade20k["person"],),
|
|
|
|
"l4_animal": (ade20k["animal"],),
|
|
|
|
"l4_sky": (ade20k["sky"],),
|
|
}
|
|
labels = labels | {
|
|
|
|
|
|
"l3_hori_land": labels["l4_ground"] + labels["l4_field"] + labels["l4_sand"] + labels["l4_dirt"] + labels["l4_path"],
|
|
"l3_vert_land": labels["l4_mountain"] + labels["l4_hill"] + labels["l4_rock"],
|
|
|
|
"l3_woody_plant": labels["l4_tree"] + labels["l4_palm"] + labels["l4_flora"],
|
|
"l3_herb_plant": labels["l4_grass"],
|
|
"l3_flower": labels["l4_flower"],
|
|
|
|
"l3_hori_water": labels["l4_water"] + labels["l4_sea"] + labels["l4_river"] + labels["l4_pool"] + labels["l4_lake"],
|
|
"l3_vert_water": labels["l4_fountain"] + labels["l4_waterfall"],
|
|
|
|
"l3_human": labels["l4_person"],
|
|
"l3_animal": labels["l4_animal"],
|
|
|
|
"l3_sky": labels["l4_sky"],
|
|
|
|
"l3_architecture": labels["l4_building"] + labels["l4_hovel"] + labels["l4_tower"] + labels["l4_skyscraper"],
|
|
"l3_archi_parts": labels["l4_wall"] + labels["l4_window"] + labels["l4_door"] + labels["l4_arcade"] + labels["l4_canopy"] + labels["l4_step"],
|
|
|
|
"l3_roadway": labels["l4_road"] + labels["l4_sidewalk"] + labels["l4_stairway"],
|
|
"l3_furniture": labels["l4_table"] + labels["l4_chair"] + labels["l4_fence"] + labels["l4_bench"] + labels["l4_streetlight"] + labels["l4_booth"] + labels["l4_pot"] + labels["l4_can"] + labels["l4_flag"] + labels["l4_sculpture"],
|
|
"l3_vehicle": labels["l4_car"] + labels["l4_bus"] + labels["l4_truck"] + labels["l4_van"] + labels["l4_motorbike"] + labels["l4_bike"],
|
|
"l3_sign": labels["l4_signboard"] + labels["l4_poster"] + labels["l4_traffic_sign"],
|
|
}
|
|
labels = labels | {
|
|
|
|
|
|
"l2_landform": labels["l3_hori_land"] + labels["l3_vert_land"],
|
|
"l2_vegetation": labels["l3_woody_plant"] + labels["l3_herb_plant"] + labels["l3_flower"],
|
|
"l2_water": labels["l3_hori_water"] + labels["l3_vert_water"],
|
|
"l2_bio": labels["l3_human"] + labels["l3_animal"],
|
|
"l2_sky": labels["l3_sky"],
|
|
|
|
"l2_archi": labels["l3_architecture"] + labels["l3_archi_parts"],
|
|
"l2_street": labels["l3_roadway"] + labels["l3_furniture"] + labels["l3_vehicle"] + labels["l3_sign"],
|
|
}
|
|
labels = labels | {
|
|
|
|
"l1_nature": labels["l2_landform"] + labels["l2_vegetation"] + labels["l2_water"] + labels["l2_bio"] + labels["l2_sky"],
|
|
"l1_man_made": labels["l2_archi"] + labels["l2_street"],
|
|
}
|
|
return labels
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ldc = Ladeco()
|
|
image = Path("images") / "canyon_3011_00002354.jpg"
|
|
out = ldc.predict(image)
|
|
|