Correct the output dtype of rmsnorm_func
#13
by
ag0
- opened
- modeling_flash_llama.py +1 -1
modeling_flash_llama.py
CHANGED
@@ -68,7 +68,7 @@ def rmsnorm_func(hidden_states, weight, variance_epsilon):
|
|
68 |
hidden_states = hidden_states.to(torch.float32)
|
69 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
70 |
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
71 |
-
return weight * hidden_states.to(input_dtype)
|
72 |
|
73 |
|
74 |
class LlamaRMSNorm(nn.Module):
|
|
|
68 |
hidden_states = hidden_states.to(torch.float32)
|
69 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
70 |
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
71 |
+
return (weight * hidden_states).to(input_dtype)
|
72 |
|
73 |
|
74 |
class LlamaRMSNorm(nn.Module):
|