Upload folder using huggingface_hub
Browse files- .gitignore +10 -0
- .gradio/certificate.pem +31 -0
- .python-version +1 -0
- README.md +16 -6
- app.py +80 -0
- ckpt_1000k.pkl +3 -0
- config.yaml +36 -0
- model.py +294 -0
- pyproject.toml +15 -0
- requirements.txt +2 -0
- sample.py +93 -0
- train.py +221 -0
- train_data_samples.png +0 -0
- uv.lock +0 -0
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python-generated files
|
2 |
+
__pycache__/
|
3 |
+
*.py[oc]
|
4 |
+
build/
|
5 |
+
dist/
|
6 |
+
wheels/
|
7 |
+
*.egg-info
|
8 |
+
|
9 |
+
# Virtual environments
|
10 |
+
.venv
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.11
|
README.md
CHANGED
@@ -1,12 +1,22 @@
|
|
1 |
---
|
2 |
title: AnimeFlow
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.9.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
1 |
---
|
2 |
title: AnimeFlow
|
3 |
+
app_file: app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.9.1
|
|
|
|
|
6 |
---
|
7 |
+
# Anime Flow
|
8 |
+
|
9 |
+
A simple implementation of conditional flow matching for generating anime faces. The model architecture closely follows the Diffusion Transformer model (DiT) found at https://github.com/facebookresearch/DiT/blob/main/models.py.
|
10 |
+
|
11 |
+
## Train model
|
12 |
+
|
13 |
+
```bash
|
14 |
+
pip install uv
|
15 |
+
uv run train.py --config ./config.yaml
|
16 |
+
```
|
17 |
+
|
18 |
+
## Generate images
|
19 |
|
20 |
+
```bash
|
21 |
+
uv run sample.py --ckpt ./state_1000000.ckpt --config ./config.yaml --seed 0
|
22 |
+
```
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import jax
|
3 |
+
import jax.numpy as jnp
|
4 |
+
from jax.experimental import ode
|
5 |
+
import yaml
|
6 |
+
from flax import nnx
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
|
10 |
+
def load_model(config_path, ckpt_path):
|
11 |
+
# Load config
|
12 |
+
with open(config_path) as f:
|
13 |
+
config = yaml.safe_load(f)
|
14 |
+
|
15 |
+
# Load model and state
|
16 |
+
with open(ckpt_path, "rb") as f:
|
17 |
+
leaves = pickle.load(f)
|
18 |
+
|
19 |
+
from model import DiT, DiTConfig
|
20 |
+
|
21 |
+
dit_config = DiTConfig(**config["model"])
|
22 |
+
model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
|
23 |
+
graphdef, state = nnx.split(model)
|
24 |
+
_, treedef = jax.tree_util.tree_flatten(state)
|
25 |
+
state = jax.tree_util.tree_unflatten(treedef, leaves)
|
26 |
+
return graphdef, state
|
27 |
+
|
28 |
+
|
29 |
+
@jax.jit
|
30 |
+
def sample_images(graphdef, state, x0, t):
|
31 |
+
flow = nnx.merge(graphdef, state)
|
32 |
+
|
33 |
+
def flow_fn(y, t):
|
34 |
+
o = flow(y, t[None])
|
35 |
+
return o
|
36 |
+
|
37 |
+
o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
|
38 |
+
o = jnp.clip(o[-1], 0, 1)
|
39 |
+
return o
|
40 |
+
|
41 |
+
|
42 |
+
def generate_grid(seed, noise_level):
|
43 |
+
# Load model (doing this inside function to avoid global variables)
|
44 |
+
graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
|
45 |
+
|
46 |
+
t = jnp.linspace(0, 1, 2)
|
47 |
+
x0 = jax.random.truncated_normal(
|
48 |
+
nnx.Rngs(seed)(),
|
49 |
+
-noise_level,
|
50 |
+
noise_level,
|
51 |
+
shape=(16, 64, 64, 3),
|
52 |
+
dtype=jnp.float32,
|
53 |
+
)
|
54 |
+
|
55 |
+
# Generate images
|
56 |
+
images = sample_images(graphdef, state, x0, t)
|
57 |
+
|
58 |
+
# Convert to grid of 4x4
|
59 |
+
rows = []
|
60 |
+
for i in range(4):
|
61 |
+
row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1)
|
62 |
+
rows.append(row)
|
63 |
+
grid = jnp.concatenate(rows, axis=0)
|
64 |
+
|
65 |
+
return jax.device_get(grid)
|
66 |
+
|
67 |
+
# Create Gradio interface
|
68 |
+
demo = gr.Interface(
|
69 |
+
fn=generate_grid,
|
70 |
+
inputs=[
|
71 |
+
gr.Number(label="Random Seed", value=0, precision=0),
|
72 |
+
gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"),
|
73 |
+
],
|
74 |
+
outputs=gr.Image(label="Generated Images"),
|
75 |
+
title="Anime Flow Generation Demo",
|
76 |
+
description="Generate a 4x4 grid of anime faces using Anime Flow",
|
77 |
+
)
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
demo.launch(share=True)
|
ckpt_1000k.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38793cde6fa2f5f32134d5f281141cd81fe257d6e160d975ab2a3c6c4559f6c2
|
3 |
+
size 147069817
|
config.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model architecture
|
2 |
+
model:
|
3 |
+
input_dim: 3 # RGB images
|
4 |
+
hidden_dim: 512
|
5 |
+
num_blocks: 8
|
6 |
+
num_heads: 8
|
7 |
+
patch_size: 8
|
8 |
+
patch_stride: 4
|
9 |
+
time_freq_dim: 256
|
10 |
+
time_max_period: 1024
|
11 |
+
mlp_ratio: 4
|
12 |
+
use_bias: false
|
13 |
+
padding: "SAME"
|
14 |
+
pos_embed_cls_token: false
|
15 |
+
pos_embed_extra_tokens: 0
|
16 |
+
|
17 |
+
# Training parameters
|
18 |
+
training:
|
19 |
+
learning_rate: 1.0e-4
|
20 |
+
batch_size: 128
|
21 |
+
num_steps: 1_000_000
|
22 |
+
warmup_pct: 0.01
|
23 |
+
weight_decay: 0.0
|
24 |
+
grad_clip_norm: 100.0
|
25 |
+
|
26 |
+
# Checkpointing and logging
|
27 |
+
checkpointing:
|
28 |
+
log_every: 1_000
|
29 |
+
plot_every: 10_000
|
30 |
+
save_every: 10_000
|
31 |
+
resume_from_checkpoint: null
|
32 |
+
|
33 |
+
# Data
|
34 |
+
data:
|
35 |
+
train_split: 0.9 # 90% for training, 10% for testing
|
36 |
+
random_seed: 42
|
model.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import numpy as np
|
7 |
+
from flax import nnx
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class DiTConfig:
|
12 |
+
input_dim: int
|
13 |
+
hidden_dim: int
|
14 |
+
num_blocks: int
|
15 |
+
num_heads: int
|
16 |
+
patch_size: int
|
17 |
+
patch_stride: int
|
18 |
+
time_freq_dim: int
|
19 |
+
time_max_period: int
|
20 |
+
mlp_ratio: int
|
21 |
+
use_bias: bool
|
22 |
+
padding: str
|
23 |
+
pos_embed_cls_token: bool
|
24 |
+
pos_embed_extra_tokens: int
|
25 |
+
|
26 |
+
|
27 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
28 |
+
"""
|
29 |
+
grid_size: int of the grid height and width
|
30 |
+
return:
|
31 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
32 |
+
"""
|
33 |
+
grid_h = jnp.arange(grid_size, dtype=jnp.float32)
|
34 |
+
grid_w = jnp.arange(grid_size, dtype=jnp.float32)
|
35 |
+
grid = jnp.meshgrid(grid_w, grid_h) # here w goes first
|
36 |
+
grid = jnp.stack(grid, axis=0)
|
37 |
+
|
38 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
39 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
40 |
+
if cls_token and extra_tokens > 0:
|
41 |
+
pos_embed = np.concatenate(
|
42 |
+
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
43 |
+
)
|
44 |
+
return pos_embed
|
45 |
+
|
46 |
+
|
47 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
48 |
+
assert embed_dim % 2 == 0
|
49 |
+
|
50 |
+
# use half of dimensions to encode grid_h
|
51 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
52 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
53 |
+
|
54 |
+
emb = jnp.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
55 |
+
return emb
|
56 |
+
|
57 |
+
|
58 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
59 |
+
"""
|
60 |
+
embed_dim: output dimension for each position
|
61 |
+
pos: a list of positions to be encoded: size (M,)
|
62 |
+
out: (M, D)
|
63 |
+
"""
|
64 |
+
assert embed_dim % 2 == 0
|
65 |
+
omega = jnp.arange(embed_dim // 2, dtype=np.float32)
|
66 |
+
omega /= embed_dim / 2.0
|
67 |
+
omega = 1.0 / 16**omega # (D/2,)
|
68 |
+
|
69 |
+
pos = pos.reshape(-1) # (M,)
|
70 |
+
out = jnp.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
71 |
+
|
72 |
+
emb_sin = jnp.sin(out) # (M, D/2)
|
73 |
+
emb_cos = jnp.cos(out) # (M, D/2)
|
74 |
+
|
75 |
+
emb = jnp.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
76 |
+
return emb
|
77 |
+
|
78 |
+
|
79 |
+
class PatchEmbedding(nnx.Module):
|
80 |
+
"""Patch embedding module."""
|
81 |
+
|
82 |
+
def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
|
83 |
+
super().__init__()
|
84 |
+
self.cnn = nnx.Conv(
|
85 |
+
config.input_dim,
|
86 |
+
config.hidden_dim,
|
87 |
+
kernel_size=(config.patch_size, config.patch_size),
|
88 |
+
strides=(config.patch_stride, config.patch_stride),
|
89 |
+
padding=config.padding,
|
90 |
+
use_bias=config.use_bias,
|
91 |
+
rngs=rngs,
|
92 |
+
)
|
93 |
+
|
94 |
+
def __call__(self, x):
|
95 |
+
return self.cnn(x)
|
96 |
+
|
97 |
+
|
98 |
+
class TimeEmbedding(nnx.Module):
|
99 |
+
"""Time embedding module."""
|
100 |
+
|
101 |
+
def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
|
102 |
+
super().__init__()
|
103 |
+
self.freq_dim = config.time_freq_dim
|
104 |
+
self.max_period = config.time_max_period
|
105 |
+
self.fc1 = nnx.Linear(
|
106 |
+
self.freq_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs
|
107 |
+
)
|
108 |
+
self.fc2 = nnx.Linear(
|
109 |
+
config.hidden_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs
|
110 |
+
)
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def cosine_embedding(t, dim, max_period):
|
114 |
+
assert dim % 2 == 0
|
115 |
+
half = dim // 2
|
116 |
+
freqs = jnp.exp(
|
117 |
+
-math.log(max_period)
|
118 |
+
* jnp.arange(start=0, stop=half, dtype=jnp.float32)
|
119 |
+
/ half
|
120 |
+
)
|
121 |
+
args = t[:, None] * freqs[None, :] * 1024
|
122 |
+
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
|
123 |
+
return embedding
|
124 |
+
|
125 |
+
def __call__(self, t):
|
126 |
+
t_freq = self.cosine_embedding(t, self.freq_dim, self.max_period)
|
127 |
+
t_embed = self.fc1(t_freq)
|
128 |
+
t_embed = nnx.silu(t_embed)
|
129 |
+
t_embed = self.fc2(t_embed)
|
130 |
+
return t_embed
|
131 |
+
|
132 |
+
|
133 |
+
class MLP(nnx.Module):
|
134 |
+
"""MLP module."""
|
135 |
+
|
136 |
+
def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
|
137 |
+
super().__init__()
|
138 |
+
self.fc1 = nnx.Linear(
|
139 |
+
config.hidden_dim,
|
140 |
+
config.hidden_dim * config.mlp_ratio,
|
141 |
+
use_bias=config.use_bias,
|
142 |
+
rngs=rngs,
|
143 |
+
)
|
144 |
+
self.fc2 = nnx.Linear(
|
145 |
+
config.hidden_dim * config.mlp_ratio,
|
146 |
+
config.hidden_dim,
|
147 |
+
use_bias=config.use_bias,
|
148 |
+
rngs=rngs,
|
149 |
+
)
|
150 |
+
|
151 |
+
def __call__(self, x):
|
152 |
+
x = self.fc1(x)
|
153 |
+
x = nnx.silu(x)
|
154 |
+
x = self.fc2(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
|
158 |
+
class SelfAttention(nnx.Module):
|
159 |
+
"""Self attention module."""
|
160 |
+
|
161 |
+
def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
|
162 |
+
super().__init__()
|
163 |
+
self.fc = nnx.Linear(
|
164 |
+
config.hidden_dim,
|
165 |
+
3 * config.hidden_dim,
|
166 |
+
use_bias=config.use_bias,
|
167 |
+
rngs=rngs,
|
168 |
+
)
|
169 |
+
self.heads = config.num_heads
|
170 |
+
self.head_dim = config.hidden_dim // config.num_heads
|
171 |
+
assert config.hidden_dim % config.num_heads == 0
|
172 |
+
self.q_norm = nnx.RMSNorm(num_features=self.head_dim, use_scale=True, rngs=rngs)
|
173 |
+
self.k_norm = nnx.RMSNorm(num_features=self.head_dim, use_scale=True, rngs=rngs)
|
174 |
+
|
175 |
+
def __call__(self, x):
|
176 |
+
q, k, v = jnp.split(self.fc(x), 3, axis=-1)
|
177 |
+
# reshape q, k v, to N, T, H, D
|
178 |
+
q = q.reshape(q.shape[0], q.shape[1], self.heads, self.head_dim)
|
179 |
+
k = k.reshape(k.shape[0], k.shape[1], self.heads, self.head_dim)
|
180 |
+
v = v.reshape(v.shape[0], v.shape[1], self.heads, self.head_dim)
|
181 |
+
q = self.q_norm(q)
|
182 |
+
k = self.k_norm(k)
|
183 |
+
o = jax.nn.dot_product_attention(q, k, v, is_causal=False)
|
184 |
+
o = o.reshape(o.shape[0], o.shape[1], self.heads * self.head_dim)
|
185 |
+
return o
|
186 |
+
|
187 |
+
|
188 |
+
def modulate(x, shift, scale):
|
189 |
+
return x * (1 + scale[:, None, :]) + shift[:, None, :]
|
190 |
+
|
191 |
+
|
192 |
+
class TransformerBlock(nnx.Module):
|
193 |
+
"""Transformer block."""
|
194 |
+
|
195 |
+
def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
|
196 |
+
super().__init__()
|
197 |
+
self.norm1 = nnx.RMSNorm(
|
198 |
+
num_features=config.hidden_dim, use_scale=False, rngs=rngs
|
199 |
+
)
|
200 |
+
self.attn = SelfAttention(config, rngs=rngs)
|
201 |
+
self.norm2 = nnx.RMSNorm(
|
202 |
+
num_features=config.hidden_dim, use_scale=False, rngs=rngs
|
203 |
+
)
|
204 |
+
self.mlp = MLP(config, rngs=rngs)
|
205 |
+
self.adalm_modulation = nnx.Sequential(
|
206 |
+
nnx.silu,
|
207 |
+
nnx.Linear(
|
208 |
+
config.hidden_dim,
|
209 |
+
6 * config.hidden_dim,
|
210 |
+
use_bias=config.use_bias,
|
211 |
+
rngs=rngs,
|
212 |
+
),
|
213 |
+
)
|
214 |
+
|
215 |
+
def __call__(self, x, c):
|
216 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(
|
217 |
+
self.adalm_modulation(c), 6, axis=-1
|
218 |
+
)
|
219 |
+
attn_x = self.norm1(x)
|
220 |
+
attn_x = modulate(attn_x, shift_msa, scale_msa)
|
221 |
+
x = x + gate_msa[:, None, :] * self.attn(attn_x)
|
222 |
+
mlp_x = self.norm2(x)
|
223 |
+
mlp_x = modulate(mlp_x, shift_mlp, scale_mlp)
|
224 |
+
x = x + gate_mlp[:, None, :] * self.mlp(mlp_x)
|
225 |
+
return x
|
226 |
+
|
227 |
+
|
228 |
+
class FinalLayer(nnx.Module):
|
229 |
+
"""Final layer."""
|
230 |
+
|
231 |
+
def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
|
232 |
+
super().__init__()
|
233 |
+
self.norm = nnx.RMSNorm(
|
234 |
+
num_features=config.hidden_dim, use_scale=False, rngs=rngs
|
235 |
+
)
|
236 |
+
self.conv = nnx.ConvTranspose(
|
237 |
+
config.hidden_dim,
|
238 |
+
config.input_dim,
|
239 |
+
kernel_size=(config.patch_size, config.patch_size),
|
240 |
+
strides=(config.patch_stride, config.patch_stride),
|
241 |
+
padding=config.padding,
|
242 |
+
use_bias=config.use_bias,
|
243 |
+
rngs=rngs,
|
244 |
+
)
|
245 |
+
self.adalm_modulation = nnx.Sequential(
|
246 |
+
nnx.silu,
|
247 |
+
nnx.Linear(
|
248 |
+
config.hidden_dim,
|
249 |
+
2 * config.hidden_dim,
|
250 |
+
use_bias=config.use_bias,
|
251 |
+
rngs=rngs,
|
252 |
+
),
|
253 |
+
)
|
254 |
+
|
255 |
+
def __call__(self, x, c):
|
256 |
+
shift, scale = jnp.split(self.adalm_modulation(c), 2, axis=-1)
|
257 |
+
x = self.norm(x)
|
258 |
+
x = modulate(x, shift, scale)
|
259 |
+
# reshape to N, H, W, C
|
260 |
+
H = W = int(x.shape[1] ** 0.5)
|
261 |
+
x = x.reshape(x.shape[0], H, W, x.shape[-1])
|
262 |
+
x = self.conv(x)
|
263 |
+
return x
|
264 |
+
|
265 |
+
|
266 |
+
class DiT(nnx.Module):
|
267 |
+
"""Diffusion Transformer"""
|
268 |
+
|
269 |
+
def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
|
270 |
+
super().__init__()
|
271 |
+
self.config = config
|
272 |
+
self.time_embedding = TimeEmbedding(config, rngs=rngs)
|
273 |
+
self.patch_embedding = PatchEmbedding(config, rngs=rngs)
|
274 |
+
self.blocks = [
|
275 |
+
TransformerBlock(config, rngs=rngs) for _ in range(config.num_blocks)
|
276 |
+
]
|
277 |
+
self.final_layer = FinalLayer(config, rngs=rngs)
|
278 |
+
|
279 |
+
def __call__(self, xt, t):
|
280 |
+
t = self.time_embedding(t)
|
281 |
+
x = self.patch_embedding(xt)
|
282 |
+
N, H, W, D = x.shape
|
283 |
+
x = x.reshape(N, H * W, D)
|
284 |
+
x = x + get_2d_sincos_pos_embed(
|
285 |
+
D,
|
286 |
+
H,
|
287 |
+
cls_token=self.config.pos_embed_cls_token,
|
288 |
+
extra_tokens=self.config.pos_embed_extra_tokens,
|
289 |
+
).reshape(1, H * W, D)
|
290 |
+
c = t
|
291 |
+
for block in self.blocks:
|
292 |
+
x = block(x, c)
|
293 |
+
x = self.final_layer(x, c)
|
294 |
+
return x
|
pyproject.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "anime-flow"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Generate anime faces using conditional flow matching"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.11"
|
7 |
+
dependencies = [
|
8 |
+
"flax>=0.10.2",
|
9 |
+
"gradio>=5.9.1",
|
10 |
+
"jax[cuda12]>=0.4.38",
|
11 |
+
"kagglehub>=0.3.6",
|
12 |
+
"matplotlib>=3.10.0",
|
13 |
+
"pillow>=11.0.0",
|
14 |
+
"pot>=0.9.5",
|
15 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
jax[cuda12]
|
2 |
+
flax
|
sample.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generate images from trained model
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
import jax
|
9 |
+
import jax.numpy as jnp
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import yaml
|
12 |
+
from flax import nnx
|
13 |
+
from jax.experimental import ode
|
14 |
+
|
15 |
+
from model import DiT, DiTConfig
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument(
|
21 |
+
"--config", type=str, default="config.yaml", help="Path to config file"
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--ckpt", type=str, default=None, help="Path to checkpoint file"
|
25 |
+
)
|
26 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
27 |
+
return parser.parse_args()
|
28 |
+
|
29 |
+
|
30 |
+
def load_config(config_path):
|
31 |
+
with open(config_path) as f:
|
32 |
+
config = yaml.safe_load(f)
|
33 |
+
return config
|
34 |
+
|
35 |
+
|
36 |
+
@jax.jit
|
37 |
+
def sample_images(graphdef, state, rng):
|
38 |
+
flow = nnx.merge(graphdef, state)
|
39 |
+
|
40 |
+
def flow_fn(y, t):
|
41 |
+
o = flow(y, t[None])
|
42 |
+
return o
|
43 |
+
|
44 |
+
x = jax.random.normal(rng, shape=(16, 64, 64, 3), dtype=jnp.float32)
|
45 |
+
o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000))
|
46 |
+
o = jnp.clip(o[-1], 0, 1)
|
47 |
+
return o
|
48 |
+
|
49 |
+
|
50 |
+
def plot_new_images(graphdef, state, seed):
|
51 |
+
images = sample_images(graphdef, state, nnx.Rngs(seed)())
|
52 |
+
|
53 |
+
plt.figure(figsize=(2, 2))
|
54 |
+
for i in range(16):
|
55 |
+
plt.subplot(4, 4, i + 1)
|
56 |
+
plt.imshow(images[i])
|
57 |
+
plt.axis("off")
|
58 |
+
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
|
59 |
+
plt.savefig(f"samples.png")
|
60 |
+
plt.close()
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
args = parse_args()
|
65 |
+
config = load_config(args.config)
|
66 |
+
|
67 |
+
dit_config = DiTConfig(
|
68 |
+
input_dim=config["model"]["input_dim"],
|
69 |
+
hidden_dim=config["model"]["hidden_dim"],
|
70 |
+
num_blocks=config["model"]["num_blocks"],
|
71 |
+
num_heads=config["model"]["num_heads"],
|
72 |
+
patch_size=config["model"]["patch_size"],
|
73 |
+
patch_stride=config["model"]["patch_stride"],
|
74 |
+
time_freq_dim=config["model"]["time_freq_dim"],
|
75 |
+
time_max_period=config["model"]["time_max_period"],
|
76 |
+
mlp_ratio=config["model"]["mlp_ratio"],
|
77 |
+
use_bias=config["model"]["use_bias"],
|
78 |
+
padding=config["model"]["padding"],
|
79 |
+
pos_embed_cls_token=config["model"]["pos_embed_cls_token"],
|
80 |
+
pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"],
|
81 |
+
)
|
82 |
+
|
83 |
+
abstract_flow = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
|
84 |
+
graphdef, _ = nnx.split(abstract_flow)
|
85 |
+
with open(args.ckpt, "rb") as f:
|
86 |
+
state = pickle.load(f, fix_imports=True)
|
87 |
+
if "time_embedding" not in state:
|
88 |
+
state = state[0]
|
89 |
+
plot_new_images(graphdef, state, args.seed)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A simple implementation of conditional flow matching for generating anime faces.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import pickle
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import jax
|
12 |
+
import jax.numpy as jnp
|
13 |
+
import kagglehub
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import numpy as np
|
16 |
+
import optax
|
17 |
+
import ot
|
18 |
+
import yaml
|
19 |
+
from flax import nnx
|
20 |
+
from jax.experimental import ode
|
21 |
+
from PIL import Image
|
22 |
+
from tqdm.cli import tqdm
|
23 |
+
|
24 |
+
from model import DiT, DiTConfig
|
25 |
+
|
26 |
+
|
27 |
+
def parse_args():
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument(
|
30 |
+
"--config", type=str, default="config.yaml", help="Path to config file"
|
31 |
+
)
|
32 |
+
return parser.parse_args()
|
33 |
+
|
34 |
+
|
35 |
+
def load_config(config_path):
|
36 |
+
with open(config_path) as f:
|
37 |
+
config = yaml.safe_load(f)
|
38 |
+
return config
|
39 |
+
|
40 |
+
|
41 |
+
def gen_data_batches(data, batch_size):
|
42 |
+
N = data.shape[0]
|
43 |
+
while True:
|
44 |
+
random_indices = np.random.choice(N, size=batch_size, replace=False)
|
45 |
+
batch = data[random_indices]
|
46 |
+
batch = batch.astype(np.float32) / 256
|
47 |
+
yield batch
|
48 |
+
|
49 |
+
|
50 |
+
def loss_fn(flow, batch):
|
51 |
+
xt, t, vt = batch
|
52 |
+
velocity = flow(xt, t)
|
53 |
+
loss = jnp.mean(jnp.square(velocity - vt))
|
54 |
+
return loss
|
55 |
+
|
56 |
+
|
57 |
+
def train_step(flow, optimizer, rngs, batch):
|
58 |
+
x0, x1 = batch
|
59 |
+
noise = jax.random.uniform(rngs(), shape=x1.shape, minval=0, maxval=1 / 256)
|
60 |
+
x1 = x1 + noise
|
61 |
+
# randomize t
|
62 |
+
t = jax.random.uniform(rngs(), (x1.shape[0],), minval=0, maxval=1)
|
63 |
+
# randomize x0
|
64 |
+
xt = x0 + (x1 - x0) * t[:, None, None, None]
|
65 |
+
vt = x1 - x0
|
66 |
+
batch = (xt, t, vt)
|
67 |
+
loss, grads = nnx.value_and_grad(loss_fn)(flow, batch)
|
68 |
+
optimizer.update(grads)
|
69 |
+
return loss
|
70 |
+
|
71 |
+
|
72 |
+
@jax.jit
|
73 |
+
def train_step_raw(graphdef, state, batch):
|
74 |
+
flow, optimizer, rngs = nnx.merge(graphdef, state)
|
75 |
+
loss = train_step(flow, optimizer, rngs, batch)
|
76 |
+
_, state = nnx.split((flow, optimizer, rngs))
|
77 |
+
return state, loss
|
78 |
+
|
79 |
+
|
80 |
+
@jax.jit
|
81 |
+
def sample_images(graphdef, state):
|
82 |
+
flow, _, _ = nnx.merge(graphdef, state)
|
83 |
+
|
84 |
+
def flow_fn(y, t):
|
85 |
+
o = flow(y, t[None])
|
86 |
+
return o
|
87 |
+
|
88 |
+
x = jax.random.normal(nnx.Rngs(0)(), shape=(16, 64, 64, 3), dtype=jnp.float32)
|
89 |
+
o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000))
|
90 |
+
o = jnp.clip(o[-1], 0, 1)
|
91 |
+
return o
|
92 |
+
|
93 |
+
|
94 |
+
def generate_ot_pairs(x1):
|
95 |
+
n = x1.shape[0]
|
96 |
+
x0 = np.random.randn(*x1.shape)
|
97 |
+
d1 = x1.reshape(n, -1)
|
98 |
+
d0 = x0.reshape(n, -1)
|
99 |
+
# loss matrix
|
100 |
+
M = ot.dist(d0, d1)
|
101 |
+
a, b = np.ones((n,)), np.ones((n,))
|
102 |
+
G0 = ot.emd(a, b, M)
|
103 |
+
d1 = np.matmul(G0, d1)
|
104 |
+
x1 = d1.reshape(*x1.shape)
|
105 |
+
return x0, x1
|
106 |
+
|
107 |
+
|
108 |
+
def plot_new_images(step: int, graphdef, state):
|
109 |
+
images = sample_images(graphdef, state)
|
110 |
+
|
111 |
+
plt.figure(figsize=(2, 2))
|
112 |
+
for i in range(16):
|
113 |
+
plt.subplot(4, 4, i + 1)
|
114 |
+
plt.imshow(images[i])
|
115 |
+
plt.axis("off")
|
116 |
+
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
|
117 |
+
plt.savefig(f"images_{step:06d}.png")
|
118 |
+
plt.close()
|
119 |
+
|
120 |
+
|
121 |
+
args = parse_args()
|
122 |
+
config = load_config(args.config)
|
123 |
+
|
124 |
+
# Download latest version
|
125 |
+
path = kagglehub.dataset_download("thimac/anime-face-64")
|
126 |
+
data_path = Path(path) / "64x64"
|
127 |
+
print("Path to dataset files:", data_path)
|
128 |
+
|
129 |
+
data_dir = data_path
|
130 |
+
image_files = sorted(data_dir.glob("*.jpg"))
|
131 |
+
random.Random(config["data"]["random_seed"]).shuffle(image_files)
|
132 |
+
N = len(image_files)
|
133 |
+
dataset = np.empty((N, 64, 64, 3), dtype=np.uint8)
|
134 |
+
for i, file_path in enumerate(tqdm(image_files)):
|
135 |
+
dataset[i] = Image.open(file_path)
|
136 |
+
|
137 |
+
L = int(N * config["data"]["train_split"])
|
138 |
+
train_data = dataset[:L]
|
139 |
+
test_data = dataset[L:]
|
140 |
+
|
141 |
+
plt.figure(figsize=(2, 2))
|
142 |
+
for i in range(16):
|
143 |
+
plt.subplot(4, 4, i + 1)
|
144 |
+
plt.imshow(train_data[i])
|
145 |
+
plt.axis("off")
|
146 |
+
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
|
147 |
+
plt.savefig("train_data_samples.png")
|
148 |
+
plt.close()
|
149 |
+
|
150 |
+
scheduler = optax.cosine_onecycle_schedule(
|
151 |
+
transition_steps=config["training"]["num_steps"],
|
152 |
+
peak_value=config["training"]["learning_rate"],
|
153 |
+
pct_start=config["training"]["warmup_pct"],
|
154 |
+
)
|
155 |
+
|
156 |
+
gradient_transform = optax.chain(
|
157 |
+
optax.clip_by_global_norm(config["training"]["grad_clip_norm"]),
|
158 |
+
optax.scale_by_adam(),
|
159 |
+
optax.scale_by_schedule(scheduler),
|
160 |
+
optax.add_decayed_weights(config["training"]["weight_decay"]),
|
161 |
+
optax.scale(-1.0),
|
162 |
+
)
|
163 |
+
|
164 |
+
dit_config = DiTConfig(
|
165 |
+
input_dim=config["model"]["input_dim"],
|
166 |
+
hidden_dim=config["model"]["hidden_dim"],
|
167 |
+
num_blocks=config["model"]["num_blocks"],
|
168 |
+
num_heads=config["model"]["num_heads"],
|
169 |
+
patch_size=config["model"]["patch_size"],
|
170 |
+
patch_stride=config["model"]["patch_stride"],
|
171 |
+
time_freq_dim=config["model"]["time_freq_dim"],
|
172 |
+
time_max_period=config["model"]["time_max_period"],
|
173 |
+
mlp_ratio=config["model"]["mlp_ratio"],
|
174 |
+
use_bias=config["model"]["use_bias"],
|
175 |
+
padding=config["model"]["padding"],
|
176 |
+
pos_embed_cls_token=config["model"]["pos_embed_cls_token"],
|
177 |
+
pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"],
|
178 |
+
)
|
179 |
+
|
180 |
+
flow = DiT(dit_config, rngs=nnx.Rngs(0))
|
181 |
+
optimizer = nnx.Optimizer(flow, gradient_transform)
|
182 |
+
|
183 |
+
rngs = nnx.Rngs(0)
|
184 |
+
graphdef, state = nnx.split((flow, optimizer, rngs))
|
185 |
+
train_data_iter = gen_data_batches(train_data, config["training"]["batch_size"])
|
186 |
+
|
187 |
+
start = time.perf_counter()
|
188 |
+
losses = []
|
189 |
+
ckpt_path = config["checkpointing"].get("resume_from_checkpoint")
|
190 |
+
if ckpt_path:
|
191 |
+
del state
|
192 |
+
with open(ckpt_path, "rb") as f:
|
193 |
+
state = pickle.load(f)
|
194 |
+
print(f"Resuming from checkpoint {ckpt_path}")
|
195 |
+
step_str = Path(ckpt_path).stem.split("_")[-1]
|
196 |
+
start_step = int(step_str) + 1
|
197 |
+
else:
|
198 |
+
start_step = 1
|
199 |
+
|
200 |
+
for step, batch in enumerate(train_data_iter, start=start_step):
|
201 |
+
x0, x1 = generate_ot_pairs(batch)
|
202 |
+
state, loss = train_step_raw(graphdef, state, (x0, x1))
|
203 |
+
|
204 |
+
if step % 100 == 0:
|
205 |
+
losses.append(loss.item())
|
206 |
+
|
207 |
+
if step % config["checkpointing"]["log_every"] == 0:
|
208 |
+
end = time.perf_counter()
|
209 |
+
duration = end - start
|
210 |
+
loss = sum(losses) / len(losses)
|
211 |
+
start = time.perf_counter()
|
212 |
+
losses = []
|
213 |
+
print(f"step {step:06d} loss {loss:.3f} duration {duration:.3f}s", flush=True)
|
214 |
+
|
215 |
+
if step % config["checkpointing"]["plot_every"] == 0:
|
216 |
+
plot_new_images(step, graphdef, state)
|
217 |
+
|
218 |
+
if step % config["checkpointing"]["save_every"] == 0:
|
219 |
+
# save checkpoint
|
220 |
+
with open(f"state_{step:06d}.ckpt", "wb") as f:
|
221 |
+
pickle.dump(state, f)
|
train_data_samples.png
ADDED
![]() |
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|