BERT-From-Skratch
Model Description
This model is a language model based on a simple LSTM architecture. It has been trained on a cleaned subset of the Wikitext-103 dataset. The model is intended to generate coherent and meaningful text, but it currently does not perform well in generating coherent text. We plan to improve the architecture to enhance its performance.
Model Architecture
Embedding Dimension: 256 RNN Units: 1024 Layers:
- Embedding Layer
- LSTM Layer
- Dense Layer
Architecture Code
def create_model(batch_size):
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense, Input
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = Sequential([
Input(batch_shape=[batch_size, None], name='input_ids'),
Embedding(VOCAB_SIZE, EMBEDDING_DIM),
LSTM(RNN_UNITS, return_sequences=True, recurrent_initializer='glorot_uniform'),
Dense(VOCAB_SIZE, dtype='float32')
])
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
model.compile(optimizer='adam', loss=loss)
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=CHECKPOINT_PREFIX,
save_weights_only=True
)
return model, checkpoint_callback
model, checkpoint_callback = create_model(BATCH_SIZE)
Training Data
The model was trained on the Wikitext-103 dataset, which contains a variety of text from Wikipedia. A subset of 14,000 entries was selected and further cleaned to remove non-ASCII characters, unwanted symbols, and unnecessary spaces.
Preprocessing
The preprocessing steps included:
- Removing non-ASCII characters.
- Filtering out empty and very short entries.
- Removing unwanted characters such as those in brackets, parentheses, and curly braces.
- Removing consecutive spaces and fixing spaces around punctuation marks.
Tokenization
The BERT tokenizer from the transformers library was used for tokenization. It supports subword tokenization, allowing the model to handle out-of-vocabulary words effectively.
Training Details
- Optimizer: Adam
- Loss Function: Sparse Categorical Crossentropy with logits
- Batch Size: int(64 * 2.1)
- Epochs: 3 (for debugging purposes, will upload an improved version soon)
- Mixed Precision Training: Enabled
- Callbacks: Model checkpointing to save the best model
Limitations
While the model performs basic text generation, it currently does not generate coherent text. We plan to improve the architecture to enhance its performance in future iterations.
Acknowledgments
This model was developed using the Hugging Face transformers library and trained on the Wikitext-103 dataset.