Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import Optional | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from lpips import LPIPS | |
from PIL import Image | |
from torchvision.transforms import Normalize | |
def show_images_horizontally( | |
list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False | |
) -> None: | |
""" | |
Visualize the list of images horizontally and save the figure as PNG. | |
Args: | |
list_of_files: The list of images as numpy array with shape (N, H, W, C). | |
output_file: The output file path to save the figure as PNG. | |
interact: Whether to show the figure interactively in Jupyter Notebook or not in Python. | |
""" | |
number_of_files = len(list_of_files) | |
heights = [a[0].shape[0] for a in list_of_files] | |
widths = [a.shape[1] for a in list_of_files[0]] | |
fig_width = 8.0 # inches | |
fig_height = fig_width * sum(heights) / sum(widths) | |
# Create a figure with subplots | |
_, axs = plt.subplots( | |
1, number_of_files, figsize=(fig_width * number_of_files, fig_height) | |
) | |
plt.tight_layout() | |
for i in range(number_of_files): | |
_image = list_of_files[i] | |
axs[i].imshow(_image) | |
axs[i].axis("off") | |
# Save the figure as PNG | |
if interact: | |
plt.show() | |
else: | |
plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25) | |
def image_grids(images, rows=None, cols=None): | |
if not images: | |
raise ValueError("The image list is empty.") | |
n_images = len(images) | |
if cols is None: | |
cols = int(n_images**0.5) | |
if rows is None: | |
rows = (n_images + cols - 1) // cols | |
width, height = images[0].size | |
grid_width = cols * width | |
grid_height = rows * height | |
grid_image = Image.new("RGB", (grid_width, grid_height)) | |
for i, image in enumerate(images): | |
row, col = divmod(i, cols) | |
grid_image.paste(image, (col * width, row * height)) | |
return grid_image | |
def save_image(image: np.array, file_name: str) -> None: | |
""" | |
Save the image as JPG. | |
Args: | |
image: The input image as numpy array with shape (H, W, C). | |
file_name: The file name to save the image. | |
""" | |
image = Image.fromarray(image) | |
image.save(file_name) | |
def load_and_process_images(load_dir: str) -> np.array: | |
""" | |
Load and process the images into numpy array from the directory. | |
Args: | |
load_dir: The directory to load the images. | |
Returns: | |
images: The images as numpy array with shape (N, H, W, C). | |
""" | |
images = [] | |
print(load_dir) | |
filenames = sorted( | |
os.listdir(load_dir), key=lambda x: int(x.split(".")[0]) | |
) # Ensure the files are sorted numerically | |
for filename in filenames: | |
if filename.endswith(".jpg"): | |
img = Image.open(os.path.join(load_dir, filename)) | |
img_array = ( | |
np.asarray(img) / 255.0 | |
) # Convert to numpy array and scale pixel values to [0, 1] | |
images.append(img_array) | |
return images | |
def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array: | |
""" | |
Compute the LPIPS of the input images. | |
Args: | |
images: The input images as numpy array with shape (N, H, W, C). | |
lpips_model: The LPIPS model used to compute perceptual distances. | |
Returns: | |
distances: The LPIPS of the input images. | |
""" | |
# Get device of lpips_model | |
device = next(lpips_model.parameters()).device | |
device = str(device) | |
# Change the input images into tensor | |
images = torch.tensor(images).to(device).float() | |
images = torch.permute(images, (0, 3, 1, 2)) | |
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
images = normalize(images) | |
# Compute the LPIPS between each adjacent input images | |
distances = [] | |
for i in range(images.shape[0]): | |
if i == images.shape[0] - 1: | |
break | |
img1 = images[i].unsqueeze(0) | |
img2 = images[i + 1].unsqueeze(0) | |
loss = lpips_model(img1, img2) | |
distances.append(loss.item()) | |
distances = np.array(distances) | |
return distances | |
def compute_gini(distances: np.array) -> float: | |
""" | |
Compute the Gini index of the input distances. | |
Args: | |
distances: The input distances as numpy array. | |
Returns: | |
gini: The Gini index of the input distances. | |
""" | |
if len(distances) < 2: | |
return 0.0 # Gini index is 0 for less than two elements | |
# Sort the list of distances | |
sorted_distances = sorted(distances) | |
n = len(sorted_distances) | |
mean_distance = sum(sorted_distances) / n | |
# Compute the sum of absolute differences | |
sum_of_differences = 0 | |
for di in sorted_distances: | |
for dj in sorted_distances: | |
sum_of_differences += abs(di - dj) | |
# Normalize the sum of differences by the mean and the number of elements | |
gini = sum_of_differences / (2 * n * n * mean_distance) | |
return gini | |
def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple: | |
""" | |
Compute the smoothness and efficiency of the input images. | |
Args: | |
images: The input images as numpy array with shape (N, H, W, C). | |
lpips_model: The LPIPS model used to compute perceptual distances. | |
Returns: | |
smoothness: One minus gini index of LPIPS of consecutive images. | |
consistency: The mean LPIPS of consecutive images. | |
max_inception_distance: The maximum LPIPS of consecutive images. | |
""" | |
distances = compute_lpips(images, lpips_model) | |
smoothness = 1 - compute_gini(distances) | |
consistency = np.mean(distances) | |
max_inception_distance = np.max(distances) | |
return smoothness, consistency, max_inception_distance | |
def separate_source_and_interpolated_images(images: np.array) -> tuple: | |
""" | |
Separate the input images into source and interpolated images. | |
The input source is the start and end of the images, while the interpolated images are the rest. | |
Args: | |
images: The input images as numpy array with shape (N, H, W, C). | |
Returns: | |
source: The source images as numpy array with shape (2, H, W, C). | |
interpolation: The interpolated images as numpy array with shape (N-2, H, W, C). | |
""" | |
# Check if the array has at least two elements | |
if len(images) < 2: | |
raise ValueError("The input array should have at least two elements.") | |
# Separate the array into two parts | |
# First part takes the first and last element | |
source = np.array([images[0], images[-1]]) | |
# Second part takes the rest of the elements | |
interpolation = images[1:-1] | |
return source, interpolation | |