xiaotinghe commited on
Commit
1991fab
·
1 Parent(s): 376d5a1

Upload BertForSequenceClassification

Browse files
Files changed (5) hide show
  1. bert_layers.py +1101 -0
  2. bert_padding.py +159 -0
  3. config.json +114 -0
  4. configuration_bert.py +26 -0
  5. pytorch_model.bin +3 -0
bert_layers.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
5
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
6
+ # Copyright (c) 2022, Tri Dao.
7
+
8
+ """Implements Mosaic BERT, with an eye towards the Hugging Face API.
9
+
10
+ Mosaic BERT improves performance over Hugging Face BERT through the following:
11
+
12
+ 1. ALiBi. This architectural change removes positional embeddings and instead encodes positional
13
+ information through attention biases based on query-key position distance. It improves the effectiveness
14
+ of training with shorter sequence lengths by enabling extrapolation to longer sequences.
15
+
16
+ 2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer
17
+ to improve overall expressiveness, providing better convergence properties.
18
+
19
+ 3. Flash Attention. The Mosaic BERT's self-attention layer makes use of Flash Attention, which dramatically
20
+ improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that
21
+ supports attention biases, which allows us to use Flash Attention with ALiBi.
22
+
23
+ 4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT
24
+ implementations waste computation on padded tokens. Mosaic BERT internally unpads to reduce unnecessary computation
25
+ and improve speed. It does this without changing how the user interfaces with the model, thereby
26
+ preserving the simple API of standard implementations.
27
+
28
+
29
+ Currently, Mosaic BERT is available for masked language modeling :class:`BertForMaskedLM` and sequence
30
+ classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases.
31
+
32
+ See :file:`./mosaic_bert.py` for utilities to simplify working with Mosaic BERT in Composer, and for example usage
33
+ of the core Mosaic BERT classes.
34
+ """
35
+
36
+ import copy
37
+ import logging
38
+ import math
39
+ import os
40
+ import sys
41
+ import warnings
42
+ from typing import List, Optional, Tuple, Union
43
+ from .configuration_bert import BertConfig
44
+ # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
45
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
46
+
47
+ from .bert_padding import (index_first_axis,
48
+ index_put_first_axis, pad_input,
49
+ unpad_input, unpad_input_only)
50
+ import torch
51
+ import torch.nn as nn
52
+ from torch.nn import functional as F
53
+
54
+ from einops import rearrange
55
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
56
+ from transformers.activations import ACT2FN
57
+ from transformers.modeling_outputs import (MaskedLMOutput,
58
+ SequenceClassifierOutput)
59
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
60
+ logger = logging.getLogger(__name__)
61
+
62
+ class RMSNorm(nn.Module):
63
+ def __init__(self, hidden_size, eps=1e-6):
64
+ """
65
+ RMSNorm is equivalent to T5LayerNorm
66
+ """
67
+ super().__init__()
68
+ self.weight = nn.Parameter(torch.ones(hidden_size))
69
+ self.variance_epsilon = eps
70
+
71
+ def forward(self, hidden_states):
72
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
73
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
+
75
+ # convert into half-precision if necessary
76
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
77
+ hidden_states = hidden_states.to(self.weight.dtype)
78
+
79
+ return self.weight * hidden_states
80
+
81
+ class RotaryEmbedding(torch.nn.Module):
82
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
83
+ super().__init__()
84
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
85
+ self.max_seq_len_cached = max_position_embeddings
86
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
87
+ freqs = torch.outer(t, self.inv_freq)
88
+ emb = torch.cat((freqs, freqs), dim=-1)
89
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
90
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
91
+ def forward(self, x, seq_len=None):
92
+ # x: [bs, num_attention_heads, seq_len, head_size]
93
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
94
+ if seq_len > self.max_seq_len_cached:
95
+ self.max_seq_len_cached = seq_len
96
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
97
+ freqs = torch.outer(t, self.inv_freq)
98
+ emb = torch.cat((freqs, freqs), dim=-1)
99
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
100
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
101
+ elif self.cos_cached.device != x.device:
102
+ self.cos_cached = self.cos_cached.to(x.device)
103
+ self.sin_cached = self.sin_cached.to(x.device)
104
+ return (
105
+ self.cos_cached[:, :, :seq_len, ...],
106
+ self.sin_cached[:, :, :seq_len, ...],
107
+ )
108
+
109
+
110
+ def rotate_half(x):
111
+ """Rotates half the hidden dims of the input."""
112
+ x1 = x[..., : x.shape[-1] // 2]
113
+ x2 = x[..., x.shape[-1] // 2:]
114
+ return torch.cat((-x2, x1), dim=-1)
115
+
116
+
117
+ def apply_rotary_pos_emb(q, k, cos_, sin_):
118
+ #cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
119
+ #sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
120
+ cos = torch.repeat_interleave(cos_[:, :, None, :], q.shape[0], 0).squeeze(1)
121
+ sin = torch.repeat_interleave(sin_[:, :, None, :], q.shape[0], 0).squeeze(1)
122
+ #position_ids = torch.Tensor([list(range(q.shape[2]))]*q.shape[0]).int().to(q.device)
123
+ #cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
124
+ #sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
125
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
126
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
127
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
128
+
129
+ class BertEmbeddings(nn.Module):
130
+ """Construct the embeddings for words, ignoring position.
131
+
132
+ There are no positional embeddings since we use ALiBi and token_type
133
+ embeddings.
134
+
135
+ This module is modeled after the Hugging Face BERT's
136
+ :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is
137
+ modified as part of Mosaic BERT's ALiBi implementation. The key change is
138
+ that position embeddings are removed. Position information instead comes
139
+ from attention biases that scale linearly with the position distance
140
+ between query and key tokens.
141
+
142
+ This module ignores the `position_ids` input to the `forward` method.
143
+ """
144
+
145
+ def __init__(self, config):
146
+ super().__init__()
147
+ self.word_embeddings = nn.Embedding(config.vocab_size,
148
+ config.hidden_size,
149
+ padding_idx=config.pad_token_id)
150
+ # ALiBi doesn't use position embeddings
151
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
152
+ config.hidden_size)
153
+
154
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
155
+ # variable name and be able to load any TensorFlow checkpoint file
156
+ self.norm = RMSNorm(config.hidden_size,
157
+ eps=config.layer_norm_eps)
158
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
159
+ self.register_buffer('token_type_ids',
160
+ torch.zeros(config.max_position_embeddings,
161
+ dtype=torch.long),
162
+ persistent=False)
163
+
164
+ def forward(
165
+ self,
166
+ input_ids: Optional[torch.LongTensor] = None,
167
+ token_type_ids: Optional[torch.LongTensor] = None,
168
+ position_ids: Optional[torch.LongTensor] = None,
169
+ inputs_embeds: Optional[torch.FloatTensor] = None,
170
+ past_key_values_length: int = 0,
171
+ ) -> torch.Tensor:
172
+ if (input_ids is not None) == (inputs_embeds is not None):
173
+ raise ValueError('Must specify either input_ids or input_embeds!')
174
+ if input_ids is not None:
175
+ input_shape = input_ids.size()
176
+ else:
177
+ assert inputs_embeds is not None # just for type checking
178
+ input_shape = inputs_embeds.size()[:-1]
179
+
180
+ seq_length = input_shape[1]
181
+
182
+ if position_ids is None:
183
+ # great! ALiBi
184
+ pass
185
+
186
+ # Setting the token_type_ids to the registered buffer in constructor
187
+ # where it is all zeros, which usually occurs when it's auto-generated;
188
+ # registered buffer helps users when tracing the model without passing
189
+ # token_type_ids, solves issue #5664
190
+ if token_type_ids is None:
191
+ if hasattr(self, 'token_type_ids'):
192
+ assert isinstance(self.token_type_ids, torch.LongTensor)
193
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
194
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
195
+ input_shape[0], seq_length)
196
+ token_type_ids = buffered_token_type_ids_expanded # type: ignore
197
+ else:
198
+ token_type_ids = torch.zeros(input_shape, # type: ignore
199
+ dtype=torch.long,
200
+ device=self.word_embeddings.device) # type: ignore # yapf: disable
201
+
202
+ if inputs_embeds is None:
203
+ inputs_embeds = self.word_embeddings(input_ids)
204
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
205
+
206
+ embeddings = inputs_embeds + token_type_embeddings
207
+ # no position embeddings! ALiBi
208
+ embeddings = self.norm(embeddings)
209
+ embeddings = self.dropout(embeddings)
210
+ return embeddings
211
+
212
+
213
+ class BertUnpadSelfAttention(nn.Module):
214
+ """Performs multi-headed self attention on a batch of unpadded sequences.
215
+
216
+ If Triton is installed, this module uses Flash Attention to greatly improve throughput.
217
+ The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which
218
+ we use to implement ALiBi), but does not support attention dropout. If either Triton is not installed
219
+ or `config.attention_probs_dropout_prob > 0`, the implementation will default to a
220
+ math-equivalent pytorch version, which is much slower.
221
+
222
+ See `forward` method for additional detail.
223
+ """
224
+
225
+ def __init__(self, config):
226
+ super().__init__()
227
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
228
+ config, 'embedding_size'):
229
+ raise ValueError(
230
+ f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
231
+ f'heads ({config.num_attention_heads})')
232
+
233
+ self.num_attention_heads = config.num_attention_heads
234
+ self.attention_head_size = int(config.hidden_size /
235
+ config.num_attention_heads)
236
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
237
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
238
+ self.p_dropout = config.attention_probs_dropout_prob
239
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
240
+ self.max_position_embeddings = config.max_position_embeddings
241
+ self.rotary_emb = RotaryEmbedding(self.attention_head_size, max_position_embeddings=self.max_position_embeddings)
242
+ # Warn if defaulting to pytorch because of import issues
243
+
244
+
245
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
246
+ max_seqlen_in_batch: int, indices: torch.Tensor,
247
+ attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
248
+ """Perform self-attention.
249
+
250
+ If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
251
+ implementation of self-attention.
252
+
253
+ The arguments are unpadded, and our implementations of attention require padded arguments,
254
+ so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
255
+ The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
256
+ It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
257
+
258
+ Args:
259
+ hidden_states: (total_nnz, dim)
260
+ cu_seqlens: (batch + 1,)
261
+ max_seqlen_in_batch: int
262
+ indices: (total_nnz,)
263
+ attn_mask: (batch, max_seqlen_in_batch)
264
+ bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
265
+
266
+ Returns:
267
+ attention: (total_nnz, dim)
268
+ """
269
+ qkv = self.Wqkv(hidden_states)
270
+ qkv = pad_input(
271
+ qkv, indices, cu_seqlens.shape[0] - 1,
272
+ max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
273
+ qkv = rearrange(qkv,
274
+ 'b s (t h d) -> b s t h d',
275
+ t=3,
276
+ h=self.num_attention_heads)
277
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
278
+ q = qkv[:, :, 0, :, :].transpose(1, 2)
279
+ k = qkv[:, :, 1, :, :].transpose(1, 2)
280
+ v = qkv[:, :, 2, :, :].transpose(1, 2)
281
+ kv_seq_len = k.shape[-2]
282
+
283
+ cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)
284
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
285
+ #q = q.transpose(1, 2)
286
+ k = k.permute(0, 1, 3, 2)
287
+ #v = v.transpose(1, 2)
288
+ # q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
289
+ # k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
290
+ # v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
291
+
292
+ attention_scores = torch.matmul(q, k) / math.sqrt(
293
+ self.attention_head_size)
294
+ attention_scores = attention_scores + bias
295
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
296
+ attention_probs = self.dropout(attention_probs)
297
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
298
+ 3) # b s h d
299
+
300
+ # attn_mask is 1 for attend and 0 for don't
301
+ attention = unpad_input_only(
302
+ attention,
303
+ torch.squeeze(attn_mask) == 1)
304
+ return rearrange(attention, 'nnz h d -> nnz (h d)')
305
+
306
+
307
+ # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
308
+ class BertSelfOutput(nn.Module):
309
+ """Computes the output of the attention layer.
310
+
311
+ This module is modeled after the Hugging Face BERT's
312
+ :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`.
313
+ The implementation is identical. Rather than use the original module
314
+ directly, we re-implement it here so that Mosaic BERT's modules will not
315
+ be affected by any Composer surgery algorithm that modifies Hugging Face
316
+ BERT modules.
317
+ """
318
+
319
+ def __init__(self, config):
320
+ super().__init__()
321
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
322
+ self.norm = RMSNorm(config.hidden_size,
323
+ eps=config.layer_norm_eps)
324
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
325
+
326
+ def forward(self, hidden_states: torch.Tensor,
327
+ input_tensor: torch.Tensor) -> torch.Tensor:
328
+ hidden_states = self.dense(hidden_states)
329
+ hidden_states = self.dropout(hidden_states)
330
+ hidden_states = self.norm(hidden_states + input_tensor)
331
+ return hidden_states
332
+
333
+
334
+ class BertUnpadAttention(nn.Module):
335
+ """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
336
+
337
+ def __init__(self, config):
338
+ super().__init__()
339
+ self.self = BertUnpadSelfAttention(config)
340
+ self.output = BertSelfOutput(config)
341
+
342
+ def forward(
343
+ self,
344
+ input_tensor: torch.Tensor,
345
+ cu_seqlens: torch.Tensor,
346
+ max_s: int,
347
+ subset_idx: Optional[torch.Tensor] = None,
348
+ indices: Optional[torch.Tensor] = None,
349
+ attn_mask: Optional[torch.Tensor] = None,
350
+ bias: Optional[torch.Tensor] = None,
351
+ ) -> torch.Tensor:
352
+ """Forward pass for scaled self-attention without padding.
353
+
354
+ Arguments:
355
+ input_tensor: (total_nnz, dim)
356
+ cu_seqlens: (batch + 1,)
357
+ max_s: int
358
+ subset_idx: () set of indices whose values we care about at the end of the layer
359
+ (e.g., the masked tokens, if this is the final layer).
360
+ indices: None or (total_nnz,)
361
+ attn_mask: None or (batch, max_seqlen_in_batch)
362
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
363
+ """
364
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
365
+ attn_mask, bias)
366
+ if subset_idx is not None:
367
+ return self.output(
368
+ index_first_axis(self_output, subset_idx),
369
+ index_first_axis(input_tensor, subset_idx))
370
+ else:
371
+ return self.output(self_output, input_tensor)
372
+
373
+ class MLP(nn.Module):
374
+ def __init__(
375
+ self,
376
+ config
377
+ ):
378
+ super().__init__()
379
+ self.config = config
380
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
381
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
382
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
383
+ self.act_fn = ACT2FN[config.hidden_act]
384
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
385
+
386
+ def forward(self, hidden_states):
387
+ residual_connection = hidden_states
388
+ hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
389
+ hidden_states = self.norm(hidden_states + residual_connection)
390
+ return hidden_states
391
+
392
+ # class BertGatedLinearUnitMLP(nn.Module):
393
+ # """Applies the FFN at the end of each Mosaic BERT layer.
394
+
395
+ # Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
396
+ # and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
397
+ # introduces Gated Linear Units.
398
+
399
+ # Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
400
+ # standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
401
+ # `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
402
+ # with the `config.intermediate_size=3072`.
403
+ # However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
404
+ # parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
405
+ # """
406
+
407
+ # def __init__(self, config):
408
+ # super().__init__()
409
+ # self.config = config
410
+ # self.gated_layers = nn.Linear(config.hidden_size,
411
+ # config.intermediate_size * 2,
412
+ # bias=False)
413
+ # self.act = ACT2FN[config.hidden_act]#nn.GELU(approximate='none')
414
+ # self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
415
+ # self.dropout = nn.Dropout(config.hidden_dropout_prob)
416
+ # self.norm = RMSNorm(config.hidden_size,
417
+ # eps=config.layer_norm_eps)
418
+
419
+ # def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
420
+ # """Compute new hidden states from current hidden states.
421
+
422
+ # Args:
423
+ # hidden_states (torch.Tensor): The (unpadded) hidden states from
424
+ # the attention layer [nnz, dim].
425
+ # """
426
+ # residual_connection = hidden_states
427
+ # # compute the activation
428
+ # hidden_states = self.gated_layers(hidden_states)
429
+ # gated = hidden_states[:, :self.config.intermediate_size]
430
+ # non_gated = hidden_states[:, self.config.intermediate_size:]
431
+ # hidden_states = self.act(gated) * non_gated
432
+ # hidden_states = self.dropout(hidden_states)
433
+ # # multiply by the second matrix
434
+ # hidden_states = self.wo(hidden_states)
435
+ # # add the residual connection and post-LN
436
+ # hidden_states = self.norm(hidden_states + residual_connection)
437
+ # return hidden_states
438
+
439
+
440
+ class BertLayer(nn.Module):
441
+ """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
442
+
443
+ def __init__(self, config):
444
+ super(BertLayer, self).__init__()
445
+ self.attention = BertUnpadAttention(config)
446
+ self.mlp = MLP(config)
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.Tensor,
451
+ cu_seqlens: torch.Tensor,
452
+ seqlen: int,
453
+ subset_idx: Optional[torch.Tensor] = None,
454
+ indices: Optional[torch.Tensor] = None,
455
+ attn_mask: Optional[torch.Tensor] = None,
456
+ bias: Optional[torch.Tensor] = None,
457
+ ) -> torch.Tensor:
458
+ """Forward pass for a BERT layer, including both attention and MLP.
459
+
460
+ Args:
461
+ hidden_states: (total_nnz, dim)
462
+ cu_seqlens: (batch + 1,)
463
+ seqlen: int
464
+ subset_idx: () set of indices whose values we care about at the end of the layer
465
+ (e.g., the masked tokens, if this is the final layer).
466
+ indices: None or (total_nnz,)
467
+ attn_mask: None or (batch, max_seqlen_in_batch)
468
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
469
+ """
470
+ attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
471
+ subset_idx, indices, attn_mask, bias)
472
+ layer_output = self.mlp(attention_output)
473
+ return layer_output
474
+
475
+
476
+ class BertEncoder(nn.Module):
477
+ """A stack of BERT layers providing the backbone of Mosaic BERT.
478
+
479
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
480
+ but with substantial modifications to implement unpadding and ALiBi.
481
+
482
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
483
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
484
+ """
485
+
486
+ def __init__(self, config):
487
+ super().__init__()
488
+ layer = BertLayer(config)
489
+ self.layer = nn.ModuleList(
490
+ [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
491
+
492
+ self.num_attention_heads = config.num_attention_heads
493
+
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states: torch.Tensor,
498
+ attention_mask: torch.Tensor,
499
+ output_all_encoded_layers: Optional[bool] = True,
500
+ subset_mask: Optional[torch.Tensor] = None,
501
+ ) -> List[torch.Tensor]:
502
+
503
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
504
+ extended_attention_mask = extended_attention_mask.to(
505
+ dtype=next(self.parameters()).dtype) # fp16 compatibility
506
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
507
+
508
+ attention_mask_bool = attention_mask.bool()
509
+ batch, seqlen = hidden_states.shape[:2]
510
+ # Unpad inputs and mask. It will remove tokens that are padded.
511
+ # Assume ntokens is total number of tokens (padded and non-padded)
512
+ # and ntokens_unpad is total number of non-padded tokens.
513
+ # Then unpadding performs the following compression of the inputs:
514
+ # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
515
+ hidden_states, indices, cu_seqlens, _ = unpad_input(
516
+ hidden_states, attention_mask_bool)
517
+
518
+ attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
519
+ all_encoder_layers = []
520
+ if subset_mask is None:
521
+ for layer_module in self.layer:
522
+ hidden_states = layer_module(hidden_states,
523
+ cu_seqlens,
524
+ seqlen,
525
+ None,
526
+ indices,
527
+ attn_mask=attention_mask,
528
+ bias=attn_bias)
529
+ if output_all_encoded_layers:
530
+ all_encoder_layers.append(hidden_states)
531
+ # Pad inputs and mask. It will insert back zero-padded tokens.
532
+ # Assume ntokens is total number of tokens (padded and non-padded)
533
+ # and ntokens_unpad is total number of non-padded tokens.
534
+ # Then padding performs the following de-compression:
535
+ # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
536
+ hidden_states = pad_input(
537
+ hidden_states, indices, batch, seqlen)
538
+ else:
539
+ for i in range(len(self.layer) - 1):
540
+ layer_module = self.layer[i]
541
+ hidden_states = layer_module(hidden_states,
542
+ cu_seqlens,
543
+ seqlen,
544
+ None,
545
+ indices,
546
+ attn_mask=attention_mask,
547
+ bias=attn_bias)
548
+ if output_all_encoded_layers:
549
+ all_encoder_layers.append(hidden_states)
550
+ subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
551
+ as_tuple=False).flatten()
552
+ hidden_states = self.layer[-1](hidden_states,
553
+ cu_seqlens,
554
+ seqlen,
555
+ subset_idx=subset_idx,
556
+ indices=indices,
557
+ attn_mask=attention_mask,
558
+ bias=attn_bias)
559
+
560
+ if not output_all_encoded_layers:
561
+ all_encoder_layers.append(hidden_states)
562
+ return all_encoder_layers
563
+
564
+
565
+ class BertPooler(nn.Module):
566
+
567
+ def __init__(self, config):
568
+ super(BertPooler, self).__init__()
569
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
570
+ self.activation = nn.Tanh()
571
+
572
+ def forward(self,
573
+ hidden_states: torch.Tensor,
574
+ pool: Optional[bool] = True) -> torch.Tensor:
575
+ # We "pool" the model by simply taking the hidden state corresponding
576
+ # to the first token.
577
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
578
+ pooled_output = self.dense(first_token_tensor)
579
+ pooled_output = self.activation(pooled_output)
580
+ return pooled_output
581
+
582
+
583
+ class BertPredictionHeadTransform(nn.Module):
584
+
585
+ def __init__(self, config):
586
+ super().__init__()
587
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
588
+ if isinstance(config.hidden_act, str):
589
+ self.transform_act_fn = ACT2FN[config.hidden_act]
590
+ else:
591
+ self.transform_act_fn = config.hidden_act
592
+ self.norm = RMSNorm(config.hidden_size, eps=1e-12)
593
+
594
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
595
+ hidden_states = self.dense(hidden_states)
596
+ hidden_states = self.transform_act_fn(hidden_states)
597
+ hidden_states = self.norm(hidden_states)
598
+ return hidden_states
599
+
600
+
601
+ class BertModel(BertPreTrainedModel):
602
+ """Overall BERT model.
603
+
604
+ Args:
605
+ config: a BertConfig class instance with the configuration to build a new model
606
+
607
+ Inputs:
608
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
609
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
610
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
611
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
612
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
613
+ a `sentence B` token (see BERT paper for more details).
614
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
615
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
616
+ input sequence length in the current batch. It's the mask that we typically use for attention when
617
+ a batch has varying length sentences.
618
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
619
+
620
+ Outputs: Tuple of (encoded_layers, pooled_output)
621
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
622
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
623
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
624
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
625
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
626
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
627
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
628
+ classifier pretrained on top of the hidden state associated to the first character of the
629
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
630
+
631
+ Example usage:
632
+ ```python
633
+ # Already been converted into WordPiece token ids
634
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
635
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
636
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
637
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
638
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
639
+ model = BertModel(config=config)
640
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
641
+ ```
642
+ """
643
+
644
+ def __init__(self, config, add_pooling_layer=True):
645
+ super(BertModel, self).__init__(config)
646
+ self.embeddings = BertEmbeddings(config)
647
+ self.encoder = BertEncoder(config)
648
+ self.pooler = BertPooler(config) if add_pooling_layer else None
649
+ self.post_init()
650
+
651
+ def get_input_embeddings(self):
652
+ return self.embeddings.word_embeddings
653
+
654
+ def set_input_embeddings(self, value):
655
+ self.embeddings.word_embeddings = value
656
+
657
+ def forward(
658
+ self,
659
+ input_ids: torch.Tensor,
660
+ token_type_ids: Optional[torch.Tensor] = None,
661
+ attention_mask: Optional[torch.Tensor] = None,
662
+ position_ids: Optional[torch.Tensor] = None,
663
+ output_all_encoded_layers: Optional[bool] = False,
664
+ masked_tokens_mask: Optional[torch.Tensor] = None,
665
+ **kwargs
666
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
667
+ if attention_mask is None:
668
+ attention_mask = torch.ones_like(input_ids)
669
+ if token_type_ids is None:
670
+ token_type_ids = torch.zeros_like(input_ids)
671
+
672
+ embedding_output = self.embeddings(input_ids, token_type_ids,
673
+ position_ids)
674
+
675
+ subset_mask = []
676
+ first_col_mask = []
677
+
678
+ if masked_tokens_mask is None:
679
+ subset_mask = None
680
+ else:
681
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
682
+ first_col_mask[:, 0] = True
683
+ subset_mask = masked_tokens_mask | first_col_mask
684
+
685
+ encoder_outputs = self.encoder(
686
+ embedding_output,
687
+ attention_mask,
688
+ output_all_encoded_layers=output_all_encoded_layers,
689
+ subset_mask=subset_mask)
690
+
691
+ if masked_tokens_mask is None:
692
+ sequence_output = encoder_outputs[-1]
693
+ pooled_output = self.pooler(
694
+ sequence_output) if self.pooler is not None else None
695
+ else:
696
+ # TD [2022-03-01]: the indexing here is very tricky.
697
+ attention_mask_bool = attention_mask.bool()
698
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
699
+ sequence_output = encoder_outputs[-1][
700
+ masked_tokens_mask[attention_mask_bool][subset_idx]]
701
+ if self.pooler is not None:
702
+ pool_input = encoder_outputs[-1][
703
+ first_col_mask[attention_mask_bool][subset_idx]]
704
+ pooled_output = self.pooler(pool_input, pool=False)
705
+ else:
706
+ pooled_output = None
707
+
708
+ if not output_all_encoded_layers:
709
+ encoder_outputs = sequence_output
710
+
711
+ if self.pooler is not None:
712
+ return encoder_outputs, pooled_output
713
+
714
+ return encoder_outputs, None
715
+
716
+
717
+ ###################
718
+ # Bert Heads
719
+ ###################
720
+ class BertLMPredictionHead(nn.Module):
721
+
722
+ def __init__(self, config, bert_model_embedding_weights):
723
+ super().__init__()
724
+ self.transform = BertPredictionHeadTransform(config)
725
+ # The output weights are the same as the input embeddings, but there is
726
+ # an output-only bias for each token.
727
+ self.weight = nn.Parameter(torch.empty((bert_model_embedding_weights.size(0), bert_model_embedding_weights.size(1))))
728
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
729
+ self.first_flag = True
730
+
731
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
732
+ hidden_states = self.transform(hidden_states)
733
+ if self.training:
734
+ norm_weight = nn.functional.normalize(self.weight)
735
+ self.first_flag = True
736
+ elif self.first_flag:
737
+ self.first_flag = False
738
+ self.weight.data = nn.functional.normalize(self.weight)
739
+ norm_weight = self.weight
740
+ else:
741
+ norm_weight = self.weight
742
+ return nn.functional.linear(hidden_states, norm_weight)
743
+
744
+
745
+ class BertOnlyMLMHead(nn.Module):
746
+
747
+ def __init__(self, config, bert_model_embedding_weights):
748
+ super().__init__()
749
+ self.predictions = BertLMPredictionHead(config,
750
+ bert_model_embedding_weights)
751
+
752
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
753
+ prediction_scores = self.predictions(sequence_output)
754
+ return prediction_scores
755
+
756
+
757
+ class BertOnlyNSPHead(nn.Module):
758
+
759
+ def __init__(self, config):
760
+ super().__init__()
761
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
762
+
763
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
764
+ seq_relationship_score = self.seq_relationship(pooled_output)
765
+ return seq_relationship_score
766
+
767
+
768
+ #####################
769
+ # Various Bert models
770
+ #####################
771
+
772
+
773
+ class BertForPreTraining(BertPreTrainedModel):
774
+ #TBD: Coming in Future Commit
775
+ pass
776
+
777
+
778
+ class BertLMHeadModel(BertPreTrainedModel):
779
+ #TBD: Coming in Future Commit
780
+ pass
781
+
782
+
783
+ class BertForMaskedLM(BertPreTrainedModel):
784
+ config_class = BertConfig
785
+ def __init__(self, config):
786
+ super().__init__(config)
787
+
788
+ if config.is_decoder:
789
+ warnings.warn(
790
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
791
+ 'bi-directional self-attention.')
792
+ self.config = config
793
+ self.bert = BertModel(config, add_pooling_layer=False)
794
+ self.cls = BertOnlyMLMHead(config,
795
+ self.bert.embeddings.word_embeddings.weight)
796
+
797
+ # Initialize weights and apply final processing
798
+ self.post_init()
799
+
800
+ @classmethod
801
+ def from_composer(cls,
802
+ pretrained_checkpoint,
803
+ state_dict=None,
804
+ cache_dir=None,
805
+ from_tf=False,
806
+ config=None,
807
+ *inputs,
808
+ **kwargs):
809
+ """Load from pre-trained."""
810
+ model = cls(config, *inputs, **kwargs)
811
+ if from_tf:
812
+ raise ValueError(
813
+ 'Mosaic BERT does not support loading TensorFlow weights.')
814
+
815
+ state_dict = torch.load(pretrained_checkpoint)
816
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
817
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
818
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
819
+ strict=False)
820
+
821
+ if len(missing_keys) > 0:
822
+ logger.warning(
823
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
824
+ )
825
+ if len(unexpected_keys) > 0:
826
+ logger.warning(
827
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
828
+ )
829
+
830
+ return model
831
+
832
+ def get_output_embeddings(self):
833
+ return self.cls.predictions.weight
834
+
835
+ def set_output_embeddings(self, new_embeddings):
836
+ self.cls.predictions.weight = new_embeddings
837
+
838
+ def forward(
839
+ self,
840
+ input_ids: Optional[torch.Tensor] = None,
841
+ attention_mask: Optional[torch.Tensor] = None,
842
+ token_type_ids: Optional[torch.Tensor] = None,
843
+ position_ids: Optional[torch.Tensor] = None,
844
+ head_mask: Optional[torch.Tensor] = None,
845
+ inputs_embeds: Optional[torch.Tensor] = None,
846
+ encoder_hidden_states: Optional[torch.Tensor] = None,
847
+ encoder_attention_mask: Optional[torch.Tensor] = None,
848
+ labels: Optional[torch.Tensor] = None,
849
+ output_attentions: Optional[bool] = None,
850
+ output_hidden_states: Optional[bool] = None,
851
+ return_dict: Optional[bool] = None,
852
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
853
+ # labels should be a `torch.LongTensor` of shape
854
+ # `(batch_size, sequence_length)`. These are used for computing the
855
+ # masked language modeling loss.
856
+ #
857
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
858
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
859
+ # (masked), the loss is only computed for the tokens with labels in `[0,
860
+ # ..., config.vocab_size]`
861
+ #
862
+ # Prediction scores are only computed for masked tokens and the (bs,
863
+ # seqlen) dimensions are flattened
864
+ if (input_ids is not None) == (inputs_embeds is not None):
865
+ raise ValueError('Must specify either input_ids or input_embeds!')
866
+
867
+ if labels is None:
868
+ masked_tokens_mask = None
869
+ else:
870
+ masked_tokens_mask = labels > 0
871
+
872
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
873
+
874
+ outputs = self.bert(
875
+ input_ids,
876
+ attention_mask=attention_mask,
877
+ token_type_ids=token_type_ids,
878
+ position_ids=position_ids,
879
+ head_mask=head_mask,
880
+ inputs_embeds=inputs_embeds,
881
+ encoder_hidden_states=encoder_hidden_states,
882
+ encoder_attention_mask=encoder_attention_mask,
883
+ output_attentions=output_attentions,
884
+ output_hidden_states=output_hidden_states,
885
+ return_dict=return_dict,
886
+ masked_tokens_mask=masked_tokens_mask,
887
+ )
888
+
889
+ sequence_output = outputs[0]
890
+ prediction_scores = self.cls(sequence_output)
891
+
892
+ loss = None
893
+ if labels is not None:
894
+ # Compute loss
895
+ loss_fct = nn.CrossEntropyLoss()
896
+ softmax_normalizer = prediction_scores.max(-1).values ** 2
897
+ z_loss_weight = 0.2
898
+ z_loss = z_loss_weight * softmax_normalizer.mean()
899
+ # Enable model parallelism
900
+ masked_token_idx = torch.nonzero(labels.flatten() > 0,
901
+ as_tuple=False).flatten()
902
+
903
+ loss = loss_fct(prediction_scores,
904
+ labels.flatten()[masked_token_idx]) + z_loss
905
+ assert input_ids is not None, 'Coding error; please open an issue'
906
+ batch, seqlen = input_ids.shape[:2]
907
+ prediction_scores = rearrange(
908
+ index_put_first_axis(
909
+ prediction_scores, masked_token_idx, batch * seqlen),
910
+ '(b s) d -> b s d',
911
+ b=batch)
912
+
913
+ if not return_dict:
914
+ output = (prediction_scores,) + outputs[2:]
915
+ return ((loss,) + output) if loss is not None else output
916
+
917
+ return MaskedLMOutput(
918
+ loss=loss,
919
+ logits=prediction_scores,
920
+ hidden_states=None,
921
+ attentions=None,
922
+ )
923
+
924
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
925
+ attention_mask: torch.Tensor,
926
+ **model_kwargs):
927
+ input_shape = input_ids.shape
928
+ effective_batch_size = input_shape[0]
929
+
930
+ # add a dummy token
931
+ if self.config.pad_token_id is None:
932
+ raise ValueError('The PAD token should be defined for generation')
933
+
934
+ attention_mask = torch.cat([
935
+ attention_mask,
936
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
937
+ ],
938
+ dim=-1)
939
+ dummy_token = torch.full((effective_batch_size, 1),
940
+ self.config.pad_token_id,
941
+ dtype=torch.long,
942
+ device=input_ids.device)
943
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
944
+
945
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
946
+
947
+
948
+ class BertForNextSentencePrediction(BertPreTrainedModel):
949
+ #TBD: Push in future commit
950
+ pass
951
+
952
+
953
+ class BertForSequenceClassification(BertPreTrainedModel):
954
+ """Bert Model transformer with a sequence classification/regression head.
955
+
956
+ This head is just a linear layer on top of the pooled output. Used for,
957
+ e.g., GLUE tasks.
958
+ """
959
+ config_class = BertConfig
960
+ def __init__(self, config):
961
+ super().__init__(config)
962
+ self.num_labels = config.num_labels
963
+ self.config = config
964
+
965
+ self.bert = BertModel(config)
966
+ classifier_dropout = (config.classifier_dropout
967
+ if config.classifier_dropout is not None else
968
+ config.hidden_dropout_prob)
969
+ self.dropout = nn.Dropout(classifier_dropout)
970
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
971
+
972
+ # Initialize weights and apply final processing
973
+ self.post_init()
974
+
975
+ @classmethod
976
+ def from_composer(cls,
977
+ pretrained_checkpoint,
978
+ state_dict=None,
979
+ cache_dir=None,
980
+ from_tf=False,
981
+ config=None,
982
+ *inputs,
983
+ **kwargs):
984
+ """Load from pre-trained."""
985
+ model = cls(config, *inputs, **kwargs)
986
+ if from_tf:
987
+ raise ValueError(
988
+ 'Mosaic BERT does not support loading TensorFlow weights.')
989
+
990
+ state_dict = torch.load(pretrained_checkpoint)
991
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
992
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
993
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
994
+ strict=False)
995
+
996
+ if len(missing_keys) > 0:
997
+ logger.warning(
998
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
999
+ )
1000
+ if len(unexpected_keys) > 0:
1001
+ logger.warning(
1002
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
1003
+ )
1004
+
1005
+ return model
1006
+
1007
+ def forward(
1008
+ self,
1009
+ input_ids: Optional[torch.Tensor] = None,
1010
+ attention_mask: Optional[torch.Tensor] = None,
1011
+ token_type_ids: Optional[torch.Tensor] = None,
1012
+ position_ids: Optional[torch.Tensor] = None,
1013
+ head_mask: Optional[torch.Tensor] = None,
1014
+ inputs_embeds: Optional[torch.Tensor] = None,
1015
+ labels: Optional[torch.Tensor] = None,
1016
+ output_attentions: Optional[bool] = None,
1017
+ output_hidden_states: Optional[bool] = None,
1018
+ return_dict: Optional[bool] = None,
1019
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1020
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1021
+ # Labels for computing the sequence classification/regression loss.
1022
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
1023
+ # If `config.num_labels == 1` a regression loss is computed
1024
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
1025
+ # is computed (cross-entropy).
1026
+
1027
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1028
+
1029
+ outputs = self.bert(
1030
+ input_ids,
1031
+ attention_mask=attention_mask,
1032
+ token_type_ids=token_type_ids,
1033
+ position_ids=position_ids,
1034
+ head_mask=head_mask,
1035
+ inputs_embeds=inputs_embeds,
1036
+ output_attentions=output_attentions,
1037
+ output_hidden_states=output_hidden_states,
1038
+ return_dict=return_dict,
1039
+ )
1040
+
1041
+ pooled_output = outputs[1]
1042
+
1043
+ pooled_output = self.dropout(pooled_output)
1044
+ logits = self.classifier(pooled_output)
1045
+
1046
+ loss = None
1047
+ if labels is not None:
1048
+ # Compute loss
1049
+ if self.config.problem_type is None:
1050
+ if self.num_labels == 1:
1051
+ self.config.problem_type = 'regression'
1052
+ elif self.num_labels > 1 and (labels.dtype == torch.long or
1053
+ labels.dtype == torch.int):
1054
+ self.config.problem_type = 'single_label_classification'
1055
+ else:
1056
+ self.config.problem_type = 'multi_label_classification'
1057
+
1058
+ if self.config.problem_type == 'regression':
1059
+ loss_fct = nn.MSELoss()
1060
+ if self.num_labels == 1:
1061
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1062
+ else:
1063
+ loss = loss_fct(logits, labels)
1064
+ elif self.config.problem_type == 'single_label_classification':
1065
+ loss_fct = nn.CrossEntropyLoss()
1066
+ loss = loss_fct(logits.view(-1, self.num_labels),
1067
+ labels.view(-1))
1068
+ elif self.config.problem_type == 'multi_label_classification':
1069
+ loss_fct = nn.BCEWithLogitsLoss()
1070
+ loss = loss_fct(logits, labels)
1071
+
1072
+ if not return_dict:
1073
+ output = (logits,) + outputs[2:]
1074
+ return ((loss,) + output) if loss is not None else output
1075
+
1076
+ return SequenceClassifierOutput(
1077
+ loss=loss,
1078
+ logits=logits,
1079
+ hidden_states=None,
1080
+ attentions=None,
1081
+ )
1082
+
1083
+
1084
+ class BertForMultipleChoice(BertPreTrainedModel):
1085
+ #TBD: Push in future commit
1086
+ pass
1087
+
1088
+
1089
+ class BertForTokenClassification(BertPreTrainedModel):
1090
+ #TBD: Push in future commit
1091
+ pass
1092
+
1093
+
1094
+ class BertForQuestionAnswering(BertPreTrainedModel):
1095
+ """Bert Model with a span classification head.
1096
+
1097
+ This is used for extractive question-answering tasks like SQuAD (a linear
1098
+ layers on top of the hidden states' output to compute `span start logits`
1099
+ and `span end logits`).
1100
+ """
1101
+ #TBD: Push in future commit
bert_padding.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
5
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
6
+
7
+ """Helper functions for padding and unpadding batches.
8
+
9
+ These functions are used extensively throughout the Mosaic BERT implementation
10
+ in `bert_layers.py`.
11
+ """
12
+
13
+ from typing import Tuple, cast
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from einops import rearrange, repeat
18
+
19
+
20
+ class IndexFirstAxis(torch.autograd.Function):
21
+
22
+ @staticmethod
23
+ def forward(ctx, input: torch.Tensor,
24
+ indices: torch.Tensor) -> torch.Tensor:
25
+ """Get just the values of `input` which are at `indices`.
26
+
27
+ Arguments:
28
+ ctx: the autograd context object
29
+ input: (b, ...) 2+ dimensional tensor
30
+ indices: (num_idx) 1D tensor
31
+ """
32
+ ctx.save_for_backward(indices)
33
+ assert input.ndim >= 2
34
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
35
+ 1:] # type: ignore
36
+ second_dim = other_shape.numel(
37
+ ) # product of sizes of all but first dimension
38
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
39
+ return torch.gather(
40
+ rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
41
+ 0,
42
+ repeat(indices, 'z -> z d',
43
+ d=second_dim) # (indices,) -> (indices, second_dim)
44
+ ).reshape(-1, *other_shape) # (num_idx, ...)
45
+
46
+ @staticmethod
47
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
48
+ indices, = ctx.saved_tensors
49
+ assert grad_output.ndim >= 2
50
+ other_shape = grad_output.shape[1:]
51
+ grad_output = rearrange(grad_output, 'b ... -> b (...)')
52
+ grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
53
+ device=grad_output.device,
54
+ dtype=grad_output.dtype)
55
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
56
+ # grad_input[indices] = grad_output
57
+ grad_input.scatter_(0,
58
+ repeat(indices, 'z -> z d', d=grad_output.shape[1]),
59
+ grad_output)
60
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
61
+
62
+
63
+ index_first_axis = IndexFirstAxis.apply
64
+
65
+
66
+ class IndexPutFirstAxis(torch.autograd.Function):
67
+
68
+ @staticmethod
69
+ def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
70
+ first_axis_dim) -> torch.Tensor:
71
+ ctx.save_for_backward(indices)
72
+ assert indices.ndim == 1
73
+ assert values.ndim >= 2
74
+ output = torch.zeros(first_axis_dim,
75
+ *values.shape[1:],
76
+ device=values.device,
77
+ dtype=values.dtype)
78
+ output[indices] = values
79
+ return output
80
+
81
+ @staticmethod
82
+ def backward(ctx,
83
+ grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
84
+ indices, = ctx.saved_tensors
85
+ grad_values = grad_output[indices]
86
+ return grad_values, None, None
87
+
88
+
89
+ index_put_first_axis = IndexPutFirstAxis.apply
90
+
91
+
92
+ def unpad_input(
93
+ hidden_states: torch.Tensor,
94
+ attention_mask: torch.Tensor,
95
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
96
+ """Remove padding from input sequences.
97
+
98
+ Arguments:
99
+ hidden_states: (batch, seqlen, ...)
100
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
101
+
102
+ Returns:
103
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
104
+ indices: (total_nnz)
105
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
106
+ max_seqlen_in_batch: int ()
107
+ """
108
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
109
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
110
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
111
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
112
+ (1, 0))
113
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
114
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
115
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
116
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
117
+ # so we write custom forward and backward to make it a bit faster.
118
+ hidden_states = cast(
119
+ torch.Tensor,
120
+ index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
121
+ indices))
122
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
123
+
124
+
125
+ def unpad_input_only(
126
+ hidden_states: torch.Tensor,
127
+ attention_mask: torch.Tensor,
128
+ ) -> torch.Tensor:
129
+ """Like unpad_input, but only return the unpadded first tensor.
130
+
131
+ Save a small amount of overhead.
132
+
133
+ Arguments:
134
+ hidden_states: (batch, seqlen, ...)
135
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
136
+
137
+ Returns:
138
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
139
+ """
140
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
141
+ return index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
142
+ indices)
143
+
144
+
145
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
146
+ seqlen: int) -> torch.Tensor:
147
+ """Add padding to sequences.
148
+
149
+ Arguments:
150
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
151
+ indices: (total_nnz)
152
+ batch: int batch_size
153
+ seqlen: int max sequence length
154
+
155
+ Returns:
156
+ hidden_states: (batch, seqlen, ...)
157
+ """
158
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
159
+ return rearrange(output, '(b s) ... -> b s ...', b=batch)
config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "output_dir_XTBert",
3
+ "alibi_starting_size": 512,
4
+ "architectures": [
5
+ "BertForSequenceClassification"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_bert.BertConfig",
10
+ "AutoModelForMaskedLM": "xiaotinghe/XTBert--bert_layers.BertForMaskedLM",
11
+ "AutoModelForSequenceClassification": "bert_layers.BertForSequenceClassification"
12
+ },
13
+ "bos_token_id": 0,
14
+ "classifier_dropout": null,
15
+ "directionality": "bidi",
16
+ "eos_token_id": 2,
17
+ "gradient_checkpointing": false,
18
+ "hidden_act": "silu",
19
+ "hidden_dropout_prob": 0.1,
20
+ "hidden_size": 768,
21
+ "id2label": {
22
+ "0": "academic disciplines",
23
+ "1": "business",
24
+ "2": "code",
25
+ "3": "communication",
26
+ "4": "culture",
27
+ "5": "economy",
28
+ "6": "education",
29
+ "7": "energy",
30
+ "8": "engineering",
31
+ "9": "entertainment",
32
+ "10": "food and drink",
33
+ "11": "geography",
34
+ "12": "government",
35
+ "13": "history",
36
+ "14": "human behavior",
37
+ "15": "humanities",
38
+ "16": "information",
39
+ "17": "internet",
40
+ "18": "knowledge",
41
+ "19": "language",
42
+ "20": "law",
43
+ "21": "life health",
44
+ "22": "mass media",
45
+ "23": "mathematics",
46
+ "24": "military",
47
+ "25": "nature",
48
+ "26": "people",
49
+ "27": "philosophy",
50
+ "28": "politics",
51
+ "29": "religion",
52
+ "30": "science",
53
+ "31": "society",
54
+ "32": "sports",
55
+ "33": "time"
56
+ },
57
+ "initializer_range": 0.02,
58
+ "intermediate_size": 2048,
59
+ "label2id": {
60
+ "academic disciplines": 0,
61
+ "business": 1,
62
+ "code": 2,
63
+ "communication": 3,
64
+ "culture": 4,
65
+ "economy": 5,
66
+ "education": 6,
67
+ "energy": 7,
68
+ "engineering": 8,
69
+ "entertainment": 9,
70
+ "food and drink": 10,
71
+ "geography": 11,
72
+ "government": 12,
73
+ "history": 13,
74
+ "human behavior": 14,
75
+ "humanities": 15,
76
+ "information": 16,
77
+ "internet": 17,
78
+ "knowledge": 18,
79
+ "language": 19,
80
+ "law": 20,
81
+ "life health": 21,
82
+ "mass media": 22,
83
+ "mathematics": 23,
84
+ "military": 24,
85
+ "nature": 25,
86
+ "people": 26,
87
+ "philosophy": 27,
88
+ "politics": 28,
89
+ "religion": 29,
90
+ "science": 30,
91
+ "society": 31,
92
+ "sports": 32,
93
+ "time": 33
94
+ },
95
+ "layer_norm_eps": 1e-12,
96
+ "max_position_embeddings": 4096,
97
+ "model_type": "bert",
98
+ "num_attention_heads": 12,
99
+ "num_hidden_layers": 12,
100
+ "output_past": true,
101
+ "pad_token_id": 1,
102
+ "pooler_fc_size": 768,
103
+ "pooler_num_attention_heads": 12,
104
+ "pooler_num_fc_layers": 3,
105
+ "pooler_size_per_head": 128,
106
+ "pooler_type": "first_token_transform",
107
+ "position_embedding_type": "absolute",
108
+ "problem_type": "single_label_classification",
109
+ "torch_dtype": "float32",
110
+ "transformers_version": "4.33.2",
111
+ "type_vocab_size": 2,
112
+ "use_cache": true,
113
+ "vocab_size": 39984
114
+ }
configuration_bert.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from transformers import BertConfig as TransformersBertConfig
5
+
6
+
7
+ class BertConfig(TransformersBertConfig):
8
+
9
+ def __init__(
10
+ self,
11
+ alibi_starting_size: int = 512,
12
+ attention_probs_dropout_prob: float = 0.0,
13
+ **kwargs,
14
+ ):
15
+ """Configuration class for MosaicBert.
16
+
17
+ Args:
18
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
19
+ create when initializing the model. You should be able to ignore this parameter in most cases.
20
+ Defaults to 512.
21
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT
22
+ (otherwise, Flash Attention will be off by default). Defaults to 0.0.
23
+ """
24
+ super().__init__(
25
+ attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
26
+ self.alibi_starting_size = alibi_starting_size
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:715241907dbdcbd5783080f5e62ef3fda5985b3883b051016d2154f0b71843a5
3
+ size 465304470