fffiloni commited on
Commit
691af46
·
verified ·
1 Parent(s): 1772326

Migrated from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 000000000285.jpg +0 -0
  2. 000000000724.jpg +0 -0
  3. 000000007991.jpg +0 -0
  4. 000000018837.jpg +0 -0
  5. 000000122962.jpg +0 -0
  6. 000000295478.jpg +0 -0
  7. ORIGINAL_README.md +128 -0
  8. eval_controlnet.py +148 -0
  9. eval_controlnet.sh +19 -0
  10. eval_controlnet_sdxl_light.py +284 -0
  11. eval_controlnet_sdxl_light.sh +44 -0
  12. eval_controlnet_sdxl_light_single.py +390 -0
  13. eval_controlnet_sdxl_light_single.sh +20 -0
  14. example/UUColor_results/Hollywood-Sign.jpeg +0 -0
  15. example/legacy_images/Big-Ben-vintage.jpg +0 -0
  16. example/legacy_images/Central-Park.jpg +0 -0
  17. example/legacy_images/Hollywood-Sign.jpg +0 -0
  18. example/legacy_images/Little-Mermaid.jpg +0 -0
  19. example/legacy_images/Migrant-Mother.jpg +0 -0
  20. example/legacy_images/Mount-Everest.jpg +0 -0
  21. example/legacy_images/Tower-of-Pisa.jpg +0 -0
  22. example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg +0 -0
  23. gradio_ui.py +356 -0
  24. images/000000022935_gray.jpg +0 -0
  25. images/000000022935_green_shirt_on_right_girl.jpeg +0 -0
  26. images/000000022935_purple_shirt_on_right_girl.jpeg +0 -0
  27. images/000000022935_red_shirt_on_right_girl.jpeg +0 -0
  28. images/000000025560_color.jpg +0 -0
  29. images/000000025560_gray.jpg +0 -0
  30. images/000000025560_gt.jpg +0 -0
  31. images/000000041633_black_car.jpeg +0 -0
  32. images/000000041633_bright_red_car.jpeg +0 -0
  33. images/000000041633_dark_blue_car.jpeg +0 -0
  34. images/000000041633_gray.jpg +0 -0
  35. images/000000065736_color.jpg +0 -0
  36. images/000000065736_gray.jpg +0 -0
  37. images/000000065736_gt.jpg +0 -0
  38. images/000000091779_color.jpg +0 -0
  39. images/000000091779_gray.jpg +0 -0
  40. images/000000091779_gt.jpg +0 -0
  41. images/000000092177_color.jpg +0 -0
  42. images/000000092177_gray.jpg +0 -0
  43. images/000000092177_gt.jpg +0 -0
  44. images/000000166426_color.jpg +0 -0
  45. images/000000166426_gray.jpg +0 -0
  46. images/000000166426_gt.jpg +0 -0
  47. images/000000286708_gray.jpg +0 -0
  48. images/000000286708_orange_hat.jpeg +0 -0
  49. images/000000286708_pink_hat.jpeg +0 -0
  50. images/000000286708_yellow_hat.jpeg +0 -0
000000000285.jpg ADDED
000000000724.jpg ADDED
000000007991.jpg ADDED
000000018837.jpg ADDED
000000122962.jpg ADDED
000000295478.jpg ADDED
ORIGINAL_README.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text-Guided-Image-Colorization
2
+
3
+ This project utilizes the power of **Stable Diffusion (SDXL/SDXL-Light)** and the **BLIP (Bootstrapping Language-Image Pre-training)** captioning model to provide an interactive image colorization experience. Users can influence the generated colors of objects within images, making the colorization process more personalized and creative.
4
+
5
+ ## Table of Contents
6
+ - [Features](#features)
7
+ - [Installation](#installation)
8
+ - [Quick Start](#quick-start)
9
+ - [Dataset Usage](#dataset-usage)
10
+ - [Training](#training)
11
+ - [Evaluation](#evaluation)
12
+ - [Results](#results)
13
+ - [License](#license)
14
+
15
+ ## Features
16
+
17
+ - **Interactive Colorization**: Users can specify desired colors for different objects in the image.
18
+ - **ControlNet Approach**: Enhanced colorization capabilities through retraining with ControlNet, allowing SDXL to better adapt to the image colorization task.
19
+ - **High-Quality Outputs**: Leverage the latest advancements in diffusion models to generate vibrant and realistic colorizations.
20
+ - **User-Friendly Interface**: Easy-to-use interface for seamless interaction with the model.
21
+
22
+ ## Installation
23
+
24
+ To set up the project locally, follow these steps:
25
+
26
+ 1. **Clone the Repository**:
27
+
28
+ ```bash
29
+ git clone https://github.com/nick8592/text-guided-image-colorization.git
30
+ cd text-guided-image-colorization
31
+ ```
32
+
33
+ 2. **Install Dependencies**:
34
+ Make sure you have Python 3.7 or higher installed. Then, install the required packages:
35
+
36
+ ```bash
37
+ pip install -r requirements.txt
38
+ ```
39
+ Install `torch` and `torchvision` matching your CUDA version:
40
+ ```bash
41
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cuXXX
42
+ ```
43
+ Replace `XXX` with your CUDA version (e.g., `118` for CUDA 11.8). For more info, see [PyTorch Get Started](https://pytorch.org/get-started/locally/).
44
+
45
+
46
+ 3. **Download Pre-trained Models**:
47
+ | Models | Hugging Face (Recommand) | Other |
48
+ |:---:|:---:|:---:|
49
+ |SDXL-Lightning Caption|[link](https://huggingface.co/nickpai/sdxl_light_caption_output)|[link](https://gofile.me/7uE8s/FlEhfpWPw) (2kNJfV)|
50
+ |SDXL-Lightning Custom Caption (Recommand)|[link](https://huggingface.co/nickpai/sdxl_light_custom_caption_output)|[link](https://gofile.me/7uE8s/AKmRq5sLR) (KW7Fpi)|
51
+
52
+
53
+ ```bash
54
+ text-guided-image-colorization/sdxl_light_caption_output
55
+ └── checkpoint-30000
56
+ ├── controlnet
57
+ │ ├── diffusion_pytorch_model.safetensors
58
+ │ └── config.json
59
+ ├── optimizer.bin
60
+ ├── random_states_0.pkl
61
+ ├── scaler.pt
62
+ └── scheduler.bin
63
+ ```
64
+
65
+ ## Quick Start
66
+
67
+ 1. Run the `gradio_ui.py` script:
68
+
69
+ ```bash
70
+ python gradio_ui.py
71
+ ```
72
+
73
+ 2. Open the provided URL in your web browser to access the Gradio-based user interface.
74
+
75
+ 3. Upload an image and use the interface to control the colors of specific objects in the image. But still the model can generate images without a specific prompt.
76
+
77
+ 4. The model will generate a colorized version of the image based on your input (or automatic). See the [demo video](https://x.com/weichenpai/status/1829513077588631987).
78
+ ![Gradio UI](images/gradio_ui.png)
79
+
80
+
81
+ ## Dataset Usage
82
+
83
+ You can find more details about the dataset usage in the [Dataset-for-Image-Colorization](https://github.com/nick8592/Dataset-for-Image-Colorization).
84
+
85
+ ## Training
86
+
87
+ For training, you can use one of the following scripts:
88
+
89
+ - `train_controlnet.sh`: Trains a model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1)
90
+ - `train_controlnet_sdxl.sh`: Trains a model using [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
91
+ - `train_controlnet_sdxl_light.sh`: Trains a model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
92
+
93
+ Although the training code for SDXL is provided, due to a lack of GPU resources, I wasn't able to train the model by myself. Therefore, there might be some errors when you try to train the model.
94
+
95
+ ## Evaluation
96
+
97
+ For evaluation, you can use one of the following scripts:
98
+
99
+ - `eval_controlnet.sh`: Evaluates the model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1) for a folder of images.
100
+ - `eval_controlnet_sdxl_light.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a folder of images.
101
+ - `eval_controlnet_sdxl_light_single.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a single image.
102
+
103
+ ## Results
104
+ ### Prompt-Guided
105
+ | Caption | Condition 1 | Condition 2 | Condition 3 |
106
+ |:---:|:---:|:---:|:---:|
107
+ | ![000000022935_gray.jpg](images/000000022935_gray.jpg) | ![000000022935_green_shirt_on_right_girl.jpeg](images/000000022935_green_shirt_on_right_girl.jpeg) | ![000000022935_purple_shirt_on_right_girl.jpeg](images/000000022935_purple_shirt_on_right_girl.jpeg) |![000000022935_red_shirt_on_right_girl.jpeg](images/000000022935_red_shirt_on_right_girl.jpeg) |
108
+ | a photography of a woman in a soccer uniform kicking a soccer ball | + "green shirt"| + "purple shirt" | + "red shirt" |
109
+ | ![000000041633_gray.jpg](images/000000041633_gray.jpg) | ![000000041633_bright_red_car.jpeg](images/000000041633_bright_red_car.jpeg) | ![000000041633_dark_blue_car.jpeg](images/000000041633_dark_blue_car.jpeg) |![000000041633_black_car.jpeg](images/000000041633_black_car.jpeg) |
110
+ | a photography of a photo of a truck | + "bright red car"| + "dark blue car" | + "black car" |
111
+ | ![000000286708_gray.jpg](images/000000286708_gray.jpg) | ![000000286708_orange_hat.jpeg](images/000000286708_orange_hat.jpeg) | ![000000286708_pink_hat.jpeg](images/000000286708_pink_hat.jpeg) |![000000286708_yellow_hat.jpeg](images/000000286708_yellow_hat.jpeg) |
112
+ | a photography of a cat wearing a hat on his head | + "orange hat"| + "pink hat" | + "yellow hat" |
113
+
114
+ ### Prompt-Free
115
+ Ground truth images are provided solely for reference purpose in the image colorization task.
116
+ | Grayscale Image | Colorized Result | Ground Truth |
117
+ |:---:|:---:|:---:|
118
+ | ![000000025560_gray.jpg](images/000000025560_gray.jpg) | ![000000025560_color.jpg](images/000000025560_color.jpg) | ![000000025560_gt.jpg](images/000000025560_gt.jpg) |
119
+ | ![000000065736_gray.jpg](images/000000065736_gray.jpg) | ![000000065736_color.jpg](images/000000065736_color.jpg) | ![000000065736_gt.jpg](images/000000065736_gt.jpg) |
120
+ | ![000000091779_gray.jpg](images/000000091779_gray.jpg) | ![000000091779_color.jpg](images/000000091779_color.jpg) | ![000000091779_gt.jpg](images/000000091779_gt.jpg) |
121
+ | ![000000092177_gray.jpg](images/000000092177_gray.jpg) | ![000000092177_color.jpg](images/000000092177_color.jpg) | ![000000092177_gt.jpg](images/000000092177_gt.jpg) |
122
+ | ![000000166426_gray.jpg](images/000000166426_gray.jpg) | ![000000166426_color.jpg](images/000000166426_color.jpg) | ![000000025560_gt.jpg](images/000000166426_gt.jpg) |
123
+
124
+
125
+
126
+ ## License
127
+
128
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details.
eval_controlnet.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import shutil
5
+ import argparse
6
+ import numpy as np
7
+
8
+ from tqdm import tqdm
9
+ from PIL import Image
10
+ from datasets import load_dataset
11
+ from diffusers.utils import load_image
12
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
13
+
14
+ # Define the function to parse arguments
15
+ def parse_args(input_args=None):
16
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.")
17
+
18
+ parser.add_argument("--model_dir", type=str, default="sd_v2_caption_free_output/checkpoint-22500",
19
+ help="Directory of the model checkpoint")
20
+ parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-base",
21
+ help="ID of the model (Tested with runwayml/stable-diffusion-v1-5 and stabilityai/stable-diffusion-2-base)")
22
+ parser.add_argument("--dataset", type=str, default="nickpai/coco2017-colorization",
23
+ help="Dataset used")
24
+ parser.add_argument("--revision", type=str, default="caption-free",
25
+ choices=["main", "caption-free"],
26
+ help="Revision option (main/caption-free)")
27
+
28
+ if input_args is not None:
29
+ args = parser.parse_args(input_args)
30
+ else:
31
+ args = parser.parse_args()
32
+
33
+ return args
34
+
35
+ def apply_color(image, color_map):
36
+ # Convert input images to LAB color space
37
+ image_lab = image.convert('LAB')
38
+ color_map_lab = color_map.convert('LAB')
39
+
40
+ # Split LAB channels
41
+ l, a, b = image_lab.split()
42
+ _, a_map, b_map = color_map_lab.split()
43
+
44
+ # Merge LAB channels with color map
45
+ merged_lab = Image.merge('LAB', (l, a_map, b_map))
46
+
47
+ # Convert merged LAB image back to RGB color space
48
+ result_rgb = merged_lab.convert('RGB')
49
+
50
+ return result_rgb
51
+
52
+ def main(args):
53
+ generator = torch.manual_seed(0)
54
+
55
+ # MODEL_DIR = "sd_v2_caption_free_output/checkpoint-22500"
56
+ # # MODEL_ID="runwayml/stable-diffusion-v1-5"
57
+ # MODEL_ID="stabilityai/stable-diffusion-2-base"
58
+ # DATASET = "nickpai/coco2017-colorization"
59
+ # REVISION = "caption-free" # option: main/caption-free
60
+
61
+ # Path to the eval_results folder
62
+ eval_results_folder = os.path.join(args.model_dir, "results")
63
+
64
+ # Remove eval_results folder if it exists
65
+ if os.path.exists(eval_results_folder):
66
+ shutil.rmtree(eval_results_folder)
67
+
68
+ # Create directory for eval_results
69
+ os.makedirs(eval_results_folder)
70
+
71
+ # Create subfolders for compare and colorized images
72
+ compare_folder = os.path.join(eval_results_folder, "compare")
73
+ colorized_folder = os.path.join(eval_results_folder, "colorized")
74
+ os.makedirs(compare_folder)
75
+ os.makedirs(colorized_folder)
76
+
77
+ # Load the validation split of the colorization dataset
78
+ val_dataset = load_dataset(args.dataset, split="validation", revision=args.revision)
79
+
80
+ controlnet = ControlNetModel.from_pretrained(f"{args.model_dir}/controlnet", torch_dtype=torch.float16)
81
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
82
+ args.model_id, controlnet=controlnet, torch_dtype=torch.float16
83
+ ).to("cuda")
84
+
85
+ pipe.safety_checker = None
86
+
87
+ # Counter for processed images
88
+ processed_images = 0
89
+
90
+ # Record start time
91
+ start_time = time.time()
92
+
93
+ # Iterate through the validation dataset
94
+ for example in tqdm(val_dataset, desc="Processing Images"):
95
+ image_path = example["file_name"]
96
+
97
+ prompt = []
98
+ for caption in example["captions"]:
99
+ if isinstance(caption, str):
100
+ prompt.append(caption)
101
+ elif isinstance(caption, (list, np.ndarray)):
102
+ # take a random caption if there are multiple
103
+ prompt.append(caption[0])
104
+ else:
105
+ raise ValueError(
106
+ f"Caption column `captions` should contain either strings or lists of strings."
107
+ )
108
+
109
+ # Generate image
110
+ ground_truth_image = load_image(image_path).resize((512, 512))
111
+ control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512))
112
+ image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
113
+
114
+ # Apply color mapping
115
+ image = apply_color(ground_truth_image, image)
116
+
117
+ # Concatenate images into a row
118
+ row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image)))
119
+ row_image = Image.fromarray(row_image)
120
+
121
+ # Save row image in the compare folder
122
+ compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}")
123
+ row_image.save(compare_output_path)
124
+
125
+ # Save colorized image in the colorized folder
126
+ colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}")
127
+ image.save(colorized_output_path)
128
+
129
+ # Increment processed images counter
130
+ processed_images += 1
131
+
132
+ # Record end time
133
+ end_time = time.time()
134
+
135
+ # Calculate total time taken
136
+ total_time = end_time - start_time
137
+
138
+ # Calculate FPS
139
+ fps = processed_images / total_time
140
+
141
+ print("All images processed.")
142
+ print(f"Total time taken: {total_time:.2f} seconds")
143
+ print(f"FPS: {fps:.2f}")
144
+
145
+ # Entry point of the script
146
+ if __name__ == "__main__":
147
+ args = parse_args()
148
+ main(args)
eval_controlnet.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define default values for parameters
2
+
3
+ # # sdv2 with BCE loss
4
+ # MODEL_DIR="sd_v2_caption_bce_output/checkpoint-22500"
5
+ # MODEL_ID="stabilityai/stable-diffusion-2-base"
6
+ # DATASET="nickpai/coco2017-colorization"
7
+ # REVISION="main"
8
+
9
+ # sdv2 with kl loss
10
+ MODEL_DIR="sd_v2_caption_kl_output/checkpoint-22500"
11
+ MODEL_ID="stabilityai/stable-diffusion-2-base"
12
+ DATASET="nickpai/coco2017-colorization"
13
+ REVISION="main"
14
+
15
+ accelerate launch eval_controlnet.py \
16
+ --model_dir=$MODEL_DIR \
17
+ --model_id=$MODEL_ID \
18
+ --dataset=$DATASET \
19
+ --revision=$REVISION
eval_controlnet_sdxl_light.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import shutil
5
+ import argparse
6
+ import numpy as np
7
+
8
+ from tqdm import tqdm
9
+ from PIL import Image
10
+ from datasets import load_dataset
11
+ from accelerate import Accelerator
12
+ from diffusers.utils import load_image
13
+ from diffusers import (
14
+ AutoencoderKL,
15
+ StableDiffusionXLControlNetPipeline,
16
+ ControlNetModel,
17
+ UNet2DConditionModel,
18
+ )
19
+ from huggingface_hub import hf_hub_download
20
+ from safetensors.torch import load_file
21
+
22
+ # Define the function to parse arguments
23
+ def parse_args(input_args=None):
24
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.")
25
+
26
+ parser.add_argument(
27
+ "--pretrained_model_name_or_path",
28
+ type=str,
29
+ default=None,
30
+ required=True,
31
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
32
+ )
33
+ parser.add_argument(
34
+ "--pretrained_vae_model_name_or_path",
35
+ type=str,
36
+ default=None,
37
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
38
+ )
39
+ parser.add_argument(
40
+ "--controlnet_model_name_or_path",
41
+ type=str,
42
+ default=None,
43
+ required=True,
44
+ help="Path to pretrained controlnet model.",
45
+ )
46
+ parser.add_argument(
47
+ "--output_dir",
48
+ type=str,
49
+ default=None,
50
+ required=True,
51
+ help="Path to output results.",
52
+ )
53
+ parser.add_argument(
54
+ "--dataset",
55
+ type=str,
56
+ default="nickpai/coco2017-colorization",
57
+ help="Dataset used"
58
+ )
59
+ parser.add_argument(
60
+ "--dataset_revision",
61
+ type=str,
62
+ default="caption-free",
63
+ choices=["main", "caption-free", "custom-caption"],
64
+ help="Revision option (main/caption-free/custom-caption)"
65
+ )
66
+ parser.add_argument(
67
+ "--mixed_precision",
68
+ type=str,
69
+ default=None,
70
+ choices=["no", "fp16", "bf16"],
71
+ help=(
72
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
73
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
74
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
75
+ ),
76
+ )
77
+ parser.add_argument(
78
+ "--variant",
79
+ type=str,
80
+ default=None,
81
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
82
+ )
83
+ parser.add_argument(
84
+ "--revision",
85
+ type=str,
86
+ default=None,
87
+ required=False,
88
+ help="Revision of pretrained model identifier from huggingface.co/models.",
89
+ )
90
+ parser.add_argument(
91
+ "--num_inference_steps",
92
+ type=int,
93
+ default=8,
94
+ help="1-step, 2-step, 4-step, or 8-step distilled models"
95
+ )
96
+ parser.add_argument(
97
+ "--repo",
98
+ type=str,
99
+ default="ByteDance/SDXL-Lightning",
100
+ required=True,
101
+ help="Repository from huggingface.co",
102
+ )
103
+ parser.add_argument(
104
+ "--ckpt",
105
+ type=str,
106
+ default="sdxl_lightning_4step_unet.safetensors",
107
+ required=True,
108
+ help="Available checkpoints from the repository",
109
+ )
110
+ parser.add_argument(
111
+ "--negative_prompt",
112
+ action="store_true",
113
+ help="The prompt or prompts not to guide the image generation",
114
+ )
115
+
116
+ if input_args is not None:
117
+ args = parser.parse_args(input_args)
118
+ else:
119
+ args = parser.parse_args()
120
+
121
+ return args
122
+
123
+ def apply_color(image, color_map):
124
+ # Convert input images to LAB color space
125
+ image_lab = image.convert('LAB')
126
+ color_map_lab = color_map.convert('LAB')
127
+
128
+ # Split LAB channels
129
+ l, a, b = image_lab.split()
130
+ _, a_map, b_map = color_map_lab.split()
131
+
132
+ # Merge LAB channels with color map
133
+ merged_lab = Image.merge('LAB', (l, a_map, b_map))
134
+
135
+ # Convert merged LAB image back to RGB color space
136
+ result_rgb = merged_lab.convert('RGB')
137
+
138
+ return result_rgb
139
+
140
+ def main(args):
141
+ generator = torch.manual_seed(0)
142
+
143
+ # Path to the eval_results folder
144
+ eval_results_folder = os.path.join(args.output_dir, "results")
145
+
146
+ # Remove eval_results folder if it exists
147
+ if os.path.exists(eval_results_folder):
148
+ shutil.rmtree(eval_results_folder)
149
+
150
+ # Create directory for eval_results
151
+ os.makedirs(eval_results_folder)
152
+
153
+ # Create subfolders for compare and colorized images
154
+ compare_folder = os.path.join(eval_results_folder, "compare")
155
+ colorized_folder = os.path.join(eval_results_folder, "colorized")
156
+ os.makedirs(compare_folder)
157
+ os.makedirs(colorized_folder)
158
+
159
+ # Load the validation split of the colorization dataset
160
+ val_dataset = load_dataset(args.dataset, split="validation", revision=args.dataset_revision)
161
+
162
+ accelerator = Accelerator(
163
+ mixed_precision=args.mixed_precision,
164
+ )
165
+
166
+ weight_dtype = torch.float32
167
+ if accelerator.mixed_precision == "fp16":
168
+ weight_dtype = torch.float16
169
+ elif accelerator.mixed_precision == "bf16":
170
+ weight_dtype = torch.bfloat16
171
+
172
+ vae_path = (
173
+ args.pretrained_model_name_or_path
174
+ if args.pretrained_vae_model_name_or_path is None
175
+ else args.pretrained_vae_model_name_or_path
176
+ )
177
+ vae = AutoencoderKL.from_pretrained(
178
+ vae_path,
179
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
180
+ revision=args.revision,
181
+ variant=args.variant,
182
+ )
183
+ unet = UNet2DConditionModel.from_config(
184
+ args.pretrained_model_name_or_path,
185
+ subfolder="unet",
186
+ revision=args.revision,
187
+ variant=args.variant,
188
+ )
189
+ unet.load_state_dict(load_file(hf_hub_download(args.repo, args.ckpt)))
190
+
191
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
192
+ # The VAE is in float32 to avoid NaN losses.
193
+ if args.pretrained_vae_model_name_or_path is not None:
194
+ vae.to(accelerator.device, dtype=weight_dtype)
195
+ else:
196
+ vae.to(accelerator.device, dtype=torch.float32)
197
+ unet.to(accelerator.device, dtype=weight_dtype)
198
+
199
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=weight_dtype)
200
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
201
+ args.pretrained_model_name_or_path,
202
+ vae=vae,
203
+ unet=unet,
204
+ controlnet=controlnet,
205
+ )
206
+ pipe.to(accelerator.device, dtype=weight_dtype)
207
+
208
+ # Prepare everything with our `accelerator`.
209
+ pipe, val_dataset = accelerator.prepare(pipe, val_dataset)
210
+
211
+ pipe.safety_checker = None
212
+
213
+ # Counter for processed images
214
+ processed_images = 0
215
+
216
+ # Record start time
217
+ start_time = time.time()
218
+
219
+ # Iterate through the validation dataset
220
+ for example in tqdm(val_dataset, desc="Processing Images"):
221
+ image_path = example["file_name"]
222
+
223
+ prompt = []
224
+ for caption in example["captions"]:
225
+ if isinstance(caption, str):
226
+ prompt.append(caption)
227
+ elif isinstance(caption, (list, np.ndarray)):
228
+ # take a random caption if there are multiple
229
+ prompt.append(caption[0])
230
+ else:
231
+ raise ValueError(
232
+ f"Caption column `captions` should contain either strings or lists of strings."
233
+ )
234
+
235
+ negative_prompt = None
236
+ if args.negative_prompt:
237
+ negative_prompt = [
238
+ "low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate"
239
+ ]
240
+
241
+ # Generate image
242
+ ground_truth_image = load_image(image_path).resize((512, 512))
243
+ control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512))
244
+ image = pipe(prompt=prompt,
245
+ negative_prompt=negative_prompt,
246
+ num_inference_steps=args.num_inference_steps,
247
+ generator=generator,
248
+ image=control_image).images[0]
249
+
250
+ # Apply color mapping
251
+ image = apply_color(ground_truth_image, image)
252
+
253
+ # Concatenate images into a row
254
+ row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image)))
255
+ row_image = Image.fromarray(row_image)
256
+
257
+ # Save row image in the compare folder
258
+ compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}")
259
+ row_image.save(compare_output_path)
260
+
261
+ # Save colorized image in the colorized folder
262
+ colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}")
263
+ image.save(colorized_output_path)
264
+
265
+ # Increment processed images counter
266
+ processed_images += 1
267
+
268
+ # Record end time
269
+ end_time = time.time()
270
+
271
+ # Calculate total time taken
272
+ total_time = end_time - start_time
273
+
274
+ # Calculate FPS
275
+ fps = processed_images / total_time
276
+
277
+ print("All images processed.")
278
+ print(f"Total time taken: {total_time:.2f} seconds")
279
+ print(f"FPS: {fps:.2f}")
280
+
281
+ # Entry point of the script
282
+ if __name__ == "__main__":
283
+ args = parse_args()
284
+ main(args)
eval_controlnet_sdxl_light.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define default values for parameters
2
+
3
+ # # sdxl light without negative prompt
4
+ # export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0"
5
+ # export REPO="ByteDance/SDXL-Lightning"
6
+ # export INFERENCE_STEP=8
7
+ # export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step
8
+ # export CONTROLNET_MODEL="sdxl_light_custom_caption_output/checkpoint-12500/controlnet"
9
+ # export DATASET="nickpai/coco2017-colorization"
10
+ # export DATSET_REVISION="custom-caption"
11
+ # export OUTPUT_DIR="sdxl_light_custom_caption_output/checkpoint-12500"
12
+
13
+ # accelerate launch eval_controlnet_sdxl_light.py \
14
+ # --pretrained_model_name_or_path=$BASE_MODEL \
15
+ # --repo=$REPO \
16
+ # --ckpt=$CKPT \
17
+ # --num_inference_steps=$INFERENCE_STEP \
18
+ # --controlnet_model_name_or_path=$CONTROLNET_MODEL \
19
+ # --dataset=$DATASET \
20
+ # --dataset_revision=$DATSET_REVISION \
21
+ # --mixed_precision="fp16" \
22
+ # --output_dir=$OUTPUT_DIR
23
+
24
+ # sdxl light with negative prompt
25
+ export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0"
26
+ export REPO="ByteDance/SDXL-Lightning"
27
+ export INFERENCE_STEP=8
28
+ export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step
29
+ export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-22500/controlnet"
30
+ export DATASET="nickpai/coco2017-colorization"
31
+ export DATSET_REVISION="custom-caption"
32
+ export OUTPUT_DIR="sdxl_light_caption_output/checkpoint-22500"
33
+
34
+ accelerate launch eval_controlnet_sdxl_light.py \
35
+ --pretrained_model_name_or_path=$BASE_MODEL \
36
+ --repo=$REPO \
37
+ --ckpt=$CKPT \
38
+ --num_inference_steps=$INFERENCE_STEP \
39
+ --controlnet_model_name_or_path=$CONTROLNET_MODEL \
40
+ --dataset=$DATASET \
41
+ --dataset_revision=$DATSET_REVISION \
42
+ --mixed_precision="fp16" \
43
+ --output_dir=$OUTPUT_DIR \
44
+ --negative_prompt
eval_controlnet_sdxl_light_single.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PIL
3
+ import time
4
+ import torch
5
+ import argparse
6
+
7
+ from typing import Optional, Union
8
+ from accelerate import Accelerator
9
+ from diffusers import (
10
+ AutoencoderKL,
11
+ StableDiffusionXLControlNetPipeline,
12
+ ControlNetModel,
13
+ UNet2DConditionModel,
14
+ )
15
+ from transformers import (
16
+ BlipProcessor, BlipForConditionalGeneration,
17
+ VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
18
+ )
19
+ from huggingface_hub import hf_hub_download
20
+ from safetensors.torch import load_file
21
+
22
+ # Define the function to parse arguments
23
+ def parse_args(input_args=None):
24
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.")
25
+ parser.add_argument(
26
+ "--image_path",
27
+ type=str,
28
+ default="example/legacy_images/Hollywood-Sign.jpg",
29
+ required=True,
30
+ help="Path to the image",
31
+ )
32
+ parser.add_argument(
33
+ "--pretrained_model_name_or_path",
34
+ type=str,
35
+ default=None,
36
+ required=True,
37
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
38
+ )
39
+ parser.add_argument(
40
+ "--pretrained_vae_model_name_or_path",
41
+ type=str,
42
+ default=None,
43
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
44
+ )
45
+ parser.add_argument(
46
+ "--controlnet_model_name_or_path",
47
+ type=str,
48
+ default=None,
49
+ required=True,
50
+ help="Path to pretrained controlnet model.",
51
+ )
52
+ parser.add_argument(
53
+ "--caption_model_name",
54
+ type=str,
55
+ default="blip-image-captioning-large",
56
+ choices=["blip-image-captioning-large", "blip-image-captioning-base"],
57
+ help="Path to pretrained controlnet model.",
58
+ )
59
+ parser.add_argument(
60
+ "--mixed_precision",
61
+ type=str,
62
+ default=None,
63
+ choices=["no", "fp16", "bf16"],
64
+ help=(
65
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
66
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
67
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
68
+ ),
69
+ )
70
+ parser.add_argument(
71
+ "--variant",
72
+ type=str,
73
+ default=None,
74
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
75
+ )
76
+ parser.add_argument(
77
+ "--revision",
78
+ type=str,
79
+ default=None,
80
+ required=False,
81
+ help="Revision of pretrained model identifier from huggingface.co/models.",
82
+ )
83
+ parser.add_argument(
84
+ "--num_inference_steps",
85
+ type=int,
86
+ default=8,
87
+ help="1-step, 2-step, 4-step, or 8-step distilled models"
88
+ )
89
+ parser.add_argument(
90
+ "--repo",
91
+ type=str,
92
+ default="ByteDance/SDXL-Lightning",
93
+ required=True,
94
+ help="Repository from huggingface.co",
95
+ )
96
+ parser.add_argument(
97
+ "--ckpt",
98
+ type=str,
99
+ default="sdxl_lightning_4step_unet.safetensors",
100
+ required=True,
101
+ help="Available checkpoints from the repository",
102
+ )
103
+ parser.add_argument(
104
+ "--seed",
105
+ type=int,
106
+ default=123,
107
+ help="Random seeds"
108
+ )
109
+ parser.add_argument(
110
+ "--positive_prompt",
111
+ type=str,
112
+ help="Text for positive prompt",
113
+ )
114
+ parser.add_argument(
115
+ "--negative_prompt",
116
+ type=str,
117
+ default="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate",
118
+ help="Text for negative prompt",
119
+ )
120
+
121
+ if input_args is not None:
122
+ args = parser.parse_args(input_args)
123
+ else:
124
+ args = parser.parse_args()
125
+
126
+ return args
127
+
128
+ def apply_color(image, color_map):
129
+ # Convert input images to LAB color space
130
+ image_lab = image.convert('LAB')
131
+ color_map_lab = color_map.convert('LAB')
132
+
133
+ # Split LAB channels
134
+ l, a, b = image_lab.split()
135
+ _, a_map, b_map = color_map_lab.split()
136
+
137
+ # Merge LAB channels with color map
138
+ merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
139
+
140
+ # Convert merged LAB image back to RGB color space
141
+ result_rgb = merged_lab.convert('RGB')
142
+
143
+ return result_rgb
144
+
145
+ def remove_unlikely_words(prompt: str) -> str:
146
+ """
147
+ Removes unlikely words from a prompt.
148
+
149
+ Args:
150
+ prompt: The text prompt to be cleaned.
151
+
152
+ Returns:
153
+ The cleaned prompt with unlikely words removed.
154
+ """
155
+ unlikely_words = []
156
+
157
+ a1_list = [f'{i}s' for i in range(1900, 2000)]
158
+ a2_list = [f'{i}' for i in range(1900, 2000)]
159
+ a3_list = [f'year {i}' for i in range(1900, 2000)]
160
+ a4_list = [f'circa {i}' for i in range(1900, 2000)]
161
+ b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list]
162
+ b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
163
+ b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
164
+ b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
165
+
166
+ words_list = [
167
+ "black and white,", "black and white", "black & white,", "black & white", "circa",
168
+ "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
169
+ "black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
170
+ "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
171
+ "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
172
+ "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
173
+ "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
174
+ "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
175
+ "black-and-white photo,", "black-and-white photo", "black - and - white photography",
176
+ "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
177
+ "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
178
+ "black - and - white photograph,", "black - and - white photograph", "black on white,",
179
+ "black on white", "black-and-white", "historical image,", "historical picture,",
180
+ "historical photo,", "historical photograph,", "archival photo,", "taken in the early",
181
+ "taken in the late", "taken in the", "historic photograph,", "restored,", "restored",
182
+ "historical photo", "historical setting,",
183
+ "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
184
+ "taken in", "shot on leica", "shot on leica sl2", "sl2",
185
+ "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting",
186
+ "overcast day", "overcast weather", "slight overcast", "overcast",
187
+ "picture taken in", "photo taken in",
188
+ ", photo", ", photo", ", photo", ", photo", ", photograph",
189
+ ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
190
+ ]
191
+
192
+ unlikely_words.extend(a1_list)
193
+ unlikely_words.extend(a2_list)
194
+ unlikely_words.extend(a3_list)
195
+ unlikely_words.extend(a4_list)
196
+ unlikely_words.extend(b1_list)
197
+ unlikely_words.extend(b2_list)
198
+ unlikely_words.extend(b3_list)
199
+ unlikely_words.extend(b4_list)
200
+ unlikely_words.extend(words_list)
201
+
202
+ for word in unlikely_words:
203
+ prompt = prompt.replace(word, "")
204
+ return prompt
205
+
206
+ def blip_image_captioning(image: PIL.Image.Image,
207
+ model_backbone: str,
208
+ weight_dtype: type,
209
+ device: str,
210
+ conditional: bool) -> str:
211
+ # https://huggingface.co/Salesforce/blip-image-captioning-large
212
+ # https://huggingface.co/Salesforce/blip-image-captioning-base
213
+ if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type
214
+ weight_dtype = torch.float16
215
+
216
+ processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}")
217
+ model = BlipForConditionalGeneration.from_pretrained(
218
+ f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device)
219
+
220
+ valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"]
221
+ if model_backbone not in valid_backbones:
222
+ raise ValueError(f"Invalid model backbone '{model_backbone}'. \
223
+ Valid options are: {', '.join(valid_backbones)}")
224
+
225
+ if conditional:
226
+ text = "a photography of"
227
+ inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype)
228
+ else:
229
+ inputs = processor(image, return_tensors="pt").to(device)
230
+ out = model.generate(**inputs)
231
+ caption = processor.decode(out[0], skip_special_tokens=True)
232
+ return caption
233
+
234
+ import matplotlib.pyplot as plt
235
+
236
+ def display_images(input_image, output_image, ground_truth):
237
+ """
238
+ Displays a grid of input, output, ground truth images with a caption at the bottom.
239
+
240
+ Args:
241
+ input_image: A grayscale image as a NumPy array.
242
+ output_image: A grayscale image (result) as a NumPy array.
243
+ ground_truth: A grayscale image (ground truth) as a NumPy array.
244
+ """
245
+ fig, axes = plt.subplots(1, 3, figsize=(20, 8))
246
+
247
+ axes[0].imshow(input_image, cmap='gray')
248
+ axes[0].set_title('Input')
249
+ axes[0].axis('off')
250
+
251
+ axes[1].imshow(output_image)
252
+ axes[1].set_title('Output')
253
+ axes[1].axis('off')
254
+
255
+ axes[2].imshow(ground_truth)
256
+ axes[2].set_title('Ground Truth')
257
+ axes[2].axis('off')
258
+
259
+ plt.tight_layout()
260
+ plt.show()
261
+
262
+ # Define a function to process the image with the loaded model
263
+ def process_image(image_path: str,
264
+ controlnet_model_name_or_path: str,
265
+ caption_model_name: str,
266
+ positive_prompt: Optional[str],
267
+ negative_prompt: Optional[str],
268
+ seed: int,
269
+ num_inference_steps: int,
270
+ mixed_precision: str,
271
+ pretrained_model_name_or_path: str,
272
+ pretrained_vae_model_name_or_path: Optional[str],
273
+ revision: Optional[str],
274
+ variant: Optional[str],
275
+ repo: str,
276
+ ckpt: str,) -> PIL.Image.Image:
277
+ # Seed
278
+ generator = torch.manual_seed(seed)
279
+
280
+ # Accelerator Setting
281
+ accelerator = Accelerator(
282
+ mixed_precision=mixed_precision,
283
+ )
284
+
285
+ weight_dtype = torch.float32
286
+ if accelerator.mixed_precision == "fp16":
287
+ weight_dtype = torch.float16
288
+ elif accelerator.mixed_precision == "bf16":
289
+ weight_dtype = torch.bfloat16
290
+
291
+ vae_path = (
292
+ pretrained_model_name_or_path
293
+ if pretrained_vae_model_name_or_path is None
294
+ else pretrained_vae_model_name_or_path
295
+ )
296
+ vae = AutoencoderKL.from_pretrained(
297
+ vae_path,
298
+ subfolder="vae" if pretrained_vae_model_name_or_path is None else None,
299
+ revision=revision,
300
+ variant=variant,
301
+ )
302
+ unet = UNet2DConditionModel.from_config(
303
+ pretrained_model_name_or_path,
304
+ subfolder="unet",
305
+ revision=revision,
306
+ variant=variant,
307
+ )
308
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
309
+
310
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
311
+ # The VAE is in float32 to avoid NaN losses.
312
+ if pretrained_vae_model_name_or_path is not None:
313
+ vae.to(accelerator.device, dtype=weight_dtype)
314
+ else:
315
+ vae.to(accelerator.device, dtype=torch.float32)
316
+ unet.to(accelerator.device, dtype=weight_dtype)
317
+
318
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype)
319
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
320
+ pretrained_model_name_or_path,
321
+ vae=vae,
322
+ unet=unet,
323
+ controlnet=controlnet,
324
+ )
325
+ pipe.to(accelerator.device, dtype=weight_dtype)
326
+
327
+ image = PIL.Image.open(image_path)
328
+
329
+ # Prepare everything with our `accelerator`.
330
+ pipe, image = accelerator.prepare(pipe, image)
331
+ pipe.safety_checker = None
332
+
333
+ # Convert image into grayscale
334
+ original_size = image.size
335
+ control_image = image.convert("L").convert("RGB").resize((512, 512))
336
+
337
+ # Image captioning
338
+ if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base":
339
+ caption = blip_image_captioning(control_image, caption_model_name,
340
+ weight_dtype, accelerator.device, conditional=True)
341
+ # elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k":
342
+ # caption = clip_image_captioning(control_image, caption_model_name, accelerator.device)
343
+ # elif caption_model_name == "vit-gpt2-image-captioning":
344
+ # caption = vit_gpt2_image_captioning(control_image, accelerator.device)
345
+ caption = remove_unlikely_words(caption)
346
+
347
+ print("================================================================")
348
+ print(f"Positive prompt: \n>>> {positive_prompt}")
349
+ print(f"Negative prompt: \n>>> {negative_prompt}")
350
+ print(f"Caption results: \n>>> {caption}")
351
+ print("================================================================")
352
+
353
+ # Combine positive prompt and captioning result
354
+ prompt = [positive_prompt + ", " + caption]
355
+
356
+ # Image colorization
357
+ image = pipe(prompt=prompt,
358
+ negative_prompt=negative_prompt,
359
+ num_inference_steps=num_inference_steps,
360
+ generator=generator,
361
+ image=control_image).images[0]
362
+
363
+ # Apply color mapping
364
+ result_image = apply_color(control_image, image)
365
+ result_image = result_image.resize(original_size)
366
+ return result_image, caption
367
+
368
+ def main(args):
369
+ output_image, output_caption = process_image(image_path=args.image_path,
370
+ controlnet_model_name_or_path=args.controlnet_model_name_or_path,
371
+ caption_model_name=args.caption_model_name,
372
+ positive_prompt=args.positive_prompt,
373
+ negative_prompt=args.negative_prompt,
374
+ seed=args.seed,
375
+ num_inference_steps=args.num_inference_steps,
376
+ mixed_precision=args.mixed_precision,
377
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
378
+ pretrained_vae_model_name_or_path=args.pretrained_vae_model_name_or_path,
379
+ revision=args.revision,
380
+ variant=args.variant,
381
+ repo=args.repo,
382
+ ckpt=args.ckpt,)
383
+ input_image = PIL.Image.open(args.image_path)
384
+ display_images(input_image.convert("L"), output_image, input_image)
385
+ return output_image, output_caption
386
+
387
+ # Entry point of the script
388
+ if __name__ == "__main__":
389
+ args = parse_args()
390
+ main(args)
eval_controlnet_sdxl_light_single.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sdxl light for single image
2
+ export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0"
3
+ export REPO="ByteDance/SDXL-Lightning"
4
+ export INFERENCE_STEP=8
5
+ export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step
6
+ export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-30000/controlnet"
7
+ export CAPTION_MODEL="blip-image-captioning-large"
8
+ export IMAGE_PATH="example/legacy_images/Hollywood-Sign.jpg"
9
+ # export POSITIVE_PROMPT="blue shirt"
10
+
11
+ accelerate launch eval_controlnet_sdxl_light_single.py \
12
+ --pretrained_model_name_or_path=$BASE_MODEL \
13
+ --repo=$REPO \
14
+ --ckpt=$CKPT \
15
+ --num_inference_steps=$INFERENCE_STEP \
16
+ --controlnet_model_name_or_path=$CONTROLNET_MODEL \
17
+ --caption_model_name=$CAPTION_MODEL \
18
+ --mixed_precision="fp16" \
19
+ --image_path=$IMAGE_PATH \
20
+ --positive_prompt="red car"
example/UUColor_results/Hollywood-Sign.jpeg ADDED
example/legacy_images/Big-Ben-vintage.jpg ADDED
example/legacy_images/Central-Park.jpg ADDED
example/legacy_images/Hollywood-Sign.jpg ADDED
example/legacy_images/Little-Mermaid.jpg ADDED
example/legacy_images/Migrant-Mother.jpg ADDED
example/legacy_images/Mount-Everest.jpg ADDED
example/legacy_images/Tower-of-Pisa.jpg ADDED
example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg ADDED
gradio_ui.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torch
3
+ import subprocess
4
+ import gradio as gr
5
+
6
+ from typing import Optional
7
+ from accelerate import Accelerator
8
+ from diffusers import (
9
+ AutoencoderKL,
10
+ StableDiffusionXLControlNetPipeline,
11
+ ControlNetModel,
12
+ UNet2DConditionModel,
13
+ )
14
+ from transformers import (
15
+ BlipProcessor, BlipForConditionalGeneration,
16
+ VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
17
+ )
18
+ from huggingface_hub import hf_hub_download
19
+ from safetensors.torch import load_file
20
+ from clip_interrogator import Interrogator, Config, list_clip_models
21
+
22
+ def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image:
23
+ # Convert input images to LAB color space
24
+ image_lab = image.convert('LAB')
25
+ color_map_lab = color_map.convert('LAB')
26
+
27
+ # Split LAB channels
28
+ l, a , b = image_lab.split()
29
+ _, a_map, b_map = color_map_lab.split()
30
+
31
+ # Merge LAB channels with color map
32
+ merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
33
+
34
+ # Convert merged LAB image back to RGB color space
35
+ result_rgb = merged_lab.convert('RGB')
36
+ return result_rgb
37
+
38
+ def remove_unlikely_words(prompt: str) -> str:
39
+ """
40
+ Removes unlikely words from a prompt.
41
+
42
+ Args:
43
+ prompt: The text prompt to be cleaned.
44
+
45
+ Returns:
46
+ The cleaned prompt with unlikely words removed.
47
+ """
48
+ unlikely_words = []
49
+
50
+ a1_list = [f'{i}s' for i in range(1900, 2000)]
51
+ a2_list = [f'{i}' for i in range(1900, 2000)]
52
+ a3_list = [f'year {i}' for i in range(1900, 2000)]
53
+ a4_list = [f'circa {i}' for i in range(1900, 2000)]
54
+ b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list]
55
+ b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
56
+ b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
57
+ b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
58
+
59
+ words_list = [
60
+ "black and white,", "black and white", "black & white,", "black & white", "circa",
61
+ "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
62
+ "black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
63
+ "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
64
+ "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
65
+ "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
66
+ "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
67
+ "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
68
+ "black-and-white photo,", "black-and-white photo", "black - and - white photography",
69
+ "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
70
+ "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
71
+ "black - and - white photograph,", "black - and - white photograph", "black on white,",
72
+ "black on white", "black-and-white", "historical image,", "historical picture,",
73
+ "historical photo,", "historical photograph,", "archival photo,", "taken in the early",
74
+ "taken in the late", "taken in the", "historic photograph,", "restored,", "restored",
75
+ "historical photo", "historical setting,",
76
+ "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
77
+ "taken in", "shot on leica", "shot on leica sl2", "sl2",
78
+ "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting",
79
+ "overcast day", "overcast weather", "slight overcast", "overcast",
80
+ "picture taken in", "photo taken in",
81
+ ", photo", ", photo", ", photo", ", photo", ", photograph",
82
+ ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
83
+ ]
84
+
85
+ unlikely_words.extend(a1_list)
86
+ unlikely_words.extend(a2_list)
87
+ unlikely_words.extend(a3_list)
88
+ unlikely_words.extend(a4_list)
89
+ unlikely_words.extend(b1_list)
90
+ unlikely_words.extend(b2_list)
91
+ unlikely_words.extend(b3_list)
92
+ unlikely_words.extend(b4_list)
93
+ unlikely_words.extend(words_list)
94
+
95
+ for word in unlikely_words:
96
+ prompt = prompt.replace(word, "")
97
+ return prompt
98
+
99
+ def blip_image_captioning(image: PIL.Image.Image,
100
+ model_backbone: str,
101
+ weight_dtype: type,
102
+ device: str,
103
+ conditional: bool) -> str:
104
+ # https://huggingface.co/Salesforce/blip-image-captioning-large
105
+ # https://huggingface.co/Salesforce/blip-image-captioning-base
106
+ if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type
107
+ weight_dtype = torch.float16
108
+
109
+ processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}")
110
+ model = BlipForConditionalGeneration.from_pretrained(
111
+ f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device)
112
+
113
+ valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"]
114
+ if model_backbone not in valid_backbones:
115
+ raise ValueError(f"Invalid model backbone '{model_backbone}'. \
116
+ Valid options are: {', '.join(valid_backbones)}")
117
+
118
+ if conditional:
119
+ text = "a photography of"
120
+ inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype)
121
+ else:
122
+ inputs = processor(image, return_tensors="pt").to(device)
123
+ out = model.generate(**inputs)
124
+ caption = processor.decode(out[0], skip_special_tokens=True)
125
+ return caption
126
+
127
+ # def vit_gpt2_image_captioning(image: PIL.Image.Image, device: str) -> str:
128
+ # # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
129
+ # model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
130
+ # feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
131
+ # tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
132
+
133
+ # max_length = 16
134
+ # num_beams = 4
135
+ # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
136
+
137
+ # pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
138
+ # pixel_values = pixel_values.to(device)
139
+
140
+ # output_ids = model.generate(pixel_values, **gen_kwargs)
141
+
142
+ # preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
143
+ # caption = [pred.strip() for pred in preds]
144
+
145
+ # return caption[0]
146
+
147
+ # def clip_image_captioning(image: PIL.Image.Image,
148
+ # clip_model_name: str,
149
+ # device: str) -> str:
150
+ # # validate clip model name
151
+ # models = list_clip_models()
152
+ # if clip_model_name not in models:
153
+ # raise ValueError(f"Could not find CLIP model {clip_model_name}! \
154
+ # Available models: {models}")
155
+ # config = Config(device=device, clip_model_name=clip_model_name)
156
+ # config.apply_low_vram_defaults()
157
+ # ci = Interrogator(config)
158
+ # caption = ci.interrogate(image)
159
+ # return caption
160
+
161
+ # Define a function to process the image with the loaded model
162
+ def process_image(image_path: str,
163
+ controlnet_model_name_or_path: str,
164
+ caption_model_name: str,
165
+ positive_prompt: Optional[str],
166
+ negative_prompt: Optional[str],
167
+ seed: int,
168
+ num_inference_steps: int,
169
+ mixed_precision: str,
170
+ pretrained_model_name_or_path: str,
171
+ pretrained_vae_model_name_or_path: Optional[str],
172
+ revision: Optional[str],
173
+ variant: Optional[str],
174
+ repo: str,
175
+ ckpt: str,) -> PIL.Image.Image:
176
+ # Seed
177
+ generator = torch.manual_seed(seed)
178
+
179
+ # Accelerator Setting
180
+ accelerator = Accelerator(
181
+ mixed_precision=mixed_precision,
182
+ )
183
+
184
+ weight_dtype = torch.float32
185
+ if accelerator.mixed_precision == "fp16":
186
+ weight_dtype = torch.float16
187
+ elif accelerator.mixed_precision == "bf16":
188
+ weight_dtype = torch.bfloat16
189
+
190
+ vae_path = (
191
+ pretrained_model_name_or_path
192
+ if pretrained_vae_model_name_or_path is None
193
+ else pretrained_vae_model_name_or_path
194
+ )
195
+ vae = AutoencoderKL.from_pretrained(
196
+ vae_path,
197
+ subfolder="vae" if pretrained_vae_model_name_or_path is None else None,
198
+ revision=revision,
199
+ variant=variant,
200
+ )
201
+ unet = UNet2DConditionModel.from_config(
202
+ pretrained_model_name_or_path,
203
+ subfolder="unet",
204
+ revision=revision,
205
+ variant=variant,
206
+ )
207
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
208
+
209
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
210
+ # The VAE is in float32 to avoid NaN losses.
211
+ if pretrained_vae_model_name_or_path is not None:
212
+ vae.to(accelerator.device, dtype=weight_dtype)
213
+ else:
214
+ vae.to(accelerator.device, dtype=torch.float32)
215
+ unet.to(accelerator.device, dtype=weight_dtype)
216
+
217
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype)
218
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
219
+ pretrained_model_name_or_path,
220
+ vae=vae,
221
+ unet=unet,
222
+ controlnet=controlnet,
223
+ )
224
+ pipe.to(accelerator.device, dtype=weight_dtype)
225
+
226
+ image = PIL.Image.open(image_path)
227
+
228
+ # Prepare everything with our `accelerator`.
229
+ pipe, image = accelerator.prepare(pipe, image)
230
+ pipe.safety_checker = None
231
+
232
+ # Convert image into grayscale
233
+ original_size = image.size
234
+ control_image = image.convert("L").convert("RGB").resize((512, 512))
235
+
236
+ # Image captioning
237
+ if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base":
238
+ caption = blip_image_captioning(control_image, caption_model_name,
239
+ weight_dtype, accelerator.device, conditional=True)
240
+ # elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k":
241
+ # caption = clip_image_captioning(control_image, caption_model_name, accelerator.device)
242
+ # elif caption_model_name == "vit-gpt2-image-captioning":
243
+ # caption = vit_gpt2_image_captioning(control_image, accelerator.device)
244
+ caption = remove_unlikely_words(caption)
245
+
246
+ # Combine positive prompt and captioning result
247
+ prompt = [positive_prompt + ", " + caption]
248
+
249
+ # Image colorization
250
+ image = pipe(prompt=prompt,
251
+ negative_prompt=negative_prompt,
252
+ num_inference_steps=num_inference_steps,
253
+ generator=generator,
254
+ image=control_image).images[0]
255
+
256
+ # Apply color mapping
257
+ result_image = apply_color(control_image, image)
258
+ result_image = result_image.resize(original_size)
259
+ return result_image, caption
260
+
261
+ # Define the image gallery based on folder path
262
+ def get_image_paths(folder_path):
263
+ import os
264
+ image_paths = []
265
+ for filename in os.listdir(folder_path):
266
+ if filename.endswith(".jpg") or filename.endswith(".png"):
267
+ image_paths.append([os.path.join(folder_path, filename)])
268
+ return image_paths
269
+
270
+ # Create the Gradio interface
271
+ def create_interface():
272
+ controlnet_model_dict = {
273
+ "sdxl-light-caption-30000": "sdxl_light_caption_output/checkpoint-30000/controlnet",
274
+ "sdxl-light-custom-caption-30000": "sdxl_light_custom_caption_output/checkpoint-30000/controlnet",
275
+ }
276
+ images = get_image_paths("example/legacy_images") # Replace with your folder path
277
+
278
+ interface = gr.Interface(
279
+ fn=process_image,
280
+ inputs=[
281
+ gr.Image(label="Upload image",
282
+ value="example/legacy_images/Hollywood-Sign.jpg",
283
+ type='filepath'),
284
+ gr.Dropdown(choices=[controlnet_model_dict[key] for key in controlnet_model_dict],
285
+ value=controlnet_model_dict["sdxl-light-caption-30000"],
286
+ label="Select ControlNet Model"),
287
+ gr.Dropdown(choices=["blip-image-captioning-large",
288
+ "blip-image-captioning-base",],
289
+ value="blip-image-captioning-large",
290
+ label="Select Image Captioning Model"),
291
+ gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt"),
292
+ gr.Textbox(value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate",
293
+ label="Negative Prompt", placeholder="Text for negative prompt"),
294
+ ],
295
+ outputs=[
296
+ gr.Image(label="Colorized image",
297
+ value="example/UUColor_results/Hollywood-Sign.jpeg",
298
+ format="jpeg"),
299
+ gr.Textbox(label="Captioning Result", show_copy_button=True)
300
+ ],
301
+ examples=images,
302
+ additional_inputs=[
303
+ # gr.Radio(choices=["Original", "Square"], value="Original",
304
+ # label="Output resolution"),
305
+ # gr.Slider(minimum=128, maximum=512, value=256, step=128,
306
+ # label="Height & Width",
307
+ # info='Only effect if select "Square" output resolution'),
308
+ gr.Slider(0, 1000, 123, label="Seed"),
309
+ gr.Radio(choices=[1, 2, 4, 8],
310
+ value=8,
311
+ label="Inference Steps",
312
+ info="1-step, 2-step, 4-step, or 8-step distilled models"),
313
+ gr.Radio(choices=["no", "fp16", "bf16"],
314
+ value="fp16",
315
+ label="Mixed Precision",
316
+ info="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16)."),
317
+ gr.Dropdown(choices=["stabilityai/stable-diffusion-xl-base-1.0"],
318
+ value="stabilityai/stable-diffusion-xl-base-1.0",
319
+ label="Base Model",
320
+ info="Path to pretrained model or model identifier from huggingface.co/models."),
321
+ gr.Dropdown(choices=["None"],
322
+ value=None,
323
+ label="VAE Model",
324
+ info="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."),
325
+ gr.Dropdown(choices=["None"],
326
+ value=None,
327
+ label="Varient",
328
+ info="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"),
329
+ gr.Dropdown(choices=["None"],
330
+ value=None,
331
+ label="Revision",
332
+ info="Revision of pretrained model identifier from huggingface.co/models."),
333
+ gr.Dropdown(choices=["ByteDance/SDXL-Lightning"],
334
+ value="ByteDance/SDXL-Lightning",
335
+ label="Repository",
336
+ info="Repository from huggingface.co"),
337
+ gr.Dropdown(choices=["sdxl_lightning_1step_unet.safetensors",
338
+ "sdxl_lightning_2step_unet.safetensors",
339
+ "sdxl_lightning_4step_unet.safetensors",
340
+ "sdxl_lightning_8step_unet.safetensors"],
341
+ value="sdxl_lightning_8step_unet.safetensors",
342
+ label="Checkpoint",
343
+ info="Available checkpoints from the repository. Caution! Checkpoint's 'N'step must match with inference steps"),
344
+ ],
345
+ title="Text-Guided Image Colorization",
346
+ description="Upload an image and select a model to colorize it."
347
+ )
348
+ return interface
349
+
350
+ def main():
351
+ # Launch the Gradio interface
352
+ interface = create_interface()
353
+ interface.launch()
354
+
355
+ if __name__ == "__main__":
356
+ main()
images/000000022935_gray.jpg ADDED
images/000000022935_green_shirt_on_right_girl.jpeg ADDED
images/000000022935_purple_shirt_on_right_girl.jpeg ADDED
images/000000022935_red_shirt_on_right_girl.jpeg ADDED
images/000000025560_color.jpg ADDED
images/000000025560_gray.jpg ADDED
images/000000025560_gt.jpg ADDED
images/000000041633_black_car.jpeg ADDED
images/000000041633_bright_red_car.jpeg ADDED
images/000000041633_dark_blue_car.jpeg ADDED
images/000000041633_gray.jpg ADDED
images/000000065736_color.jpg ADDED
images/000000065736_gray.jpg ADDED
images/000000065736_gt.jpg ADDED
images/000000091779_color.jpg ADDED
images/000000091779_gray.jpg ADDED
images/000000091779_gt.jpg ADDED
images/000000092177_color.jpg ADDED
images/000000092177_gray.jpg ADDED
images/000000092177_gt.jpg ADDED
images/000000166426_color.jpg ADDED
images/000000166426_gray.jpg ADDED
images/000000166426_gt.jpg ADDED
images/000000286708_gray.jpg ADDED
images/000000286708_orange_hat.jpeg ADDED
images/000000286708_pink_hat.jpeg ADDED
images/000000286708_yellow_hat.jpeg ADDED