duanyuxuan commited on
Commit
69cffc7
·
1 Parent(s): c1b3775

demo starts

Browse files
Files changed (7) hide show
  1. .gitattributes +0 -0
  2. .gitignore +4 -0
  3. README.md +0 -0
  4. app.py +165 -129
  5. requirements.txt +1 -1
  6. tdd_svd_scheduler.py +487 -0
  7. utils.py +37 -0
.gitattributes CHANGED
File without changes
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ /examples
3
+ /outputs_gradio
4
+ svd-xt-1-1_tdd_lora_weights.safetensors
README.md CHANGED
File without changes
app.py CHANGED
@@ -1,142 +1,178 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
- #import spaces #[uncomment to use ZeroGPU]
5
- from diffusers import DiffusionPipeline
6
  import torch
 
 
 
 
 
 
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
10
 
11
- if torch.cuda.is_available():
12
- torch_dtype = torch.float16
 
 
 
 
13
  else:
14
- torch_dtype = torch.float32
 
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
 
 
 
 
18
 
19
- MAX_SEED = np.iinfo(np.int32).max
20
- MAX_IMAGE_SIZE = 1024
21
 
22
- #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
-
28
- generator = torch.Generator().manual_seed(seed)
29
-
30
- image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
- ).images[0]
39
-
40
- return image, seed
41
-
42
- examples = [
43
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
- "An astronaut riding a green horse",
45
- "A delicious ceviche cheesecake slice",
46
- ]
47
-
48
- css="""
49
- #col-container {
50
- margin: 0 auto;
51
- max-width: 640px;
52
- }
53
- """
54
-
55
- with gr.Blocks(css=css) as demo:
56
-
57
- with gr.Column(elem_id="col-container"):
58
- gr.Markdown(f"""
59
- # Text-to-Image Gradio Template
60
- """)
61
-
62
- with gr.Row():
63
-
64
- prompt = gr.Text(
65
- label="Prompt",
66
- show_label=False,
67
- max_lines=1,
68
- placeholder="Enter your prompt",
69
- container=False,
70
- )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
- result = gr.Image(label="Result", show_label=False)
75
-
76
- with gr.Accordion("Advanced Settings", open=False):
77
-
78
- negative_prompt = gr.Text(
79
- label="Negative prompt",
80
- max_lines=1,
81
- placeholder="Enter a negative prompt",
82
- visible=False,
83
- )
84
-
85
- seed = gr.Slider(
86
- label="Seed",
87
- minimum=0,
88
- maximum=MAX_SEED,
89
- step=1,
90
- value=0,
91
- )
92
-
93
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
-
95
- with gr.Row():
96
-
97
- width = gr.Slider(
98
- label="Width",
99
- minimum=256,
100
- maximum=MAX_IMAGE_SIZE,
101
- step=32,
102
- value=1024, #Replace with defaults that work for your model
103
- )
104
-
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=32,
110
- value=1024, #Replace with defaults that work for your model
111
- )
112
-
113
- with gr.Row():
114
-
115
- guidance_scale = gr.Slider(
116
- label="Guidance scale",
117
- minimum=0.0,
118
- maximum=10.0,
119
- step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
121
- )
122
-
123
- num_inference_steps = gr.Slider(
124
- label="Number of inference steps",
125
- minimum=1,
126
- maximum=50,
127
- step=1,
128
- value=2, #Replace with defaults that work for your model
129
- )
130
-
131
- gr.Examples(
132
- examples = examples,
133
- inputs = [prompt]
134
  )
135
- gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- demo.queue().launch()
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
 
 
 
 
3
  import torch
4
+ import torchvision as tv
5
+ import random, os
6
+ from diffusers import StableVideoDiffusionPipeline
7
+ from PIL import Image
8
+ from glob import glob
9
+ from typing import Optional
10
 
11
+ from tdd_svd_scheduler import TDDSVDStochasticIterativeScheduler
12
+ from utils import load_lora_weights, save_video
13
 
14
+ # LOCAL = True
15
+ LOCAL = False
16
+
17
+ if LOCAL:
18
+ svd_path = '/share2/duanyuxuan/diff_playground/diffusers_models/stable-video-diffusion-img2vid-xt-1-1'
19
+ lora_file_path = '/share2/duanyuxuan/diff_playground/SVD-TDD/svd-xt-1-1_tdd_lora_weights.safetensors'
20
  else:
21
+ svd_path = 'stabilityai/stable-video-diffusion-img2vid-xt-1-1'
22
+ lora_file_path = 'RED-AIGC/TDD/svd-xt-1-1_tdd_lora_weights.safetensors'
23
 
24
+ if torch.cuda.is_available():
25
+ noise_scheduler = TDDSVDStochasticIterativeScheduler(num_train_timesteps = 250, sigma_min = 0.002, sigma_max = 700.0, sigma_data = 1.0,
26
+ s_noise = 1.0, rho = 7, clip_denoised = False)
27
+
28
+ pipeline = StableVideoDiffusionPipeline.from_pretrained(svd_path, scheduler = noise_scheduler, torch_dtype = torch.float16, variant = "fp16").to('cuda')
29
+ load_lora_weights(pipeline.unet, lora_file_path)
30
 
31
+ max_64_bit_int = 2**63 - 1
 
32
 
33
+ @spaces.GPU
34
+ def sample(
35
+ image: Image,
36
+ seed: Optional[int] = 1,
37
+ randomize_seed: bool = False,
38
+ num_inference_steps: int = 4,
39
+ eta: float = 0.3,
40
+ min_guidance_scale: float = 1.0,
41
+ max_guidance_scale: float = 1.0,
42
+
43
+ fps: int = 7,
44
+ width: int = 512,
45
+ height: int = 512,
46
+ num_frames: int = 25,
47
+ motion_bucket_id: int = 127,
48
+ output_folder: str = "outputs_gradio",
49
+ ):
50
+ pipeline.scheduler.set_eta(eta)
51
 
52
  if randomize_seed:
53
+ seed = random.randint(0, max_64_bit_int)
54
+ generator = torch.manual_seed(seed)
55
+
56
+ os.makedirs(output_folder, exist_ok=True)
57
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
58
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
59
+
60
+ with torch.autocast("cuda"):
61
+ frames = pipeline(
62
+ image, height = height, width = width,
63
+ num_inference_steps = num_inference_steps,
64
+ min_guidance_scale = min_guidance_scale,
65
+ max_guidance_scale = max_guidance_scale,
66
+ num_frames = num_frames, fps = fps, motion_bucket_id = motion_bucket_id,
67
+ decode_chunk_size = 8,
68
+ noise_aug_strength = 0.02,
69
+ generator = generator,
70
+ ).frames[0]
71
+ save_video(frames, video_path, fps = fps, quality = 5.0)
72
+ torch.manual_seed(seed)
73
+
74
+ return video_path, seed
75
+
76
+
77
+ def preprocess_image(image, height = 512, width = 512):
78
+ image = image.convert('RGB')
79
+ if image.size[0] != image.size[1]:
80
+ image = tv.transforms.functional.pil_to_tensor(image)
81
+ image = tv.transforms.functional.center_crop(image, min(image.shape[-2:]))
82
+ image = tv.transforms.functional.to_pil_image(image)
83
+ image = image.resize((width, height))
84
+ return image
85
+
86
+
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown(
89
+ """
90
+ # Stable Video Diffusion distilled by ✨Target-Driven Distillation✨
91
+
92
+ Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of *target timestep selection* and *decoupled guidance*, models distilled by TDD can generated highly detailed images with only a few steps.
93
+
94
+ Besides, TDD is also available for distilling video generation models. This space presents the TDD-distilled version of [SVD-xt 1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1).
95
+
96
+ [**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD)
97
+
98
+ The codes of this space are built on [AnimateLCM-SVD](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD) and we acknowledge their contribution.
99
+ """
100
+ )
101
+ with gr.Row():
102
+ with gr.Column():
103
+ image = gr.Image(label="Upload your image", type="pil")
104
+ generate_btn = gr.Button("Generate")
105
+ video = gr.Video()
106
+ with gr.Accordion("Options", open = True):
107
+ seed = gr.Slider(
108
+ label="Seed",
109
+ value=1,
110
+ randomize=False,
111
+ minimum=0,
112
+ maximum=max_64_bit_int,
113
+ step=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
116
+ min_guidance_scale = gr.Slider(
117
+ label="Min guidance scale",
118
+ info="min strength of classifier-free guidance",
119
+ value=1.0,
120
+ minimum=1.0,
121
+ maximum=1.5,
122
+ )
123
+ max_guidance_scale = gr.Slider(
124
+ label="Max guidance scale",
125
+ info="max strength of classifier-free guidance, it should not be less than Min guidance scale",
126
+ value=1.0,
127
+ minimum=1.0,
128
+ maximum=3.0,
129
+ )
130
+ num_inference_steps = gr.Slider(
131
+ label="Num inference steps",
132
+ info="steps for inference",
133
+ value=4,
134
+ minimum=4,
135
+ maximum=8,
136
+ step=1,
137
+ )
138
+ eta = gr.Slider(
139
+ label = "Eta",
140
+ info = "the value of gamma in gamma-sampling",
141
+ value = 0.3,
142
+ minimum = 0.0,
143
+ maximum = 1.0,
144
+ step = 0.1,
145
+ )
146
+
147
+ image.upload(fn = preprocess_image, inputs = image, outputs = image, queue = False)
148
+ generate_btn.click(
149
+ fn = sample,
150
+ inputs = [
151
+ image,
152
+ seed,
153
+ randomize_seed,
154
+ num_inference_steps,
155
+ eta,
156
+ min_guidance_scale,
157
+ max_guidance_scale,
158
+ ],
159
+ outputs = [video, seed],
160
+ api_name = "video",
161
  )
162
+ # safetensors_dropdown.change(fn=model_select, inputs=safetensors_dropdown)
163
+
164
+ # gr.Examples(
165
+ # examples=[
166
+ # ["examples/ipadapter_cat.jpg"],
167
+ # ],
168
+ # inputs=[image],
169
+ # outputs=[video, seed],
170
+ # fn=sample,
171
+ # cache_examples=True,
172
+ # )
173
 
174
+ if __name__ == "__main__":
175
+ if LOCAL:
176
+ demo.queue().launch(share=True, server_name='0.0.0.0')
177
+ else:
178
+ demo.queue(api_open=False).launch(show_api=False)
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  accelerate
2
  diffusers
3
- invisible_watermark
4
  torch
 
5
  transformers
6
  xformers
 
1
  accelerate
2
  diffusers
 
3
  torch
4
+ torchvision
5
  transformers
6
  xformers
tdd_svd_scheduler.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.utils import BaseOutput, logging
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class TDDSVDStochasticIterativeSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class TDDSVDStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Multistep and onestep sampling for consistency models.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 40):
53
+ The number of diffusion steps to train the model.
54
+ sigma_min (`float`, defaults to 0.002):
55
+ Minimum noise magnitude in the sigma schedule. Defaults to 0.002 from the original implementation.
56
+ sigma_max (`float`, defaults to 80.0):
57
+ Maximum noise magnitude in the sigma schedule. Defaults to 80.0 from the original implementation.
58
+ sigma_data (`float`, defaults to 0.5):
59
+ The standard deviation of the data distribution from the EDM
60
+ [paper](https://huggingface.co/papers/2206.00364). Defaults to 0.5 from the original implementation.
61
+ s_noise (`float`, defaults to 1.0):
62
+ The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
63
+ 1.011]. Defaults to 1.0 from the original implementation.
64
+ rho (`float`, defaults to 7.0):
65
+ The parameter for calculating the Karras sigma schedule from the EDM
66
+ [paper](https://huggingface.co/papers/2206.00364). Defaults to 7.0 from the original implementation.
67
+ clip_denoised (`bool`, defaults to `True`):
68
+ Whether to clip the denoised outputs to `(-1, 1)`.
69
+ timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*):
70
+ An explicit timestep schedule that can be optionally specified. The timesteps are expected to be in
71
+ increasing order.
72
+ """
73
+
74
+ order = 1
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ num_train_timesteps: int = 40,
80
+ sigma_min: float = 0.002,
81
+ sigma_max: float = 80.0,
82
+ sigma_data: float = 0.5,
83
+ s_noise: float = 1.0,
84
+ rho: float = 7.0,
85
+ clip_denoised: bool = True,
86
+ eta: float = 0.3,
87
+ ):
88
+ # standard deviation of the initial noise distribution
89
+ self.init_noise_sigma = (sigma_max**2 + 1) ** 0.5
90
+ # self.init_noise_sigma = sigma_max
91
+
92
+ ramp = np.linspace(0, 1, num_train_timesteps)
93
+ sigmas = self._convert_to_karras(ramp)
94
+ sigmas = np.concatenate([sigmas, np.array([0])])
95
+ timesteps = self.sigma_to_t(sigmas)
96
+
97
+ # setable values
98
+ self.num_inference_steps = None
99
+ self.sigmas = torch.from_numpy(sigmas)
100
+ self.timesteps = torch.from_numpy(timesteps)
101
+ self.custom_timesteps = False
102
+ self.is_scale_input_called = False
103
+ self._step_index = None
104
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
105
+
106
+ self.set_eta(eta)
107
+ self.original_timesteps = self.timesteps.clone()
108
+ self.original_sigmas = self.sigmas.clone()
109
+
110
+
111
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
112
+ if schedule_timesteps is None:
113
+ schedule_timesteps = self.timesteps
114
+
115
+ indices = (schedule_timesteps == timestep).nonzero()
116
+ return indices.item()
117
+
118
+ @property
119
+ def step_index(self):
120
+ """
121
+ The index counter for current timestep. It will increae 1 after each scheduler step.
122
+ """
123
+ return self._step_index
124
+
125
+ def scale_model_input(
126
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
127
+ ) -> torch.FloatTensor:
128
+ """
129
+ Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`.
130
+
131
+ Args:
132
+ sample (`torch.FloatTensor`):
133
+ The input sample.
134
+ timestep (`float` or `torch.FloatTensor`):
135
+ The current timestep in the diffusion chain.
136
+
137
+ Returns:
138
+ `torch.FloatTensor`:
139
+ A scaled input sample.
140
+ """
141
+ # Get sigma corresponding to timestep
142
+ if self.step_index is None:
143
+ self._init_step_index(timestep)
144
+
145
+ sigma = self.sigmas[self.step_index]
146
+ sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
147
+
148
+ self.is_scale_input_called = True
149
+ return sample
150
+
151
+ # def _sigma_to_t(self, sigma, log_sigmas):
152
+ # # get log sigma
153
+ # log_sigma = np.log(np.maximum(sigma, 1e-10))
154
+
155
+ # # get distribution
156
+ # dists = log_sigma - log_sigmas[:, np.newaxis]
157
+
158
+ # # get sigmas range
159
+ # low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
160
+ # high_idx = low_idx + 1
161
+
162
+ # low = log_sigmas[low_idx]
163
+ # high = log_sigmas[high_idx]
164
+
165
+ # # interpolate sigmas
166
+ # w = (low - log_sigma) / (low - high)
167
+ # w = np.clip(w, 0, 1)
168
+
169
+ # # transform interpolation to time range
170
+ # t = (1 - w) * low_idx + w * high_idx
171
+ # t = t.reshape(sigma.shape)
172
+ # return t
173
+
174
+ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
175
+ """
176
+ Gets scaled timesteps from the Karras sigmas for input to the consistency model.
177
+
178
+ Args:
179
+ sigmas (`float` or `np.ndarray`):
180
+ A single Karras sigma or an array of Karras sigmas.
181
+
182
+ Returns:
183
+ `float` or `np.ndarray`:
184
+ A scaled input timestep or scaled input timestep array.
185
+ """
186
+ if not isinstance(sigmas, np.ndarray):
187
+ sigmas = np.array(sigmas, dtype=np.float64)
188
+
189
+ timesteps = 0.25 * np.log(sigmas + 1e-44)
190
+
191
+ return timesteps
192
+
193
+ def set_timesteps(
194
+ self,
195
+ num_inference_steps: Optional[int] = None,
196
+ device: Union[str, torch.device] = None,
197
+ timesteps: Optional[List[int]] = None,
198
+ ):
199
+ """
200
+ Sets the timesteps used for the diffusion chain (to be run before inference).
201
+
202
+ Args:
203
+ num_inference_steps (`int`):
204
+ The number of diffusion steps used when generating samples with a pre-trained model.
205
+ device (`str` or `torch.device`, *optional*):
206
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
207
+ timesteps (`List[int]`, *optional*):
208
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
209
+ timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
210
+ `num_inference_steps` must be `None`.
211
+ """
212
+ if num_inference_steps is None and timesteps is None:
213
+ raise ValueError(
214
+ "Exactly one of `num_inference_steps` or `timesteps` must be supplied."
215
+ )
216
+
217
+ if num_inference_steps is not None and timesteps is not None:
218
+ raise ValueError(
219
+ "Can only pass one of `num_inference_steps` or `timesteps`."
220
+ )
221
+
222
+ # Follow DDPMScheduler custom timesteps logic
223
+ if timesteps is not None:
224
+ for i in range(1, len(timesteps)):
225
+ if timesteps[i] >= timesteps[i - 1]:
226
+ raise ValueError("`timesteps` must be in descending order.")
227
+
228
+ if timesteps[0] >= self.config.num_train_timesteps:
229
+ raise ValueError(
230
+ f"`timesteps` must start before `self.config.train_timesteps`:"
231
+ f" {self.config.num_train_timesteps}."
232
+ )
233
+
234
+ timesteps = np.array(timesteps, dtype=np.int64)
235
+ self.custom_timesteps = True
236
+ else:
237
+ if num_inference_steps > self.config.num_train_timesteps:
238
+ raise ValueError(
239
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
240
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
241
+ f" maximal {self.config.num_train_timesteps} timesteps."
242
+ )
243
+
244
+ self.num_inference_steps = num_inference_steps
245
+
246
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
247
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
248
+ self.custom_timesteps = False
249
+
250
+ self.original_indices = timesteps
251
+ # Map timesteps to Karras sigmas directly for multistep sampling
252
+ # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675
253
+ num_train_timesteps = self.config.num_train_timesteps
254
+ ramp = timesteps.copy()
255
+ ramp = ramp / (num_train_timesteps - 1)
256
+ sigmas = self._convert_to_karras(ramp)
257
+ timesteps = self.sigma_to_t(sigmas)
258
+
259
+ sigmas = np.concatenate([sigmas, [0]]).astype(np.float32)
260
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
261
+
262
+ if str(device).startswith("mps"):
263
+ # mps does not support float64
264
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
265
+ else:
266
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
267
+
268
+ self._step_index = None
269
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
270
+
271
+ # Modified _convert_to_karras implementation that takes in ramp as argument
272
+ def _convert_to_karras(self, ramp):
273
+ """Constructs the noise schedule of Karras et al. (2022)."""
274
+
275
+ sigma_min: float = self.config.sigma_min
276
+ sigma_max: float = self.config.sigma_max
277
+
278
+ rho = self.config.rho
279
+ min_inv_rho = sigma_min ** (1 / rho)
280
+ max_inv_rho = sigma_max ** (1 / rho)
281
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
282
+ return sigmas
283
+
284
+ def get_scalings(self, sigma):
285
+ sigma_data = self.config.sigma_data
286
+
287
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
288
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
289
+ return c_skip, c_out
290
+
291
+ def get_scalings_for_boundary_condition(self, sigma):
292
+ """
293
+ Gets the scalings used in the consistency model parameterization (from Appendix C of the
294
+ [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
295
+
296
+ <Tip>
297
+
298
+ `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
299
+
300
+ </Tip>
301
+
302
+ Args:
303
+ sigma (`torch.FloatTensor`):
304
+ The current sigma in the Karras sigma schedule.
305
+
306
+ Returns:
307
+ `tuple`:
308
+ A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
309
+ (which weights the consistency model output) is the second element.
310
+ """
311
+ sigma_min = self.config.sigma_min
312
+ sigma_data = self.config.sigma_data
313
+
314
+ c_skip = sigma_data**2 / ((sigma) ** 2 + sigma_data**2)
315
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
316
+ return c_skip, c_out
317
+
318
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
319
+ def _init_step_index(self, timestep):
320
+ if isinstance(timestep, torch.Tensor):
321
+ timestep = timestep.to(self.timesteps.device)
322
+
323
+ index_candidates = (self.timesteps == timestep).nonzero()
324
+
325
+ # The sigma index that is taken for the **very** first `step`
326
+ # is always the second index (or the last index if there is only 1)
327
+ # This way we can ensure we don't accidentally skip a sigma in
328
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
329
+ if len(index_candidates) > 1:
330
+ step_index = index_candidates[1]
331
+ else:
332
+ step_index = index_candidates[0]
333
+
334
+ self._step_index = step_index.item()
335
+
336
+ def step(
337
+ self,
338
+ model_output: torch.FloatTensor,
339
+ timestep: Union[float, torch.FloatTensor],
340
+ sample: torch.FloatTensor,
341
+ generator: Optional[torch.Generator] = None,
342
+ return_dict: bool = True,
343
+ ) -> Union[TDDSVDStochasticIterativeSchedulerOutput, Tuple]:
344
+ """
345
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
346
+ process from the learned model outputs (most often the predicted noise).
347
+
348
+ Args:
349
+ model_output (`torch.FloatTensor`):
350
+ The direct output from the learned diffusion model.
351
+ timestep (`float`):
352
+ The current timestep in the diffusion chain.
353
+ sample (`torch.FloatTensor`):
354
+ A current instance of a sample created by the diffusion process.
355
+ generator (`torch.Generator`, *optional*):
356
+ A random number generator.
357
+ return_dict (`bool`, *optional*, defaults to `True`):
358
+ Whether or not to return a
359
+ [`~schedulers.scheduling_consistency_models.TDDSVDStochasticIterativeSchedulerOutput`] or `tuple`.
360
+
361
+ Returns:
362
+ [`~schedulers.scheduling_consistency_models.TDDSVDStochasticIterativeSchedulerOutput`] or `tuple`:
363
+ If return_dict is `True`,
364
+ [`~schedulers.scheduling_consistency_models.TDDSVDStochasticIterativeSchedulerOutput`] is returned,
365
+ otherwise a tuple is returned where the first element is the sample tensor.
366
+ """
367
+
368
+ if (
369
+ isinstance(timestep, int)
370
+ or isinstance(timestep, torch.IntTensor)
371
+ or isinstance(timestep, torch.LongTensor)
372
+ ):
373
+ raise ValueError(
374
+ (
375
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
376
+ f" `{self.__class__}.step()` is not supported. Make sure to pass"
377
+ " one of the `scheduler.timesteps` as a timestep."
378
+ ),
379
+ )
380
+
381
+ if not self.is_scale_input_called:
382
+ logger.warning(
383
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
384
+ "See `StableDiffusionPipeline` for a usage example."
385
+ )
386
+
387
+ sigma_min = self.config.sigma_min
388
+ sigma_max = self.config.sigma_max
389
+
390
+ if self.step_index is None:
391
+ self._init_step_index(timestep)
392
+
393
+ # sigma_next corresponds to next_t in original implementation
394
+ next_step_index = self.step_index + 1
395
+
396
+ sigma = self.sigmas[self.step_index]
397
+ if next_step_index < len(self.sigmas):
398
+ sigma_next = self.sigmas[next_step_index]
399
+ else:
400
+ # Set sigma_next to sigma_min
401
+ sigma_next = self.sigmas[-1]
402
+
403
+ # Get scalings for boundary conditions
404
+ c_skip, c_out = self.get_scalings_for_boundary_condition(sigma)
405
+
406
+ if next_step_index < len(self.original_indices):
407
+ next_step_original_index = self.original_indices[next_step_index]
408
+ step_s_original_index = int(next_step_original_index + self.eta * (self.config.num_train_timesteps - 1 - next_step_original_index))
409
+ sigma_s = self.original_sigmas[step_s_original_index]
410
+ else:
411
+ sigma_s = self.sigmas[-1]
412
+
413
+ # 1. Denoise model output using boundary conditions
414
+ denoised = c_out * model_output + c_skip * sample
415
+ if self.config.clip_denoised:
416
+ denoised = denoised.clamp(-1, 1)
417
+
418
+ d = (sample - denoised) / sigma
419
+ sample_s = sample + d * (sigma_s - sigma)
420
+
421
+ # 2. Sample z ~ N(0, s_noise^2 * I)
422
+ # Noise is not used for onestep sampling.
423
+ if len(self.timesteps) > 1:
424
+ noise = randn_tensor(
425
+ model_output.shape,
426
+ dtype=model_output.dtype,
427
+ device=model_output.device,
428
+ generator=generator,
429
+ )
430
+ else:
431
+ noise = torch.zeros_like(model_output)
432
+ z = noise * self.config.s_noise
433
+
434
+ sigma_hat = sigma_next.clamp(min = 0, max = sigma_max)
435
+ # sigma_hat = sigma_next.clamp(min = sigma_min, max = sigma_max)
436
+
437
+ # print("denoise currently")
438
+ # print(sigma_hat)
439
+
440
+ # origin
441
+ # prev_sample = denoised + z * sigma_hat
442
+ prev_sample = sample_s + z * (sigma_hat - sigma_s)
443
+
444
+ # upon completion increase step index by one
445
+ self._step_index += 1
446
+
447
+ if not return_dict:
448
+ return (prev_sample,)
449
+
450
+ return TDDSVDStochasticIterativeSchedulerOutput(prev_sample=prev_sample)
451
+
452
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
453
+ def add_noise(
454
+ self,
455
+ original_samples: torch.FloatTensor,
456
+ noise: torch.FloatTensor,
457
+ timesteps: torch.FloatTensor,
458
+ ) -> torch.FloatTensor:
459
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
460
+ sigmas = self.sigmas.to(
461
+ device=original_samples.device, dtype=original_samples.dtype
462
+ )
463
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
464
+ # mps does not support float64
465
+ schedule_timesteps = self.timesteps.to(
466
+ original_samples.device, dtype=torch.float32
467
+ )
468
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
469
+ else:
470
+ schedule_timesteps = self.timesteps.to(original_samples.device)
471
+ timesteps = timesteps.to(original_samples.device)
472
+
473
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
474
+
475
+ sigma = sigmas[step_indices].flatten()
476
+ while len(sigma.shape) < len(original_samples.shape):
477
+ sigma = sigma.unsqueeze(-1)
478
+
479
+ noisy_samples = original_samples + noise * sigma
480
+ return noisy_samples
481
+
482
+ def __len__(self):
483
+ return self.config.num_train_timesteps
484
+
485
+ def set_eta(self, eta: float):
486
+ assert 0.0 <= eta <= 1.0
487
+ self.eta = eta
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.loaders.lora import LoraLoaderMixin
3
+ from typing import Dict, Union
4
+ import numpy as np
5
+ import imageio
6
+
7
+ def load_lora_weights(unet, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name = None, **kwargs):
8
+ # if a dict is passed, copy it instead of modifying it inplace
9
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
10
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
11
+
12
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
13
+ state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
14
+
15
+ # remove prefix if not removed when saved
16
+ state_dict = {name.replace('base_model.model.', ''): param for name, param in state_dict.items()}
17
+
18
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
19
+ if not is_correct_format:
20
+ raise ValueError("Invalid LoRA checkpoint.")
21
+
22
+ low_cpu_mem_usage = True
23
+
24
+ LoraLoaderMixin.load_lora_into_unet(
25
+ state_dict,
26
+ network_alphas=network_alphas,
27
+ unet = unet,
28
+ low_cpu_mem_usage=low_cpu_mem_usage,
29
+ adapter_name=adapter_name,
30
+ )
31
+
32
+ def save_video(frames, save_path, fps, quality=9):
33
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality)
34
+ for frame in frames:
35
+ frame = np.array(frame)
36
+ writer.append_data(frame)
37
+ writer.close()