ntt123 commited on
Commit
7294a64
·
verified ·
1 Parent(s): e8774dc

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -2
model.py CHANGED
@@ -89,6 +89,7 @@ class PatchEmbedding(nnx.Module):
89
  padding=config.padding,
90
  use_bias=config.use_bias,
91
  rngs=rngs,
 
92
  )
93
 
94
  def __call__(self, x):
@@ -103,10 +104,10 @@ class TimeEmbedding(nnx.Module):
103
  self.freq_dim = config.time_freq_dim
104
  self.max_period = config.time_max_period
105
  self.fc1 = nnx.Linear(
106
- self.freq_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs
107
  )
108
  self.fc2 = nnx.Linear(
109
- config.hidden_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs
110
  )
111
 
112
  @staticmethod
@@ -140,12 +141,14 @@ class MLP(nnx.Module):
140
  config.hidden_dim * config.mlp_ratio,
141
  use_bias=config.use_bias,
142
  rngs=rngs,
 
143
  )
144
  self.fc2 = nnx.Linear(
145
  config.hidden_dim * config.mlp_ratio,
146
  config.hidden_dim,
147
  use_bias=config.use_bias,
148
  rngs=rngs,
 
149
  )
150
 
151
  def __call__(self, x):
@@ -165,6 +168,7 @@ class SelfAttention(nnx.Module):
165
  3 * config.hidden_dim,
166
  use_bias=config.use_bias,
167
  rngs=rngs,
 
168
  )
169
  self.heads = config.num_heads
170
  self.head_dim = config.hidden_dim // config.num_heads
@@ -209,6 +213,7 @@ class TransformerBlock(nnx.Module):
209
  6 * config.hidden_dim,
210
  use_bias=config.use_bias,
211
  rngs=rngs,
 
212
  ),
213
  )
214
 
@@ -241,6 +246,7 @@ class FinalLayer(nnx.Module):
241
  padding=config.padding,
242
  use_bias=config.use_bias,
243
  rngs=rngs,
 
244
  )
245
  self.adalm_modulation = nnx.Sequential(
246
  nnx.silu,
@@ -249,6 +255,7 @@ class FinalLayer(nnx.Module):
249
  2 * config.hidden_dim,
250
  use_bias=config.use_bias,
251
  rngs=rngs,
 
252
  ),
253
  )
254
 
 
89
  padding=config.padding,
90
  use_bias=config.use_bias,
91
  rngs=rngs,
92
+ dtype=jnp.bfloat16,
93
  )
94
 
95
  def __call__(self, x):
 
104
  self.freq_dim = config.time_freq_dim
105
  self.max_period = config.time_max_period
106
  self.fc1 = nnx.Linear(
107
+ self.freq_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs, dtype=jnp.bfloat16
108
  )
109
  self.fc2 = nnx.Linear(
110
+ config.hidden_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs, dtype=jnp.bfloat16
111
  )
112
 
113
  @staticmethod
 
141
  config.hidden_dim * config.mlp_ratio,
142
  use_bias=config.use_bias,
143
  rngs=rngs,
144
+ dtype=jnp.bfloat16,
145
  )
146
  self.fc2 = nnx.Linear(
147
  config.hidden_dim * config.mlp_ratio,
148
  config.hidden_dim,
149
  use_bias=config.use_bias,
150
  rngs=rngs,
151
+ dtype=jnp.bfloat16,
152
  )
153
 
154
  def __call__(self, x):
 
168
  3 * config.hidden_dim,
169
  use_bias=config.use_bias,
170
  rngs=rngs,
171
+ dtype=jnp.bfloat16,
172
  )
173
  self.heads = config.num_heads
174
  self.head_dim = config.hidden_dim // config.num_heads
 
213
  6 * config.hidden_dim,
214
  use_bias=config.use_bias,
215
  rngs=rngs,
216
+ dtype=jnp.bfloat16,
217
  ),
218
  )
219
 
 
246
  padding=config.padding,
247
  use_bias=config.use_bias,
248
  rngs=rngs,
249
+ dtype=jnp.bfloat16,
250
  )
251
  self.adalm_modulation = nnx.Sequential(
252
  nnx.silu,
 
255
  2 * config.hidden_dim,
256
  use_bias=config.use_bias,
257
  rngs=rngs,
258
+ dtype=jnp.bfloat16,
259
  ),
260
  )
261