DL_LLM_from_scratch / README.md
RicardoPoleo's picture
Update README.md
db8170e verified

BERT-From-Skratch

Model Description

This model is a language model based on a simple LSTM architecture. It has been trained on a cleaned subset of the Wikitext-103 dataset. The model is intended to generate coherent and meaningful text, but it currently does not perform well in generating coherent text. We plan to improve the architecture to enhance its performance.

Model Architecture

Embedding Dimension: 256 RNN Units: 1024 Layers:

  • Embedding Layer
  • LSTM Layer
  • Dense Layer

Architecture Code

def create_model(batch_size):
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Embedding, LSTM, Dense, Input

    strategy = tf.distribute.MirroredStrategy()  

    with strategy.scope():
        model = Sequential([
            Input(batch_shape=[batch_size, None], name='input_ids'),
            Embedding(VOCAB_SIZE, EMBEDDING_DIM),
            LSTM(RNN_UNITS, return_sequences=True, recurrent_initializer='glorot_uniform'),
            Dense(VOCAB_SIZE, dtype='float32')  
        ])

        def loss(labels, logits):
            return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

        model.compile(optimizer='adam', loss=loss)

    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PREFIX,
        save_weights_only=True
    )
    return model, checkpoint_callback

model, checkpoint_callback = create_model(BATCH_SIZE)

Training Data

The model was trained on the Wikitext-103 dataset, which contains a variety of text from Wikipedia. A subset of 14,000 entries was selected and further cleaned to remove non-ASCII characters, unwanted symbols, and unnecessary spaces.

Preprocessing

The preprocessing steps included:

  • Removing non-ASCII characters.
  • Filtering out empty and very short entries.
  • Removing unwanted characters such as those in brackets, parentheses, and curly braces.
  • Removing consecutive spaces and fixing spaces around punctuation marks.

Tokenization

The BERT tokenizer from the transformers library was used for tokenization. It supports subword tokenization, allowing the model to handle out-of-vocabulary words effectively.

Training Details

  • Optimizer: Adam
  • Loss Function: Sparse Categorical Crossentropy with logits
  • Batch Size: int(64 * 2.1)
  • Epochs: 3 (for debugging purposes, will upload an improved version soon)
  • Mixed Precision Training: Enabled
  • Callbacks: Model checkpointing to save the best model

Limitations

While the model performs basic text generation, it currently does not generate coherent text. We plan to improve the architecture to enhance its performance in future iterations.

Acknowledgments

This model was developed using the Hugging Face transformers library and trained on the Wikitext-103 dataset.