ntt123 commited on
Commit
0c9bb32
·
verified ·
1 Parent(s): 225efdd

Upload folder using huggingface_hub

Browse files
Files changed (14) hide show
  1. .gitignore +10 -0
  2. .gradio/certificate.pem +31 -0
  3. .python-version +1 -0
  4. README.md +16 -6
  5. app.py +80 -0
  6. ckpt_1000k.pkl +3 -0
  7. config.yaml +36 -0
  8. model.py +294 -0
  9. pyproject.toml +15 -0
  10. requirements.txt +2 -0
  11. sample.py +93 -0
  12. train.py +221 -0
  13. train_data_samples.png +0 -0
  14. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md CHANGED
@@ -1,12 +1,22 @@
1
  ---
2
  title: AnimeFlow
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
  title: AnimeFlow
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.9.1
 
 
6
  ---
7
+ # Anime Flow
8
+
9
+ A simple implementation of conditional flow matching for generating anime faces. The model architecture closely follows the Diffusion Transformer model (DiT) found at https://github.com/facebookresearch/DiT/blob/main/models.py.
10
+
11
+ ## Train model
12
+
13
+ ```bash
14
+ pip install uv
15
+ uv run train.py --config ./config.yaml
16
+ ```
17
+
18
+ ## Generate images
19
 
20
+ ```bash
21
+ uv run sample.py --ckpt ./state_1000000.ckpt --config ./config.yaml --seed 0
22
+ ```
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from jax.experimental import ode
5
+ import yaml
6
+ from flax import nnx
7
+ import pickle
8
+
9
+
10
+ def load_model(config_path, ckpt_path):
11
+ # Load config
12
+ with open(config_path) as f:
13
+ config = yaml.safe_load(f)
14
+
15
+ # Load model and state
16
+ with open(ckpt_path, "rb") as f:
17
+ leaves = pickle.load(f)
18
+
19
+ from model import DiT, DiTConfig
20
+
21
+ dit_config = DiTConfig(**config["model"])
22
+ model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
23
+ graphdef, state = nnx.split(model)
24
+ _, treedef = jax.tree_util.tree_flatten(state)
25
+ state = jax.tree_util.tree_unflatten(treedef, leaves)
26
+ return graphdef, state
27
+
28
+
29
+ @jax.jit
30
+ def sample_images(graphdef, state, x0, t):
31
+ flow = nnx.merge(graphdef, state)
32
+
33
+ def flow_fn(y, t):
34
+ o = flow(y, t[None])
35
+ return o
36
+
37
+ o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
38
+ o = jnp.clip(o[-1], 0, 1)
39
+ return o
40
+
41
+
42
+ def generate_grid(seed, noise_level):
43
+ # Load model (doing this inside function to avoid global variables)
44
+ graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
45
+
46
+ t = jnp.linspace(0, 1, 2)
47
+ x0 = jax.random.truncated_normal(
48
+ nnx.Rngs(seed)(),
49
+ -noise_level,
50
+ noise_level,
51
+ shape=(16, 64, 64, 3),
52
+ dtype=jnp.float32,
53
+ )
54
+
55
+ # Generate images
56
+ images = sample_images(graphdef, state, x0, t)
57
+
58
+ # Convert to grid of 4x4
59
+ rows = []
60
+ for i in range(4):
61
+ row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1)
62
+ rows.append(row)
63
+ grid = jnp.concatenate(rows, axis=0)
64
+
65
+ return jax.device_get(grid)
66
+
67
+ # Create Gradio interface
68
+ demo = gr.Interface(
69
+ fn=generate_grid,
70
+ inputs=[
71
+ gr.Number(label="Random Seed", value=0, precision=0),
72
+ gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"),
73
+ ],
74
+ outputs=gr.Image(label="Generated Images"),
75
+ title="Anime Flow Generation Demo",
76
+ description="Generate a 4x4 grid of anime faces using Anime Flow",
77
+ )
78
+
79
+ if __name__ == "__main__":
80
+ demo.launch(share=True)
ckpt_1000k.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38793cde6fa2f5f32134d5f281141cd81fe257d6e160d975ab2a3c6c4559f6c2
3
+ size 147069817
config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model architecture
2
+ model:
3
+ input_dim: 3 # RGB images
4
+ hidden_dim: 512
5
+ num_blocks: 8
6
+ num_heads: 8
7
+ patch_size: 8
8
+ patch_stride: 4
9
+ time_freq_dim: 256
10
+ time_max_period: 1024
11
+ mlp_ratio: 4
12
+ use_bias: false
13
+ padding: "SAME"
14
+ pos_embed_cls_token: false
15
+ pos_embed_extra_tokens: 0
16
+
17
+ # Training parameters
18
+ training:
19
+ learning_rate: 1.0e-4
20
+ batch_size: 128
21
+ num_steps: 1_000_000
22
+ warmup_pct: 0.01
23
+ weight_decay: 0.0
24
+ grad_clip_norm: 100.0
25
+
26
+ # Checkpointing and logging
27
+ checkpointing:
28
+ log_every: 1_000
29
+ plot_every: 10_000
30
+ save_every: 10_000
31
+ resume_from_checkpoint: null
32
+
33
+ # Data
34
+ data:
35
+ train_split: 0.9 # 90% for training, 10% for testing
36
+ random_seed: 42
model.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ from flax import nnx
8
+
9
+
10
+ @dataclass
11
+ class DiTConfig:
12
+ input_dim: int
13
+ hidden_dim: int
14
+ num_blocks: int
15
+ num_heads: int
16
+ patch_size: int
17
+ patch_stride: int
18
+ time_freq_dim: int
19
+ time_max_period: int
20
+ mlp_ratio: int
21
+ use_bias: bool
22
+ padding: str
23
+ pos_embed_cls_token: bool
24
+ pos_embed_extra_tokens: int
25
+
26
+
27
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
28
+ """
29
+ grid_size: int of the grid height and width
30
+ return:
31
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
32
+ """
33
+ grid_h = jnp.arange(grid_size, dtype=jnp.float32)
34
+ grid_w = jnp.arange(grid_size, dtype=jnp.float32)
35
+ grid = jnp.meshgrid(grid_w, grid_h) # here w goes first
36
+ grid = jnp.stack(grid, axis=0)
37
+
38
+ grid = grid.reshape([2, 1, grid_size, grid_size])
39
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
40
+ if cls_token and extra_tokens > 0:
41
+ pos_embed = np.concatenate(
42
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
43
+ )
44
+ return pos_embed
45
+
46
+
47
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
48
+ assert embed_dim % 2 == 0
49
+
50
+ # use half of dimensions to encode grid_h
51
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
52
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
53
+
54
+ emb = jnp.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
55
+ return emb
56
+
57
+
58
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
59
+ """
60
+ embed_dim: output dimension for each position
61
+ pos: a list of positions to be encoded: size (M,)
62
+ out: (M, D)
63
+ """
64
+ assert embed_dim % 2 == 0
65
+ omega = jnp.arange(embed_dim // 2, dtype=np.float32)
66
+ omega /= embed_dim / 2.0
67
+ omega = 1.0 / 16**omega # (D/2,)
68
+
69
+ pos = pos.reshape(-1) # (M,)
70
+ out = jnp.einsum("m,d->md", pos, omega) # (M, D/2), outer product
71
+
72
+ emb_sin = jnp.sin(out) # (M, D/2)
73
+ emb_cos = jnp.cos(out) # (M, D/2)
74
+
75
+ emb = jnp.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
76
+ return emb
77
+
78
+
79
+ class PatchEmbedding(nnx.Module):
80
+ """Patch embedding module."""
81
+
82
+ def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
83
+ super().__init__()
84
+ self.cnn = nnx.Conv(
85
+ config.input_dim,
86
+ config.hidden_dim,
87
+ kernel_size=(config.patch_size, config.patch_size),
88
+ strides=(config.patch_stride, config.patch_stride),
89
+ padding=config.padding,
90
+ use_bias=config.use_bias,
91
+ rngs=rngs,
92
+ )
93
+
94
+ def __call__(self, x):
95
+ return self.cnn(x)
96
+
97
+
98
+ class TimeEmbedding(nnx.Module):
99
+ """Time embedding module."""
100
+
101
+ def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
102
+ super().__init__()
103
+ self.freq_dim = config.time_freq_dim
104
+ self.max_period = config.time_max_period
105
+ self.fc1 = nnx.Linear(
106
+ self.freq_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs
107
+ )
108
+ self.fc2 = nnx.Linear(
109
+ config.hidden_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs
110
+ )
111
+
112
+ @staticmethod
113
+ def cosine_embedding(t, dim, max_period):
114
+ assert dim % 2 == 0
115
+ half = dim // 2
116
+ freqs = jnp.exp(
117
+ -math.log(max_period)
118
+ * jnp.arange(start=0, stop=half, dtype=jnp.float32)
119
+ / half
120
+ )
121
+ args = t[:, None] * freqs[None, :] * 1024
122
+ embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
123
+ return embedding
124
+
125
+ def __call__(self, t):
126
+ t_freq = self.cosine_embedding(t, self.freq_dim, self.max_period)
127
+ t_embed = self.fc1(t_freq)
128
+ t_embed = nnx.silu(t_embed)
129
+ t_embed = self.fc2(t_embed)
130
+ return t_embed
131
+
132
+
133
+ class MLP(nnx.Module):
134
+ """MLP module."""
135
+
136
+ def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
137
+ super().__init__()
138
+ self.fc1 = nnx.Linear(
139
+ config.hidden_dim,
140
+ config.hidden_dim * config.mlp_ratio,
141
+ use_bias=config.use_bias,
142
+ rngs=rngs,
143
+ )
144
+ self.fc2 = nnx.Linear(
145
+ config.hidden_dim * config.mlp_ratio,
146
+ config.hidden_dim,
147
+ use_bias=config.use_bias,
148
+ rngs=rngs,
149
+ )
150
+
151
+ def __call__(self, x):
152
+ x = self.fc1(x)
153
+ x = nnx.silu(x)
154
+ x = self.fc2(x)
155
+ return x
156
+
157
+
158
+ class SelfAttention(nnx.Module):
159
+ """Self attention module."""
160
+
161
+ def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
162
+ super().__init__()
163
+ self.fc = nnx.Linear(
164
+ config.hidden_dim,
165
+ 3 * config.hidden_dim,
166
+ use_bias=config.use_bias,
167
+ rngs=rngs,
168
+ )
169
+ self.heads = config.num_heads
170
+ self.head_dim = config.hidden_dim // config.num_heads
171
+ assert config.hidden_dim % config.num_heads == 0
172
+ self.q_norm = nnx.RMSNorm(num_features=self.head_dim, use_scale=True, rngs=rngs)
173
+ self.k_norm = nnx.RMSNorm(num_features=self.head_dim, use_scale=True, rngs=rngs)
174
+
175
+ def __call__(self, x):
176
+ q, k, v = jnp.split(self.fc(x), 3, axis=-1)
177
+ # reshape q, k v, to N, T, H, D
178
+ q = q.reshape(q.shape[0], q.shape[1], self.heads, self.head_dim)
179
+ k = k.reshape(k.shape[0], k.shape[1], self.heads, self.head_dim)
180
+ v = v.reshape(v.shape[0], v.shape[1], self.heads, self.head_dim)
181
+ q = self.q_norm(q)
182
+ k = self.k_norm(k)
183
+ o = jax.nn.dot_product_attention(q, k, v, is_causal=False)
184
+ o = o.reshape(o.shape[0], o.shape[1], self.heads * self.head_dim)
185
+ return o
186
+
187
+
188
+ def modulate(x, shift, scale):
189
+ return x * (1 + scale[:, None, :]) + shift[:, None, :]
190
+
191
+
192
+ class TransformerBlock(nnx.Module):
193
+ """Transformer block."""
194
+
195
+ def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
196
+ super().__init__()
197
+ self.norm1 = nnx.RMSNorm(
198
+ num_features=config.hidden_dim, use_scale=False, rngs=rngs
199
+ )
200
+ self.attn = SelfAttention(config, rngs=rngs)
201
+ self.norm2 = nnx.RMSNorm(
202
+ num_features=config.hidden_dim, use_scale=False, rngs=rngs
203
+ )
204
+ self.mlp = MLP(config, rngs=rngs)
205
+ self.adalm_modulation = nnx.Sequential(
206
+ nnx.silu,
207
+ nnx.Linear(
208
+ config.hidden_dim,
209
+ 6 * config.hidden_dim,
210
+ use_bias=config.use_bias,
211
+ rngs=rngs,
212
+ ),
213
+ )
214
+
215
+ def __call__(self, x, c):
216
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(
217
+ self.adalm_modulation(c), 6, axis=-1
218
+ )
219
+ attn_x = self.norm1(x)
220
+ attn_x = modulate(attn_x, shift_msa, scale_msa)
221
+ x = x + gate_msa[:, None, :] * self.attn(attn_x)
222
+ mlp_x = self.norm2(x)
223
+ mlp_x = modulate(mlp_x, shift_mlp, scale_mlp)
224
+ x = x + gate_mlp[:, None, :] * self.mlp(mlp_x)
225
+ return x
226
+
227
+
228
+ class FinalLayer(nnx.Module):
229
+ """Final layer."""
230
+
231
+ def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
232
+ super().__init__()
233
+ self.norm = nnx.RMSNorm(
234
+ num_features=config.hidden_dim, use_scale=False, rngs=rngs
235
+ )
236
+ self.conv = nnx.ConvTranspose(
237
+ config.hidden_dim,
238
+ config.input_dim,
239
+ kernel_size=(config.patch_size, config.patch_size),
240
+ strides=(config.patch_stride, config.patch_stride),
241
+ padding=config.padding,
242
+ use_bias=config.use_bias,
243
+ rngs=rngs,
244
+ )
245
+ self.adalm_modulation = nnx.Sequential(
246
+ nnx.silu,
247
+ nnx.Linear(
248
+ config.hidden_dim,
249
+ 2 * config.hidden_dim,
250
+ use_bias=config.use_bias,
251
+ rngs=rngs,
252
+ ),
253
+ )
254
+
255
+ def __call__(self, x, c):
256
+ shift, scale = jnp.split(self.adalm_modulation(c), 2, axis=-1)
257
+ x = self.norm(x)
258
+ x = modulate(x, shift, scale)
259
+ # reshape to N, H, W, C
260
+ H = W = int(x.shape[1] ** 0.5)
261
+ x = x.reshape(x.shape[0], H, W, x.shape[-1])
262
+ x = self.conv(x)
263
+ return x
264
+
265
+
266
+ class DiT(nnx.Module):
267
+ """Diffusion Transformer"""
268
+
269
+ def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs):
270
+ super().__init__()
271
+ self.config = config
272
+ self.time_embedding = TimeEmbedding(config, rngs=rngs)
273
+ self.patch_embedding = PatchEmbedding(config, rngs=rngs)
274
+ self.blocks = [
275
+ TransformerBlock(config, rngs=rngs) for _ in range(config.num_blocks)
276
+ ]
277
+ self.final_layer = FinalLayer(config, rngs=rngs)
278
+
279
+ def __call__(self, xt, t):
280
+ t = self.time_embedding(t)
281
+ x = self.patch_embedding(xt)
282
+ N, H, W, D = x.shape
283
+ x = x.reshape(N, H * W, D)
284
+ x = x + get_2d_sincos_pos_embed(
285
+ D,
286
+ H,
287
+ cls_token=self.config.pos_embed_cls_token,
288
+ extra_tokens=self.config.pos_embed_extra_tokens,
289
+ ).reshape(1, H * W, D)
290
+ c = t
291
+ for block in self.blocks:
292
+ x = block(x, c)
293
+ x = self.final_layer(x, c)
294
+ return x
pyproject.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "anime-flow"
3
+ version = "0.1.0"
4
+ description = "Generate anime faces using conditional flow matching"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "flax>=0.10.2",
9
+ "gradio>=5.9.1",
10
+ "jax[cuda12]>=0.4.38",
11
+ "kagglehub>=0.3.6",
12
+ "matplotlib>=3.10.0",
13
+ "pillow>=11.0.0",
14
+ "pot>=0.9.5",
15
+ ]
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ jax[cuda12]
2
+ flax
sample.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate images from trained model
3
+ """
4
+
5
+ import argparse
6
+ import pickle
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import matplotlib.pyplot as plt
11
+ import yaml
12
+ from flax import nnx
13
+ from jax.experimental import ode
14
+
15
+ from model import DiT, DiTConfig
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument(
21
+ "--config", type=str, default="config.yaml", help="Path to config file"
22
+ )
23
+ parser.add_argument(
24
+ "--ckpt", type=str, default=None, help="Path to checkpoint file"
25
+ )
26
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
27
+ return parser.parse_args()
28
+
29
+
30
+ def load_config(config_path):
31
+ with open(config_path) as f:
32
+ config = yaml.safe_load(f)
33
+ return config
34
+
35
+
36
+ @jax.jit
37
+ def sample_images(graphdef, state, rng):
38
+ flow = nnx.merge(graphdef, state)
39
+
40
+ def flow_fn(y, t):
41
+ o = flow(y, t[None])
42
+ return o
43
+
44
+ x = jax.random.normal(rng, shape=(16, 64, 64, 3), dtype=jnp.float32)
45
+ o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000))
46
+ o = jnp.clip(o[-1], 0, 1)
47
+ return o
48
+
49
+
50
+ def plot_new_images(graphdef, state, seed):
51
+ images = sample_images(graphdef, state, nnx.Rngs(seed)())
52
+
53
+ plt.figure(figsize=(2, 2))
54
+ for i in range(16):
55
+ plt.subplot(4, 4, i + 1)
56
+ plt.imshow(images[i])
57
+ plt.axis("off")
58
+ plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
59
+ plt.savefig(f"samples.png")
60
+ plt.close()
61
+
62
+
63
+ def main():
64
+ args = parse_args()
65
+ config = load_config(args.config)
66
+
67
+ dit_config = DiTConfig(
68
+ input_dim=config["model"]["input_dim"],
69
+ hidden_dim=config["model"]["hidden_dim"],
70
+ num_blocks=config["model"]["num_blocks"],
71
+ num_heads=config["model"]["num_heads"],
72
+ patch_size=config["model"]["patch_size"],
73
+ patch_stride=config["model"]["patch_stride"],
74
+ time_freq_dim=config["model"]["time_freq_dim"],
75
+ time_max_period=config["model"]["time_max_period"],
76
+ mlp_ratio=config["model"]["mlp_ratio"],
77
+ use_bias=config["model"]["use_bias"],
78
+ padding=config["model"]["padding"],
79
+ pos_embed_cls_token=config["model"]["pos_embed_cls_token"],
80
+ pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"],
81
+ )
82
+
83
+ abstract_flow = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
84
+ graphdef, _ = nnx.split(abstract_flow)
85
+ with open(args.ckpt, "rb") as f:
86
+ state = pickle.load(f, fix_imports=True)
87
+ if "time_embedding" not in state:
88
+ state = state[0]
89
+ plot_new_images(graphdef, state, args.seed)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
train.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A simple implementation of conditional flow matching for generating anime faces.
3
+ """
4
+
5
+ import argparse
6
+ import pickle
7
+ import random
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import kagglehub
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import optax
17
+ import ot
18
+ import yaml
19
+ from flax import nnx
20
+ from jax.experimental import ode
21
+ from PIL import Image
22
+ from tqdm.cli import tqdm
23
+
24
+ from model import DiT, DiTConfig
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument(
30
+ "--config", type=str, default="config.yaml", help="Path to config file"
31
+ )
32
+ return parser.parse_args()
33
+
34
+
35
+ def load_config(config_path):
36
+ with open(config_path) as f:
37
+ config = yaml.safe_load(f)
38
+ return config
39
+
40
+
41
+ def gen_data_batches(data, batch_size):
42
+ N = data.shape[0]
43
+ while True:
44
+ random_indices = np.random.choice(N, size=batch_size, replace=False)
45
+ batch = data[random_indices]
46
+ batch = batch.astype(np.float32) / 256
47
+ yield batch
48
+
49
+
50
+ def loss_fn(flow, batch):
51
+ xt, t, vt = batch
52
+ velocity = flow(xt, t)
53
+ loss = jnp.mean(jnp.square(velocity - vt))
54
+ return loss
55
+
56
+
57
+ def train_step(flow, optimizer, rngs, batch):
58
+ x0, x1 = batch
59
+ noise = jax.random.uniform(rngs(), shape=x1.shape, minval=0, maxval=1 / 256)
60
+ x1 = x1 + noise
61
+ # randomize t
62
+ t = jax.random.uniform(rngs(), (x1.shape[0],), minval=0, maxval=1)
63
+ # randomize x0
64
+ xt = x0 + (x1 - x0) * t[:, None, None, None]
65
+ vt = x1 - x0
66
+ batch = (xt, t, vt)
67
+ loss, grads = nnx.value_and_grad(loss_fn)(flow, batch)
68
+ optimizer.update(grads)
69
+ return loss
70
+
71
+
72
+ @jax.jit
73
+ def train_step_raw(graphdef, state, batch):
74
+ flow, optimizer, rngs = nnx.merge(graphdef, state)
75
+ loss = train_step(flow, optimizer, rngs, batch)
76
+ _, state = nnx.split((flow, optimizer, rngs))
77
+ return state, loss
78
+
79
+
80
+ @jax.jit
81
+ def sample_images(graphdef, state):
82
+ flow, _, _ = nnx.merge(graphdef, state)
83
+
84
+ def flow_fn(y, t):
85
+ o = flow(y, t[None])
86
+ return o
87
+
88
+ x = jax.random.normal(nnx.Rngs(0)(), shape=(16, 64, 64, 3), dtype=jnp.float32)
89
+ o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000))
90
+ o = jnp.clip(o[-1], 0, 1)
91
+ return o
92
+
93
+
94
+ def generate_ot_pairs(x1):
95
+ n = x1.shape[0]
96
+ x0 = np.random.randn(*x1.shape)
97
+ d1 = x1.reshape(n, -1)
98
+ d0 = x0.reshape(n, -1)
99
+ # loss matrix
100
+ M = ot.dist(d0, d1)
101
+ a, b = np.ones((n,)), np.ones((n,))
102
+ G0 = ot.emd(a, b, M)
103
+ d1 = np.matmul(G0, d1)
104
+ x1 = d1.reshape(*x1.shape)
105
+ return x0, x1
106
+
107
+
108
+ def plot_new_images(step: int, graphdef, state):
109
+ images = sample_images(graphdef, state)
110
+
111
+ plt.figure(figsize=(2, 2))
112
+ for i in range(16):
113
+ plt.subplot(4, 4, i + 1)
114
+ plt.imshow(images[i])
115
+ plt.axis("off")
116
+ plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
117
+ plt.savefig(f"images_{step:06d}.png")
118
+ plt.close()
119
+
120
+
121
+ args = parse_args()
122
+ config = load_config(args.config)
123
+
124
+ # Download latest version
125
+ path = kagglehub.dataset_download("thimac/anime-face-64")
126
+ data_path = Path(path) / "64x64"
127
+ print("Path to dataset files:", data_path)
128
+
129
+ data_dir = data_path
130
+ image_files = sorted(data_dir.glob("*.jpg"))
131
+ random.Random(config["data"]["random_seed"]).shuffle(image_files)
132
+ N = len(image_files)
133
+ dataset = np.empty((N, 64, 64, 3), dtype=np.uint8)
134
+ for i, file_path in enumerate(tqdm(image_files)):
135
+ dataset[i] = Image.open(file_path)
136
+
137
+ L = int(N * config["data"]["train_split"])
138
+ train_data = dataset[:L]
139
+ test_data = dataset[L:]
140
+
141
+ plt.figure(figsize=(2, 2))
142
+ for i in range(16):
143
+ plt.subplot(4, 4, i + 1)
144
+ plt.imshow(train_data[i])
145
+ plt.axis("off")
146
+ plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
147
+ plt.savefig("train_data_samples.png")
148
+ plt.close()
149
+
150
+ scheduler = optax.cosine_onecycle_schedule(
151
+ transition_steps=config["training"]["num_steps"],
152
+ peak_value=config["training"]["learning_rate"],
153
+ pct_start=config["training"]["warmup_pct"],
154
+ )
155
+
156
+ gradient_transform = optax.chain(
157
+ optax.clip_by_global_norm(config["training"]["grad_clip_norm"]),
158
+ optax.scale_by_adam(),
159
+ optax.scale_by_schedule(scheduler),
160
+ optax.add_decayed_weights(config["training"]["weight_decay"]),
161
+ optax.scale(-1.0),
162
+ )
163
+
164
+ dit_config = DiTConfig(
165
+ input_dim=config["model"]["input_dim"],
166
+ hidden_dim=config["model"]["hidden_dim"],
167
+ num_blocks=config["model"]["num_blocks"],
168
+ num_heads=config["model"]["num_heads"],
169
+ patch_size=config["model"]["patch_size"],
170
+ patch_stride=config["model"]["patch_stride"],
171
+ time_freq_dim=config["model"]["time_freq_dim"],
172
+ time_max_period=config["model"]["time_max_period"],
173
+ mlp_ratio=config["model"]["mlp_ratio"],
174
+ use_bias=config["model"]["use_bias"],
175
+ padding=config["model"]["padding"],
176
+ pos_embed_cls_token=config["model"]["pos_embed_cls_token"],
177
+ pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"],
178
+ )
179
+
180
+ flow = DiT(dit_config, rngs=nnx.Rngs(0))
181
+ optimizer = nnx.Optimizer(flow, gradient_transform)
182
+
183
+ rngs = nnx.Rngs(0)
184
+ graphdef, state = nnx.split((flow, optimizer, rngs))
185
+ train_data_iter = gen_data_batches(train_data, config["training"]["batch_size"])
186
+
187
+ start = time.perf_counter()
188
+ losses = []
189
+ ckpt_path = config["checkpointing"].get("resume_from_checkpoint")
190
+ if ckpt_path:
191
+ del state
192
+ with open(ckpt_path, "rb") as f:
193
+ state = pickle.load(f)
194
+ print(f"Resuming from checkpoint {ckpt_path}")
195
+ step_str = Path(ckpt_path).stem.split("_")[-1]
196
+ start_step = int(step_str) + 1
197
+ else:
198
+ start_step = 1
199
+
200
+ for step, batch in enumerate(train_data_iter, start=start_step):
201
+ x0, x1 = generate_ot_pairs(batch)
202
+ state, loss = train_step_raw(graphdef, state, (x0, x1))
203
+
204
+ if step % 100 == 0:
205
+ losses.append(loss.item())
206
+
207
+ if step % config["checkpointing"]["log_every"] == 0:
208
+ end = time.perf_counter()
209
+ duration = end - start
210
+ loss = sum(losses) / len(losses)
211
+ start = time.perf_counter()
212
+ losses = []
213
+ print(f"step {step:06d} loss {loss:.3f} duration {duration:.3f}s", flush=True)
214
+
215
+ if step % config["checkpointing"]["plot_every"] == 0:
216
+ plot_new_images(step, graphdef, state)
217
+
218
+ if step % config["checkpointing"]["save_every"] == 0:
219
+ # save checkpoint
220
+ with open(f"state_{step:06d}.ckpt", "wb") as f:
221
+ pickle.dump(state, f)
train_data_samples.png ADDED
uv.lock ADDED
The diff for this file is too large to render. See raw diff