babyGPT (version 1.1.3, 2025-September-16)

babyGPT.py
 
Version: 1.1.3
   
Author: Avinash Kak (kak@purdue.edu)
 
Date: 2025-September-16
 
 

Download Version 1.1.3:  gztar  

 
     Total number of downloads (all versions): 403
     This count is automatically updated at every rotation of
     the weblogs (normally once every two to four days)
     Last updated: Mon Oct 13 06:01:01 EDT 2025

View the main module code file in your browser  
 
Download the text datasets for babyGPT  
 
 
                [See my Twitter posts for examples of the latest results with babyGPT]  
 
 
 
 
THE QUICKEST WAY TO START USING babyGPT:

    1.  Download the module code archive by clicking on the "gztar" link shown above.
        Unpack and install the archive by following the instruction in the
        INSTALLATION section of this documentation page.
        
    2.  Next, download the training dataset by clicking on the link "Download the
        text datasets for babyGPT" shown above.  See the Section "Training Datasets
        Provided" for further details.
 
    3.  Now enter the Example subdirectory of the distribution and enter the pathname
        to the training data in the script
 
                  create_base_model_with_buffered_context.py
 
        The pathname is the value of the variable 'articles_dir' near the beginning of
        the script.
 
    4.  Finally, execute the script named above.  That's all!
 
 
 
CHANGE LOG:
 
Version 1.1.3:
 
    This version allows you to use gradient accumulation to experiment with longer
    sequence lengths for transformer-based learning.  If you don't wish to use the
    gradient accumulation feature, just set gradient_accumulation_steps to zero in
    the call to the constructor for the MasterDecoderWithMasking class.  Another
    improvement in this version is better documentation regarding why I designed
    babyGPT to be trained with streaming datasets.  You will find this explanation in
    the doc-string associated with the class TokenStreamDataset that is derived from
    the class torch.utils.data.IterableDataset.
 
Version 1.1.2:
 
    There was an error in the code that creates the vocab dictionary for the cleaned
    up tokens. I have fixed the error in this version and also provided new
    tokenizers trained on the athlete-news dataset. These are the base tokenizer with
    a vocab of size 50002 tokens and its cleaned-up version whose vocab size is
    35035. The base tokenizer was trained with the target vocab size set to 50000.
    Both tokenizers are in the Examples directory of the distribution.
 
Version 1.1.1:
 
    I have made further enhancements to the tokenizer training code in this version
    in order to discard superfluous tokens, these being tokens that contribute almost
    nothing to the downstream learning tasks.  For example, if the tokenizer training
    has learned two tokens 'abc' and 'abcd' and both predict exactly the same set of
    corpus words, you can discard one of the two without affecting the overall
    expressive power of all the learned tokens. Additionally, since unsupervised
    learning requires estimating the maximum-likelihood probabilities for the next
    token over all the possibilities at that position, I believe that getting rid of
    superfluous tokens can only reduce the noise in the estimation process. I refer
    to getting rid of such tokens as "cleaning up of the tokenizer vocabulary".  In
    this module, the token clean-up logic is implemented in a new function called
    "post_training_cleanup()" that is defined for the inner class TrainTokenizer of
    babyGPT.  As shown by the example script
    create_base_model_with_buffered_context.py in the Examples directory, I call the
    "post_training_cleanup()" function after I have trained a tokenizer with the
    "train_tokenizer()" function.
 
Version 1.1.0:
 
    There was a module packaging error with 1.0.9.  I have fixed the problem in
    1.1.0.
 
Version 1.0.9:
 
    This version applies filtering to the training corpus for improving both the
    tokenizer and the working of the transformer network for next token prediction.
    Text files downloaded from the internet --- and especially the news media
    articles --- include long URL strings that should play no role in the training of
    a tokenizer or the learning required by the transformer network. (It is not
    uncommon for the URL strings to consist of hundreds of characters.)  Version
    1.0.9 accepts a string for downstream processing only if it is shorter than 20
    characters. I chose 20 because (according to Google) the average word length in
    English is only 5 and "virtually all, very close to 100%, of English words have
    fewer than 20 letters." Words that are longer than 20 characters tend to be
    mostly scientific or technical jargon.
 
Version 1.0.8:
 
    This is a PyTorch-Lightning compatible version of babyGPT. It is not so uncommon
    today for a university lab to deploy multiple low-cost GPUs for training a
    network. Using more than one GPU requires refactoring your code so that it
    conforms to the Lightning API.  Version 1.0.8 is an attempt in that direction. In
    addition to code reorganization, I have also made other minor changes to make the
    code more efficient.  For example, I eliminated a not-really-needed inner-loop in
    the overall training loop for the transformer network.  [IMPORTANT: You can still
    use this version for single-GPU based training.  The code will automatically
    detect the number of GPUs available and proceed accordingly.]  Finally, note that
    I have tested Lightning based execution of the code with only the DDP
    (Distributed Data Parallel) strategy for multi-GPU processing.  With this
    strategy, the computational graph created for the model has to fit in each GPU.
    So you cannot construct a larger model just because you are using Lightning for
    multi-GPU support.  All that you get (with the DDP strategy) is that the learning
    process will digest more data faster.  For example, if you are using a 2-GPU VM,
    your effective batch size will double because the two GPUs will be consuming the
    batches in parallel.
 
Version 1.0.7:
 
    There was an error in the definition of BasicDecoderWithMasking that I have fixed
    in this version.  Despite the error, the module worked as intended but not as
    efficiently as one would have expected.
 
Version 1.0.6:
 
    I have fixed the error that caused the predicted tokens to be shifted by one
    position vis-a-vis the ground-truth tokens.
 
Version 1.0.5:
    
    Had a URL error in the setup.py of the previous version. The rest of the module
    remains unchanged.
 
Version 1.0.4:
 
    This is the first public release version of the module. This module was created
    for the Deep Learning class at Purdue University.
 
 
INTRODUCTION:
 
    SPECIFIC GOALS FOR THIS MODULE:
 
    1) To introduce the students in Purdue's Deep Learning class to the foundational
       concepts in how to create a Base Language Model through self-supervised
       learning.  Large Language Models start out as Base Models that are
       subsequently fine-tuned with reinforcement learning.  The focus of this module
       is solely on Base Modeling.
 
    2) To demonstrate small-scale large-language modeling (SS-LLM) that, for
       educational purposes, can be run in the hardware available in a typical
       university lab.
 
    3) To create a self-contained module that, given a set of media URLs, will
       download the articles from those websites (assuming they are not behind
       paywalls), train a BPE tokenizer for the corpus of the articles collected,
       create a Base Model from the corpus, and, subsequently, let you play with the
       model using the prompting script in the module.
 
    My main goal in babyGPT is to demonstrate that, for the purpose of teaching and
    learning, it is possible to create a small-scale end-to-end implementation that
    downloads a corpus of news articles, trains a BPE tokenizer if you need a new one
    for the domain of the corpus you have collected, and, finally, uses the corpus
    for training an autoregressive model for next-token-prediction based on
    unsupervised learning. After you have trained the model, you can test it with the
    prompting script that is included in the Examples directory.
 
 
    LANGUAGE MODELING AND UNSUPERVISED LEARNING:
 
    There is no denying the fact that the recent advances in chatbots have set the
    world on fire. It's truly amazing to see a chatbot returning (most of the time) a
    smooth-reading well-structured narrative in response to a prompt. As if that were
    not enough, it can also supply you with variants of the same narrative depending
    on how you prompt it and your randomization settings for the bot.
 
    One would think that this degree of competency shown by a chatbot would require a
    vast amount of human annotated data for training the neural networks used for the
    bot.
 
    The truth is exactly the opposite.  Most of the learning that takes place in
    order to train a chatbot is unsupervised --- that is, without any human
    supervision. The bot is given the simplest of the goals: To predict the next
    token given the tokens that have been seen so far.  To master this goal, the bot
    needs zero supervision.  All it needs to do is to use its neural network to make
    a prediction for the next token.  And, at training time, should this prediction
    be wrong, to estimate the error made, and then to backpropagate that error while
    adjusting the learnable weights in the network.  Until not too long ago most
    people would have thought that this type of learning would be much too weak to be
    of any practical use. But, as in all engineering, you cannot argue with something
    that actually works.  One great thing that has come out of AI research of the
    last two decades is that unsupervised learning not only works, it actually lends
    itself to designing powerful data driven frameworks without too much human
    intervention.
 
 
    TRANSFORMERS:
 
    The unsupervised learning of the sort described above is best implemented with
    Transformers. (See my material for the Week 13 lecture at Purdue's Deep Learning
    class for a detailed presentation on how one can implement an English-to-Spanish
    translation framework using Transformers.)  And central to a Transformer-based
    architecture is the notion of Attention.  Attention means the extent to which
    each element at the input to a neural network attends to every other element in
    the same input.  For example, in a network for language translation, the network
    would use Attention to figure out the significance of each token in a
    source-language sentence to every other token in the same sentence.  If "car" was
    one of the tokens in a sentence at the input and a subsequent clause in the same
    sentence used the pronoun "it" that pointed to that car, the network would be
    able to figure out the connection between the "it" and the "car" tokens through
    Attention.  Along the same lines, the network would use Cross Attention to figure
    out the importance of each token in the source language to the different tokens
    in the target language.  As you can imagine, understanding such connections
    between the tokens would be critical to any network that is learning how to
    translate a source language sentence into a target language sentence.
 
 
THE MAJOR COMPONENTS of babyGPT:
 
    babyGPT module contains the following Python classes:
 
             (1) ArticleGatherer 
 
             (2) ArticleDataset              [supplies the data downloader for training]
 
             (3) TrainTokenizer 
 
             (4) TransformerFG               [borrowed from Transformers in DLStudio]
 
             (5) MasterDecoderWithMasking;   [borrowed from Transformers in DLStudio]
 
             (6) PromptResponder
 
    In what follows, I'll introduce each of these components one by one.  Each
    component is a separate inner class of the main module class babyGPT.
 
 
    ArticleGatherer:
 
    About the ArticleGatherer, you supply it with a list of URLs to media news sites.
    It then uses the Newspaper module (which understands the structure of a typical
    news HTML file) to download the articles from each of those URLs.  It is
    important to keep in mind that ArticleGatherer skips over non-HTML article files
    at the media websites. Unfortunately, many popular news websites now hide their
    content behind paywalls implemented with JavaScript.  [Examples of such websites
    include www.nyt.com, www.wsj.com, www.bbc.com, etc.] For obvious reasons, if the
    list of the URLs you provide ArticleGatherer consists of mostly such websites,
    the size of the corpus you create for experimenting with babyGPT could be much
    too small to be of any use.
 
 
    ArticleDataset:
 
    After you have used ArticleGatherer to download the news articles for the
    training corpus, the next thing you are going to need is a dataloader. That's
    exactly what's provided by the ArticleDataset class.  It randomly shuffles all
    the articles gathered and creates a number of dataloading streams equal to the
    batch-size that you are using for training babyGPT. The data input for the i^th
    batch instance is provided by the i^th stream. Logically speaking, you can think
    of each stream as a concatenation of the news articles that were randomly chosen
    for that batch instance.
 
 
    TrainTokenizer:
    
    Tokenizers play a critical role in language modeling because they create a
    bounded vocabulary of the tokens that is needed for maximum-likelihood prediction
    of the next token in a narrative.  The token vocabulary is generally constructed
    by using a split-and-merge approach in which you start by considering each
    different word in your corpus as a sequence of the most basic symbols, which can
    be ASCII characters if you are considering purely English text; or the individual
    bytes, as in the BPE (Byte Pair Encoding) tokenizer that can be used for most
    Western languages; or the even more general individual utf-8 encoded multi-byte
    characters if your goal is to create a language agnostic tokenizer.
    Subsequently, you form subwords by, first, merging the most basic constituents
    and, then, merging together smaller subwords into longer subwords by choosing at
    each iteration the most frequently occurring contiguously occurring pair of
    subwords.  The merging process continues until you have reached the specified
    vocabulary size.  What this logic implies is that if a long word in the corpus
    occurs sufficiently frequently, it will be represented by a single token.  On the
    other hand, a relatively short word that occurs rarely in the original corpus
    could be decomposed into shorter tokens.  It is in this manner that, with the
    WordPiece tokenizer, the BERT LLM has a vocabulary of around 30,000 tokens and,
    with the BPE tokenizer, the GPT-3 has a vocabulary of 50,000 tokens. Without such
    tokenization, the size of the vocabulary could grow continuously with the
    size of the corpus.  As you can imagine, if a language modeler is ingesting
    terabytes of text, the vocabulary of the words it sees could run into millions.
    It is not possible to devise the probability-based logic for next-word prediction
    if your underlying vocabulary is unbounded.
 
    The module comes with a pretrained tokenizer with a vocab size of around 50,000
    tokens.  I trained this tokenizer using the babyGPT module on the "Athlete News"
    dataset created by Adrien Dubois. The name of the tokenizer JSON in the Examples
    directory is like: "XYZ_babygpt_tokenizer_PQRST.json".  The prefix "XYZ" says
    that JSON was created with the tokenization code in version X.Y.Z of babyGPT.
    And "PQRST" is an integer that is the actual size of the token vocab.
 
    Starting with Version 1.1.1, you will find the tokenizer named above in a
    subdirectory named "tokenizer_outputs" in the Examples directory of the distro.
    You will also find a cleaned version of the tokenizer in a subdirectory named
    "cleaned_tokenizer_outputs".
 
 
    TransformerFG:
 
    About the TransformerFG component of babyGPT, as mentioned already, language
    modeling is best carried out with Transformer based implementations. To that end,
    I borrowed TransformerFG from DLStudio's Transformers module.  TransformerFG is
    my implementation of the concept of the Transformer as proposed by Vaswani et
    al. in their seminal paper "Attention is All You Need."  The suffix "FG" stands
    for "First Generation."
 
 
    MasterDecoderWithMasking:
 
    The MasterDecoderWithMasking part of babyGPT has also been borrowed from
    DLStudio's Transformers module.  To see the need for this component, note that
    unsupervised learning that is needed for autoregressive language modeling only
    uses the Decoder side of the Encode-Decoder architecture that would otherwise be
    needed for a Transformer-based framework for translating one language into
    another. An example of such a framework is presented in the notes for my Week 14
    lecture at Purdue's Deep Learning class. That framework has two decoder
    implementations: MasterDecoder and MasterDecoderWithMasking.  If you are engaged
    in autoregressive modeling, you have no choice but to use the "WithMasking"
    version of the decoder.  As to the reason for the "Master" prefix in the name of
    the decoder, language modeling typically requires a number of Transformer layers,
    with each layer using multiple Attention Heads to calculate what's known as Self
    Attention. In my DLStudio code, I refer to this layered organization of the
    Transformers as MasterEncoder and MasterDecoder, and to each Transformer layer as
    the BasicEncoder and the BasicDecoder.  
 
    Note that there's an interesting difference between the decoder logic as used in
    language translation and what you need for unsupervised learning in a GPT: When
    used for language translation, the decoder would also calculate Cross Attention,
    which is the attention between each element of the data coursing through the
    decoder and all the elements at the final output of the encoder.  The decoder as
    used for unsupervised learning in a GPT only needs to calculate Self Attention.
 
 
    PromptResponder:
 
    About the final component of babyGPT, PromptResponder, its purpose is to put the
    trained babyGPT model to use by having it respond appropriately to the prompts
    supplied by a user.  Given a prompt in the form of a sentence fragment, the
    PromptResponder uses its next-token prediction ability to keep on generating the
    tokens until it reaches the end-of-sentence token or until it has generated a
    specified number of sentences through this process.
 
 
 
DEALING WITH THE PROBLEM OF CONTEXT DISRUPTION CAUSED BY THE "<SOS>" TOKEN:
 
    What comes in the way of training babyGPT well are the textual discontinuities
    created by how a batch is constructed for each new iteration of training.  As
    explained elsewhere in this doc page, the list of all the documents in the
    training corpus is first randomized and then divided into a number of token
    streams, with one stream for each batch instance. (This randomization of the
    files and the division into token streams is carried out afresh at the beginning
    of each epoch.)  Subsequently, when a fresh batch is needed, for each batch
    instance you "draw" from its corresponding stream a "Context Window" number of
    tokens. The special <SOS> token is placed at the beginning of each such token
    sequence.
 
    This insertion of the <SOS> token disrupts the continuity of the token streams,
    as you can well imagine, and it violates the main point of the learning involved
    which is to learn the continuity properties of the text. Since these continuities
    are context dependent, it would be fair to say that the <SOS> token causes a
    context disruption for the token that comes after <SOS> at the beginning of each
    batch instance.  Over the years, various strategies have been proposed to
    circumvent this problem, one of the most recent being the "sliding-window based
    Attention" as presented by Beltagy, Peters, and Cohan in their 2023 paper
    "Longformer: The Long-Document Transformer".  In this approach, a fixed-sized
    window is used to calculate the attention at the token that is at the center of
    the window.  In this manner, what is calculated for Self Attention is the extent
    to which each token attends to the W/2 tokens on each side of the token at the
    center.  As the authors say: "Using multiple stacked layers of such windowed
    attention results in a large receptive field, where top layers have access to all
    input locations and have the capacity to build representations that incorporate
    information across the entire input."
 
    In keeping with the spirit of babyGPT, I have used a much simpler approach to
    deal with the context disruption problem caused by the <SOS> token.  My solution
    is based on an idea I call "Context Buffer".  The Context Buffer for each token
    sequence in CURRENT batch consists of the last n tokens in the corresponding
    token sequence in the PREVIOUS batch.  These last n tokens, inserted after the
    <SOS> token in the current batch, provided the context for the prediction at the
    first token positions in the current batch.
 
    To elaborate, let's assume that N is the size of the Context Window for your
    Transformer based processing of text and n is the size of the Context Buffer.  At
    every training iteration, for each batch instance you will pull N fresh tokens
    from the dataloader.  You will prepend the n last tokens for the same instance in
    the previous batch to the sequence of N tokens supplied by the dataloader for the
    current batch.  This will result in n+N tokens for transformer-based
    processing. I refer to n+N as the max_seq_length for which you have designed the
    transformer.
 
    It is interesting to note that the above mentioned problem with context
    disruption does NOT arise with sentence-based language modeling (as in BERT)
    since <SOS> is what you would want to use for designating the start of a
    sentence.  (For such learning, you would also use another token, denoted <EOS>
    for the "End of Sequence" indication.)
 
 
 
CONFORMING TO THE LIGHTNING API:
 
Right off the bat, your code must create an instance of the Lightning's Trainer
class.  In my code, this call looks like:
 
        trainer =  Lightning.Trainer(devices=-1, 
                                     accelerator="gpu", 
                                     strategy='ddp', 
                                     enable_progress_bar=False,
                                     logger=logger,
                                     max_epochs=-1,
                                     log_every_n_steps=100
                                    )
 
About the constructor parameters used above, I have used "devices=-1" because I want
babyGPT to run without any changes in both my laptop for debugging and code
development purposes and my 2-GPU VM in our lab cloud.  With the option "devices=-1",
Lightning will discover the number of GPUs available and automatically set "devices"
to that number.  Understanding the option "strategy='ddp'is important if you want to
have realistic expectations of what Lightning can do for you.  The "ddp" strategy
stands for "Distributed Data Parallel".  This strategy launches a number of processes
that are executed in parallel, with one process for each GPU.  While each process
creates its own instance of the dataloader and has its own training loop for forward
propagation of the data and backpropagation of the loss gradients, the processes are
synchronized for updating the model parameters.  The updating of the learnable
parameters is synchronized in the sense that it is based on the backpropagated loss
gradients in all of the processes.
 
Subsequently, you must call "fit()" on the Trainer instance with at least the two
required arguments: the model you want Lightning to train and the dataloader to be
used for training.  Here's an example of this call in babyGPT:
 
        trainer.fit( model=master_decoder,  
                     train_dataloaders= StreamingDataModule(
                                            data_source=dataloader,
                                            context_window_size=dataloader.context_window_size,
                                            batchsize=dataloader.batch_size,
                                            context_buffer_size=dataloader.context_buffer_size,
                                            inv_lookup_fn=dataloader.inverse_lookup)
                                        )
 
The "model" that is in the first argument to "trainer.fit()" is a network that you
want Lightning to train; this model must be subclassed from the class
'Lightning.LightningModule'. Starting with version 1.0.8, in babyGPT, it is the class
MasterDecoderWithMasking that is derived from 'Lightning.LightningModule'.
 
If you have been following the evolution of babyGPT, you will see a sea change
between the versions 1.0.7 and 1.0.8 for how the class MasterDecoderWithMasking is
implemented. You see, a network that is meant to be learned with Lightning must have
following methods defined for it: training_step(), train_dataloader(), and
configure_optimizers().  The first, training_step(), has the code that is meant to be
executed at every iteration of training.  The second, train_dataloader(), must return
a dataloader that is implemented as a Python generator.  That is, the dataloader must
return a new batch through a 'yield' statement in a 'while' loop.  Lightning can
invoke the train_dataloader() automatically to fetch the next batch for a new
training cycle.  Finally, about the function configure_optimizers(), note that
Lightning allows for only specific types of optimizers and learning-rate schedulers
for the optimizers.
 
In the code that follows, the required three functions named above are in the
definition of the class MasterDecoderWithMasking. Note that it is the network created
for this class that carries out autoregressive modeling of a text corpus.  [In the
previous version of the module, v.1.0.7, all of this code was in the method
"run_code_with_buffered_context_for_training_TransformerFG" of the top-level babyGPT
class.] 
 
In addition to the above, here are some pointers relevant to making a software module
ready for Lightning: (1) For single GPU based processing, when you cast a tensor to
type CUDA, it is obvious to the system that you want to prepare that tensor for its
storage in the memory of the GPU that is available.  However, when you have multiple
GPUs, how do you cast a tensor to be of type CUDA for a specific GPU?  Here is an
example of how to best address this problem:
 
       mask = torch.ones(1, device=input_sequence.device, dtype=torch.long)   
 
The goal here is to create a mask that is needed for autoregressive learning. Recall
from my Week 14 lecture, autoregressive modeling is all about predicting the next
token based on the tokens seen so far in the input. This requires progressively
masking the input sequence to indicate to the model how much of the input sequence to
use for the next token prediction. That is, for a fresh iteration of the training
cycle, you will start with the mask consisting of a single 1, which tells the model
to make a prediction for the token at second position in the input based on its
knowing just the first token. Subsequently, you will concatenate another 1 to the
mask and now the model will try to make a prediction for the third token in the input
based on the its knowledge of the first two tokens.  In the code line shown above,
what you see the mask being initialized with a single '1' for a new input sequence.
 
The important point is that the initialization of the mask tensor shown above must
take place separately for each of the GPUs available to Lightning. As it turns out,
Lightning creates a separate process for each of the GPUs. The question then becomes:
How to inform each process as to the type of the CUDA tensor being created as shown
above. As you can see, it is done with the simple expedient of declaring the "device"
to be the same as for the "input_sequence".  That makes completely unambiguous the
destination of the new mask tensor. Regarding the "input_sequence" tensor, for the
DDP (Distributed Data Parallel) strategy, each GPU process runs the dataloader
separately.  Therefore, each process will create its own version of the
"input_sequence" tensor.
 
Another thing to bear in mind about Lightning is that it assumes that all of the
tensors created in the implementation for "training_step()" are meant to be of type
CUDA. So you are spared the need to append ".cuda()" or ".to(device)" to the tensor
initializations as is common for the case of single-GPU based code.
 
 
 
GETTING RID OF SUPERFLUOUS TOKENS
 
Version 1.1.1 includes a new function named "post_training_cleanup()" defined for the
TrainTokenizer class that you can invoke after you have trained a tokenizer in order
to get rid of superfluous tokens.  A token A is superfluous vis-a-vis another token B
if A is a substring of B and if the number of the corpus words that contain the
tokens A and B are exactly the same.  
 
When you invoke the function "post_training_cleanup()", the cleaned-up tokenizer JSON
is deposited in the subdirectory:
 
            cleaned_tokenizer_outputs
 
Ordinarily, the tokenizer JSON that is produced by the function "train_tokenizer()" 
is deposited in the subdirectory:
 
            tokenizer_outputs
 
 
 
EXTENDING A PREVIOUSLY LEARNED TOKENIZER JSON:
 
Let's say you have trained a tokenizer with a target vocabulary of 40,000 and you
want to extend the target vocabulary to, say, 50,000 without having to retrain the
whole thing from scratch.  How does one do that?  It is an important question in a
university lab because tokenizer training is CPU intensive and can take days depending
on your hardware and the size of the target vocabulary.
 
Version 1.1.1 allows you to extend the target vocabulary size for a previously
trained tokenizer.  It is accomplished with the function 
 
              extend_tokenizer_with_additional_training()
 
defined for the inner class TrainTokenizer. The starting point for using this function
should be the following script in the Examples directory:
 
              extend_previously_trained_tokenizer.py
 
Note that this script expects to command-line arguments, with the first being the 
pathname to the previously trained JSON and the second the new target vocab size.
Here's an example:
 
   python3   extend_previously_trained_tokenizer.py   tokenizer_outputs/112_babygpt_tokenizer_50002.json    60000
 
You'll obviously need to make sure that the numbers "112", "20025", and "60000"  
are specific to your request.
 
 
 
TRAINING WITH GRADIENT ACCUMULATION:
 
With the hardware typically available in a university lab (say, you are using GPUs
like NVIDIA A5000 with 24 GB memory), you will find yourself trading off batch-size
for max-sequence-length for the tokens you feed into the transformer (while also
taking into account the embedding size).  Batch size plays a critical role in
learning and you want it to be as large as possible, but not at the cost of making
max-sequence-length too small to be useful.  In Version 1.1.2, I was able to use a
batch-size of 50 for a max-sequence-length of 50 tokens (and with the embedding size
set to 384).
 
How does one get past the constraints described above and give demonstrations with
longer token sequence?  Gradient accumulation is one answer.  Gradient accumulation
allows you to reduce the batch-size and increase the maximum sequence length without
losing learning effectiveness. That is because you accumulate the backpropagated
gradients over multiple steps of training before actually updating the learnable
parameters. So your effective batch-size becomes the number of accumulation steps
times the actual batch size.
 
When using Lightning, it takes only one extra statement in your training code if you
want to take advantage of gradient accumulation.  That is because Lightning is happy
to take care of the rest of the details under the hood.  However, as I have explained
in the doc-string associated with the TokenStreamDataset class, babyGPT was designed
specifically to work with streaming training datasets.  True streaming datasets are
of type IterableDataset and they do NOT lend themselves to distributed sampling that
is the forte of PyTorch Lightning. In such cases, it takes a little bit more work to
take advantage of gradient accumulation during training.  In babyGPT, you will see
the extra statements for gradient accumulation if you search for the string
"gradient_accumulation_steps". If, say, you have set gradient_accumulation_steps to
2, the backproped gradients would be accumulated over 2 steps before the accumulated
values are used to update the model parameters.
 
Beware that there is a price to pay for using gradient accumulation --- your training
time will go up by approximately the same factor as the number of step you are using
for accumulation.  Let's say that without gradient accumulation it takes you two or
three days of training before you start seeing some evidence of learning.  If you
decide to train babyGPT with "gradient_accumulation_steps" set to 2, now it could
take you the better part of a week of training before you start seeing the same sort
of evidence.
 
 
 
EXAMPLES OF THE OUTPUT PRODUCED DURING TRAINING:
 
The examples in this section are based on the assumptions listed below. These
examples were generated by Version 1.1.2 of babyGPT.  As stated elsewhere, the use of
gradient accumulation in Version 1.1.3 allows babyGPT to handle longer sequences than
in the examples here.
 
--  You are using the "Athlete News" dataset that you can download from the module
    website at Purdue.  As mentioned, this dataset was derived by Adrien Dubois from
    the 2017-2018 news reports.
 
--  You are training the model with the following config parameters:
 
 
                    GPUs used   :   TWO NVIDIA A5000 (each with 24 GB)
          Max Sequence Length   :   55
                   Batch Size   :   50
               Embedding Size   :   384                                        
           Transformer Stages   :   6 
    Number of Attention Heads   :   8 
                Learning Rate   :   10^-4  (with Cosine scheduler)
                 Warmup Steps   :   4000
 
 
--  You are running the following script in the Examples directory of the
    distribution:
 
            create_base_model_with_buffered_context.py                       
 
    
The script named above will show the following sort of output every 100 iterations.
For each output, it will randomly select four out of for 50 batch instances and present
the information arranged as follows:
 
 
            Ground-truth
 
            GT Token Seq
 
               Predicted
 
             Detokenized
 
 
where "Ground-truth" is a segment of the text extracted from the downloader for each
batch instance; "GT Token Seq" is tokenized version of the ground-truth; "Predicted"
is the sequence of predictions at each token position by the transformer; and
"Detokenized" is the output of the decoder that joins the tokens back into words.  To
help out with detokenization, babyGPT inserts underscores between the whole words
that are dropped in the detokenization step.  You can see these underscores in the
outputs shown for "GT Token Seq".
 
In the rest of this section, I have displayed an example of the output produced for
each epoch:
 
 
Epoch 0 example:
 
Ground-truth:  ” It was just the Lakers ’ third loss in 11 games since the All - Star break ,  but their fourth against the Warriors_
GT Token Seq: . ” It _ was _ just _ the _ Lakers ’ third _ loss _ in _ 11 _ games _ since _ the _ All - Star _ break , but _ the ir _ fourth _ again st _ the _ War ri or s _
   Predicted: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
 Detokenized: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
 
 
Epoch 1 example:
 
Ground-truth: _with a seven - point lead .  However ,  the Warriors caught up soon enough and then left them in a trail of smoke
GT Token Seq: _ with _ a _ seven - point _ lead . How ever , the _ War ri or s _ caught _ up _ soon _ enough _ an d _ then _ left _ them _ in _ a _ tra il _ of _ smo ke
   Predicted: _ of _ a _ two - time _ lead _ The ever, the _ War ri or s _ are _ up _ a _ to _ to d _ the _ the _ the _ to _ the _ row _ _ of _ the y
 Detokenized: _of a two - time lead Theever,  the Warriors are up a to tod the the the to the row _of they
 
 
Epoch 2 example:
 
Ground-truth: Feb .  2 ,  2018 ,  in Sacramento ,  Calif .  ( AP Photo / Rich Pedroncelli ) Golden State Warriors guard Stephen Curry goes to the
GT Token Seq: Fe b . 2 , 2018 , in _ Sacramento , Cali f . ( AP _ Photo / Rich _ Pe d ron cel li ) Golden _ State _ War ri or s _ guar d _ Stephen _ Curry _ go es _ to _ the
   Predicted: Fe b. 22, 2018. in _ the, Cali f. ( AP _ Photo / Mark _ J der le _ _ ). _ State _ War ri or s _ forward d _ Stephen _ Curry _ ( _ _ to _ the
 Detokenized: Feb.  22,  2018.  in the,  Calif.  ( AP Photo / Mark Jderle _ ). State Warriors forwardd Stephen Curry ( _to the
 
 
Epoch 3 example:
 
Ground-truth: got a Finals MVP [ Andre Iguodala ] that comes off their bench .  ” James pointed out that the Warriors are
GT Token Seq: got _ a _ Finals _ MVP _ [ An d re _ Igu od al a ] that _ co mes _ off _ the ir _ bench . ” James _ poin te d _ out _ that _ the _ War ri or s _ are
   Predicted: got _ to _ lot _ MVP _ an Stephen d re _ Igu od al a _, _ is mes _ to _ the _ _ season _ The The _ is te d _ to _ the _ the _ War ri or s _ have
 Detokenized: got to lot MVP anStephendre Iguodala_, ismes to the _season TheThe isted to the the Warriors have
 
 
Epoch 4 example:
 
Ground-truth: _Warriors ( 58 - 24 ) vs .  No .  7 San Antonio Spurs ( 47 - 35 ) How to watch Game 5 Date : Tuesday ,  April 24_
GT Token Seq: _ War ri or s _ ( 58 - 24 ) vs . No . 7 _ San _ Antonio _ Spurs _ ( 4 7 - 35 ) How _ to _ watch _ Ga me _ 5 _ Da te : Tuesday , April _ 24 _
   Predicted: _ War ri or s _ in 1 - 0 ).. Cleveland _ 1 _ seed _ Antonio _ Spurs : ( 10 ) ) 3 ) an _ did _ watch _ the me _ 1 _ of vi _ " _ June _ 22,
 Detokenized: _Warriors in1 - 0 ).  Cleveland 1 seed Antonio Spurs : ( 10 ) ) 3 ) an did watch theme 1 ofvi " June 22,
 
 
Epoch 5 example:
 
Ground-truth: Kevin Durant and Klay Thompson ,  the Warriors lost Draymond Green in the second quarter to a bruise in_
GT Token Seq: K ev in _ Durant _ an d _ K lay _ Thom p son , the _ War ri or s _ lo st _ Dra y mon d _ Green _ in _ the _ second _ quarter _ to _ a _ bruise _ in _
   Predicted: K lay in _ Durant _ an d _ Dra lay _ Thom p son _ Dra _ War ri or s _ an st _ to y mon d _ Green _ an _ the _ second _ half _ to _ a _ four _ in _
 Detokenized: Klayin Durant and Dralay Thompson Dra Warriors anst toymond Green an the second half to a four in_
 
 
...and so on
 
 
 
For truth in advertising, I must hasten to add that, in what you see above, I have
chosen some of the better looking examples for each epoch.  For every good output
example like those shown above, you will see a large number meaningless gibberish
examples.  As you would expect, at the beginning of training, all output is mostly
gibberish.  However, as the training continues, you begin to see more and more
examples of the output that makes sense.
 
 
 
INSTALLATION:
 
    The babyGPT class was packaged using setuptools.  For installation, execute
    the following command in the source directory (this is the directory that
    contains the setup.py file after you have downloaded and uncompressed the
    gzipped tar archive for the module):
 
            sudo python3 setup.py install
 
    On Linux distributions, this will install the module file at a location that
    looks like
 
             /usr/local/lib/python3.10/dist-packages/
 
    If you do not have root access, you have the option of working directly off
    the directory in which you downloaded the software by simply placing the
    following statements at the top of your scripts that use the
    babyGPT class:
 
            import sys
            sys.path.append( "pathname_to_babyGPT_directory" )
 
    To uninstall the module, simply delete the source directory, locate where the
    babyGPT module was installed with "locate
    babyGPT" and delete those files.  As mentioned above, the full
    pathname to the installed version is likely to look like
    /usr/local/lib/python2.7/dist-packages/babyGPT*
 
    If you want to carry out a non-standard install of the babyGPT
    module, look up the on-line information on Disutils by pointing your browser
    to
 
              http://docs.python.org/dist/dist.html
 
USAGE:
 
    If you want to use babyGPT for unsupervised learning of a base model for a text
    corpus, you would need to construct an instance of the main babyGPT class and its
    supporting classes as follows:
 
    baby_gpt = babyGPT(
                        max_seq_length = max_seq_length,
                        batch_size = batch_size,
                        embedding_size = embedding_size,
                        num_basic_decoders = num_basic_decoders,
                        num_atten_heads = num_atten_heads,
                        optimizer_params = optimizer_params,
                        num_warmup_steps = num_warmup_steps,
                        masking = masking,
                        verify_text_corpus = False,
                        path_saved_model = {"decoder" : "./saved_decoder",                                                             
                                            "embedding_generator" : "./saved_embedding_generator",                             
                                           },
                      )
    
    xformer = baby_gpt.TransformerFG( 
                        max_seq_length = max_seq_length,
                        embedding_size = embedding_size,
                        tokenizer_json = tokenizer_json,
                        num_warmup_steps = num_warmup_steps,
                        optimizer_params = optimizer_params,
              )
    
 
    master_decoder = baby_gpt.MasterDecoderWithMasking(
                        xformer,
                        num_basic_decoders = num_basic_decoders,
                        num_atten_heads = num_atten_heads,
                        context_window_size = context_window_size,
                        context_buffer_size = context_buffer_size,
                        batch_size = batch_size,
                        gradient_accumulation_steps = gradient_accumulation_steps,
                        masking = masking
                     )
    
    dataloader = baby_gpt.ArticleDatasetWithBufferedContext(
                        gpt = baby_gpt,
                        tokenizer_json = tokenizer_json,
                        context_window_size = context_window_size,
                        context_buffer_size = context_buffer_size,
                        articles_dir = articles_dir,
                 )
 
    
 
THE Examples DIRECTORY:
 
    This directory contains the following six scripts for working with babyGPT:
 
 
        1.  run_gatherer.py
 
            This script is for collecting a corpus for experimenting with babyGPT.
            The script requires a list of URLs as article sources as illustrated
            by the following example:
 
                urls = ['https://finance.yahoo.com','http://cnn.com',
                        'https://sports.yahoo.com',
                        'https://purdueexponent.org','https://slate.com',
                        'https://timesofindia.indiatimes.com',
                        'http://cnn.com',
                        'https://slate.com'
                       ]
 
 
        2.  train_tokenizer.py
 
            If the text corpus you have collected is for a specialized domain (such
            as movies, sports, healthcare, etc.), you are likely to get better
            results from babyGPT if you first train a new tokenizer for that domain.
            You train a new tokenizer merely by invoking this script after you have
            set its variable "articles_dir" so that it points to the corpus 
            directory.
 
 
        3.  apply_tokenizer.py
 
            If you have created a new JSON file for the tokenizer, this script is
            just to test the tokenizer on a small txt file.  To get started with using
            this script, try it out with the following command line:
 
               python3  apply_tokenizer.py   text_sample_for_testing.txt   112_babygpt_tokenizer_50002.json
 
           where the sample file "text_sample_for_testing.txt" should already be in
           the Examples directory of the distro and where the last arg is the JSON
           you are testing.  Make sure the name of tokenizer JSON is what you are
           testing.
 
        
        4.  extend_previously_trained_tokenizer.py
 
            You need to run this script only if you wish to extend a previously
            trained tokenizer with a larger target vocabulary.  Pay attention to the
            call syntax for this script since it expects command-line arguments.
            Here is an example:
 
              python3   extend_previously_trained_tokenizer.py   tokenizer_outputs/112_babygpt_tokenizer_50002.json    60000
 
            which says you want to extend the JSON in the penultimate arg with a
            new target vocab size of 60000.
 
 
        5.  create_base_model_with_buffered_context.py
 
            This is the script to run if you want to create a Base Model for your
            corpus.  By Base Model I mean a language model acquired through
            unsupervised learning from a training corpus.  Since this script calls on
            the core language modeling functionality of babyGPT, you have to set a
            relatively large number of parameters in the script.  These parameters
            are shown below:
 
                articles_dir
                tokenizer_json 
                max_seq_length 
                context_window_size
                context_buffer_size
                batch_size 
                embedding_size
                num_atten_heads 
                num_basic_decoders 
                optimizer_params
                num_warmup_steps
 
 
        6.  interact_with_prompts.py
 
            This is the script for interacting with a trained babyGPT model through
            prompts.  The idea is that you supply a small number of words (as, say,
            the beginning of a new thought) as a prompt and the model supplies the
            rest of the words to complete the thought.  At this time, the model
            extends your prompt until it reaches a period (or the end dictated by the
            size of the "max_seq_length" parameter.
 
 
 
THE TRAINING DATASETS PROVIDED:
 
    Click on the following link near the beginning of this documentation page:
 
                    "Download the text datasets for babyGPT" 
 
    in order to download the following training data archive
 
                          datasets_for_babyGPT.tar.gz
 
    Save the archive in the Examples directory of the distribution.  Now execute the
    following command:
 
                 tar zxvf datasets_for_babyGPT.tar.gz
 
    This command will create a 'data' subdirectory in the 'Examples' directory                                                  
    and deposit the datasets mentioned below in that subdirectory:
 
                 saved_Adrien_News_Articles_56M
 
                 saved_articles_dir_12M
 
    The first is the "Athlete News" corpus created by Adrien Dubois. The suffix "56M"
    in the name of the corpus refers to the fact that the corpus consists of roughly
    56 Million multi-byte Unicode characters with utf-8 encoding.
 
    The second is a much smaller corpus for debugging purposes.  It is based on the
    news articles I downloaded with the "run_gatherer.py" script in the Examples 
    directory.
 
 
 
BUGS:
 
    Please notify the author if you encounter any bugs.  When sending email,
    please place the string 'babyGPT' in the subject line to get past his
    spam filter.
 
 
ACKNOWLEDGMENTS:
 
    I must thank Aditya Chauhan for pulling me into the world of multi-GPU training
    with PyTorch Lightning.  If you find useful any of the pointers I have provided
    for making your code compatible with the Lightning API, the primary credit for
    that should go to Aditya.  Aditya, currently pursuing a PhD in Purdue RVL, is
    also our foremost expert in OpenStack based cloud computing.  I'd also like to
    thank Amith Kashyap and Adrien Dubois, both also PhD candidates in RVL, for many
    insightful conversations about deep learning, in general, and about mult-GPU 
    computing in particular.  As previously mentioned in this page, Adrien is also 
    the creator of the "Athlete News" dataset that I provide through this module 
    webpage at Purdue.
 
 
ABOUT THE AUTHOR:
 
    The author, Avinash Kak, is a professor of Electrical and Computer Engineering at
    Purdue University.  For all issues related to this module, contact the author at
    "kak@purdue.edu". If you send email, please place the string "babyGPT" in your
    subject line to get past his spam filter.
 
 
COPYRIGHT:
 
    Python Software Foundation License
 
    Copyright 2025 Avinash Kak
 
@endofdocs

 
Imported Modules
       
lightning
matplotlib.animation
blingfire
copy
glob
itertools
json
logging
math
newspaper
torch.nn
numpy
torch.optim
os
matplotlib.pyplot
random
re
signal
string
sys
time
torch
torchvision
torchvision.transforms

 
Classes
       
builtins.object
babyGPT

 
class babyGPT(builtins.object)
    babyGPT(*args, **kwargs)
 

 
  Methods defined here:
__init__(self, *args, **kwargs)
Initialize self.  See help(type(self)) for accurate signature.
run_code_with_buffered_context_for_training_TransformerFG(self, xformer, master_decoder, dataloader, checkpoint_frequency=4000, display_train_loss=False)
Drawn from the training routines in the Transformer module of DLStudio
save_checkpoint_decoder(self, decoder, dir_name, iter_index)
Save the decoder checkpoint
save_checkpoint_embedding_generator(self, embedding_generator, dir_name, iter_index)
save checkpoint for the embedding_generator
save_decoder(self, decoder)
Save the trained decoder to a disk file
save_embedding_generator(self, embedding_generator)

Data descriptors defined here:
__dict__
dictionary for instance variables (if defined)
__weakref__
list of weak references to the object (if defined)

Data and other attributes defined here:
ArticleDatasetWithBufferedContext = <class 'babyGPT.babyGPT.ArticleDatasetWithBufferedContext'>
This class supplies the 'foundational' dataloader for training. When using the PyTorch Lightning 
module for for multi-GPU training, this dataloader is routed through Lightning's LightningDataModule 
class as you will see later in this code file.  Lightning requires its dataloaders to be Python 
generators.
 
The parameter 'context_window_size' is the number of fresh tokens that the dataloader must supply in
each training iteration.  And the parameter 'context_buffer_size' is the number of trailing tokens
in the previous batch that are prepended to the fresh tokens in the current batch. The number of 
tokens that the transformer network sees is the sum of these two sizes.  
 
The sum of context_window_size and context_buffer_size is referred to as 'max_seq_length' in the 
code.
ArticleGatherer = <class 'babyGPT.babyGPT.ArticleGatherer'>
This script is for collecting data for experimenting with the Transformer based unsupervised learning 
code in baby_gpt.py.  
 
The articles are downloaded from the URLs that are specified by the argument 'urls' in the constructor 
shown below.  See the script "create_base_model_with_buffered_context.py" in the Examples directory
for how to set the URL strings for this argument.  Here are some examples:
 
    urls = ['https://finance.yahoo.com','http://cnn.com',
             'https://timesofindia.indiatimes.com',
             'https://purdueexponent.org','https://slate.com', 
             'https://sports.yahoo.com']
 
    urls = ['http://cnn.com']
 
    urls = ['https://slate.com']
 
    urls = ['https://timesofindia.indiatimes.com']
AttentionHead = <class 'babyGPT.babyGPT.AttentionHead'>
Borrowed from the Transformers module of DLStudio
BasicDecoderWithMasking = <class 'babyGPT.babyGPT.BasicDecoderWithMasking'>
Borrowed from the Transformers module of DLStudio
EmbeddingGenerator = <class 'babyGPT.babyGPT.EmbeddingGenerator'>
MasterDecoderWithMasking = <class 'babyGPT.babyGPT.MasterDecoderWithMasking'>
This class was borrowed initially from the Transformers module of the DLStudio platform.  Subsequently, its 
definition was significantly expanded to fulfill the constraints imposed by the PyTorch Lightning API.
For information regarding the operation of this class, please visit the website for DLStudio at Purdue.
PromptResponder = <class 'babyGPT.babyGPT.PromptResponder'>
Prompting a trained babyGPT models means that you supply a small number of words (as, say, the 
beginning of a new thought) as a prompt and the model supplies the rest of the words to complete 
the thought.  The class comes with two methods, the first for extending your prompt until it 
reaches a period, and the second for going beyond the first period encountered.
 
Any interaction with a trained GPT model has to deal with the following issue:  What to do with
the context buffer that is meant to be a continuation of the last part of the previous "sentence"
fed into the transformer. 
 
Ideally, we should be placing in the context buffer words that create a context for the prompt.
But there is no easy way to that without a more elaborate model. An example of more elaborate
modeling would be to have the input to the transformer consist of, say, an SOS token, a special
context token consisting possibly of integer index values beyond the tokenizer vocab, followed
by a context buffer that would be the last part of the previous sentence, followed, finally, 
by the new input tokens.
 
babyGPT gives you two options regarding what to do with the context buffer for your prompt:
 
        --  all_zeros
 
        --  get_from_prompt
 
With the first option, all of the integer encoding values in the context buffer are set to
the integer zero.  And, with the second option, at this time, the context buffer contains
a portion or all of the prompt itself.  If the tokenized version of the prompt is shorter
than the size of the context buffer, only the context_buffer_size number of elements of the
prompt are retained for the context buffer.  In the opposite case, just the initial 
context_buffer_size number of elements of the prompt are retained.
SelfAttention = <class 'babyGPT.babyGPT.SelfAttention'>
Borrowed from the Transformers module of DLStudio
TrainTokenizer = <class 'babyGPT.babyGPT.TrainTokenizer'>
Tokenizers play a critical role in language modeling because they create a fixed-sized vocabulary 
for the corpus you are working with --- regardless of the size of the corpus itself.  Unless your 
text corpus is based on a set of documents frozen in time, ordinarily, as the size of a text corpus 
goes up, so does the size of the vocabulary --- despite the illusion to the contrary created by 
the fixed sizes of the language dictionaries you have seen all your life.  How we express ourselves 
is a living thing.  We are constantly inventing new words and new expressions; these form important 
components of  what's referred to as the zeitgeist.
 
Having a fixed-sized vocab is important because the loss functions used in deep-learning network 
used for language processing are based on maximum-likelihood prediction of the next token given 
the tokens seen previously.  That requires estimating the probabilities associated with all
possible tokens at the next position.  As you can imagine, it would be impossible to engage in 
such probabilistic reasoning if you did not know in advance the size of the vocabulary.
 
Added in version 1.0.9: Here's an important point to remember when you are training a tokenizer 
on media articles collected from the internet at large: The articles frequently contain long 
URL strings that should play no role in either training the tokenizer for a new corpus or in 
training a transformer network for next-token prediction. What makes this problem worse is that 
such strings may consist of hundreds of characters --- because some media URLs these days 
contain the full title of the articles they point to. In addition, parts of a URL (such as the 
Query part) may also be encoded --- that is, consist of a seemingly gibberish sequence of
characters.  To get rid of such strings, starting with Version 1.0.9, I filter out all strings 
that are longer than 20 characters.  This I do both for Tokenizer training and when reading in 
the text data for training the transformer network.
TransformerFG = <class 'babyGPT.babyGPT.TransformerFG'>
This I have borrowed from the DLStudio's Transformers module.  "FG" stands for "First Generation" --- which is 
the Transformer as originally proposed by Vaswani et al.

 
Functions
       
ctrl_c_handler(signum, frame)
gen(container)

 
        __author__ = 'Avinash Kak (kak@purdue.edu)'
__copyright__ = '(C) 2025 Avinash Kak. Python Software Foundation.'
__date__ = '2025-September-16
__url__ = 'https://engineering.purdue.edu/kak/distBabyGPT/babyGPT-1.1.3.html'
__version__ = '1.1.3'
 
Author
        Avinash Kak (kak@purdue.edu)