babyGPT (version 1.0.6, 2025-April-22)

babyGPT.py
 
Version: 1.0.6
   
Author: Avinash Kak (kak@purdue.edu)
 
Date: 2025-April-22
 
 

Download Version 1.0.6:  gztar  

 
     Total number of downloads (all versions): 109
     This count is automatically updated at every rotation of
     the weblogs (normally once every two to four days)
     Last updated: Sun May 4 06:09:01 EDT 2025

View the main module code file in your browser  
 
Download the text datasets for babyGPT  
 
 
 
CHANGES:

  Version 1.0.6:
 
     I have fixed the error that caused the predicted tokens to be shifted by one
     position vis-a-vis the ground-truth tokens.
 
  Version 1.0.5:
    
     Had a URL error in the setup.py of the previous version. The rest of the module
     remains unchanged.
 
  Version 1.0.4:
 
    This is the first public release version of the module. This module was created
    for the Deep Learning class at Purdue University.
 
 
INTRODUCTION:
 
    SPECIFIC GOALS FOR THIS MODULE:
 
    1) To introduce the students in Purdue's Deep Learning class to the foundational
       concepts in how to create a Base Language Model through self-supervised
       learning.  Large Language Models start out as Base Models that are
       subsequently fine-tuned with reinforcement learning.  The focus of this module
       is solely on Base Modeling.
 
    2) To demonstrate small-scale large-language modeling that for educational
       purposes can be run on a typical university lab GPU.
 
    3) To create a self-contained module that, given a set of media URLs, will
       download  the articles from those websites (assuming they are not behind a
       paywall), train a BPE tokenizer from the corpus of the articles collected,
       create a Base Model from the corpus, and, subsequently, let you play with the
       model using the prompting script in the module.
 
    My main goal in babyGPT is to demonstrate that, for the purpose of teaching and
    learning, it is possible to create a small-scale end-to-end implementation that
    downloads a corpus of news media articles, trains a BPE tokenizer if you need a
    new one for the domain of the corpus you have collected, and, finally, uses the
    corpus for training an autoregressive model for the next token prediction based
    on unsupervised learning. After you have trained the model, you can test it with
    the prompting script that is included in the Examples directory. 
 
 
    LANGUAGE MODELING AND UNSUPERVISED LEARNING:
 
    There is no denying the fact that the recent advances in chatbots have set the
    world on fire. It's truly amazing to see a chatbot returning (most of the time) a
    smooth-reading and well-structured narrative in response to a prompt. As if that
    were not enough, it can also supply you with variants of the same narrative
    depending on how you prompt it and your randomization settings for the bot.
 
    One would think that this degree of competency shown by a chatbot would require a
    vast amount of human annotated data for training the neural networks used for the
    bot.
 
    The truth is exactly the opposite.  Most of the learning that takes place in
    order to train a chatbot is unsupervised --- that is, without any human
    supervision. The bot is given the simplest of the goals: To predict the next word
    given the words that have been seen so far.  To master this goal, the bot needs
    zero supervision.  All it needs to do use its neural network make a prediction
    for the next word.  And, at training time, should this prediction be wrong, to
    estimate the error made, to backpropagate that error and thus adjust the
    learnable weights in the network.  Until not too long ago most people would have
    thought that this type of learning would be much too weak to be of any practical
    use. But, as in all engineering, you cannot argue with something that actually
    works.  One great thing that has come out of AI research of the last two decades
    is that unsupervised learning not only works, it actually lends itself to
    designing powerful data driven frameworks without too much human intervention.
 
 
    TRANSFORMERS:
 
    The unsupervised learning of the sort described above is best implemented with
    Transformers. (See my material for the Week 13 lecture at Purdue's Deep Learning
    class for a detailed presentation on how can implement an English-to-Spanish
    translation framework using Transformers.)  And central to the Transformer
    architecture element is the notion of Attention.  Attention means how each
    element at the input to a neural network attends to every other element in the
    same input.  In a network for language translation, the network would also use
    Attention to figure out the significance of each word in a source-language
    sentence to every other word. For example, if the car was one of the words in a
    sentence at the input and a clause in the sentence used the pronoun "it" that
    pointed to that car, the network would be able to figure out the connection
    between the "it" and the "car" through attention.  Along the same lines, the
    network would use Cross Attention to figure out the importance of each word in the
    source language to the different words in the target language.  As you can
    imagine, understanding such connections between the words would be critical for
    any network that is learning how to translate a source language sentence into a
    target language sentence.
 
 
THE MAJOR COMPONENTS of babyGPT:
 
    babyGPT module contains the following Python classes:
 
             (1) ArticleGatherer 
 
             (2) ArticleDataset              [supplies the data downloader for training]
 
             (3) TrainTokenizer 
 
             (4) TransformerFG               [borrowed from Transformers in DLStudio]
 
             (5) MasterDecoderWithMasking;   [borrowed from Transformers in DLStudio]
 
             (6) PromptResponder
 
    In what follows, I'll introduce each of these components one by one.  Each
    component is a separate inner class of the main module class babyGPT.
 
 
    ArticleGatherer:
 
    About the ArticleGatherer, you supply it with a list of URLs to media news sites.
    It then uses the Newspaper module (which understands the structure of a typical
    news HTML file) to download the articles from each of those URLs.  It is
    important to keep in mind that ArticleGatherer skips over non-HTML article files
    at the media websites. Unfortunately, many popular news websites now hide their
    content behind paywalls implemented with JavaScript.  [Examples of such websites
    include www.nyt.com, www.wsj.com, www.bbc.com, etc.] For obvious reasons, if the
    list of the URLs you provide ArticleGatherer consists of mostly such websites, the
    size of the corpus you create for experimenting with babyGPT could be much to
    small to be any fun.
 
 
    ArticleDataset:
 
    After you have used ArticleGatherer to download the news articles for the
    training corpus, the next thing you are going to need is a dataloader. That's
    exactly what's provided by the ArticleDataset class.  It randomly shuffles all
    the articles gathered and creates a number of dataloading streams equal to the
    batch-size that you are using for training babyGPT. The data input for the i^th
    batch instance is provided by the i^th stream. Logically speaking, you can think
    of each stream as a concatenation of the news articles that were randomly chosen
    for that batch instance.
 
 
    TrainTokenizer:
    
    Tokenizers play a critical role in language modeling because they create a
    bounded vocabulary of the tokens that the language model must understand. This is
    done by using a split-and-merge approach in which you start by considering each
    different word in your corpus as a sequence of the most basic symbols, which can
    be ASCII characters as in the WordPiece tokenizer or the individual bytes, as in
    the BPE (Byte Pair Encoding) tokenizer.  Subsequently, you form subwords by,
    first, merging the most basic constituents like the bytes and, then, merging
    smaller subwords into longer subwords, on the basis of the frequencies of the
    merged subwords vis-a-vis the frequencies of the components that were merged. The
    merging process continues until you have reached the specified vocabulary size.
    What this logic implies is that if a long word in the corpus occurs sufficiently
    frequently, it will be represented by a single token.  On the other hand, a
    relatively short word that occurs rarely in the original corpus could be
    decomposed into shorter tokens.  It is in this manner that, with the WordPiece
    tokenizer, the BERT LLM has a vocabulary of around 30,000 tokens and, with the
    BPE tokenizer, the GPT-3 has a vocabulary of 50,000 tokens. Without such
    tokenization, the size of the vocabulary could grow continuously with the the
    size of the corpus.  As you can imagine, if a language modeler is ingesting
    terabytes of text, the vocabulary of the words it sees could run into millions.
    It is not possible to devise the probability-based logic for next-word prediction
    if your underlying vocabulary is unbounded.
 
    The module comes with a pre-trained tokenizer with a vocab size of around
    50,000 tokens.  I trained this tokenizer using the babyGPT module on the athlete
    news dataset created by Adrien Dubois. The name of the tokenizer JSON in the
    Examples directory is: 104_babygpt_tokenizer_49270.json 
 
 
    TransformerFG:
 
    About the TransformerFG component of babyGPT, as mentioned already, language
    modeling is best carried out with Transformer based implementations. To that
    end, I borrowed TransformerFG from DLStudio's Transformers module.
    TransformerFG is my implementation of the concept of the Transformer as
    proposed by Vaswani et al.  in their seminal paper "Attention is All You
    Need."  The suffix "FG" stands for "First Generation."
 
 
    MasterDecoderWithMasking:
 
    The MasterDecoderWithMasking part of babyGPT has also been borrowed from
    DLStudio's Transformers module.  To see the need for this component, note that
    unsupervised learning that is needed for autoregressive language modeling only
    uses the Decoder side of the Encode-Decoder paper that would otherwise be
    needed for a Transformer-based framework for translating one language into
    another. An example of such a framework is presented in the notes for my Week
    14 lecture at Purdue's Deep Learning class. That framework has two decoder
    implementations: MasterDecoder and MasterDecoderWithMasking.  If you are
    engaged in autoregressive modeling, you have no choice but to use the
    "WithMasking" version of the decoder.  As to the reason for the "Master"
    prefix in the name of the decoder, a language modeling code typically requires
    a number of Transformer layers, with each layer using multiple Attention Heads
    to calculate what's known as Self Attention. In my DLStudio code, I have
    refers to this layered organization of the Transformers as MasterEncoder and
    MasterDecoder, and to each Transformer layer as the BasicEncoder and the
    BasicDecoder.  Note that there's an interesting difference between the decoder
    logic as used in language translation and what you need for unsupervised
    learning in a GPT: When used for language translation, the decoder would also
    calculate Cross Attention, which is the attention between each element of the
    data coursing through the decoder and all the elements at the final output of
    the encoder.  The decoder as used for unsupervised learning in a GPT only
    needs to calculate Self Attention.
 
 
    PromptResponder:
 
    About the final component of babyGPT, PromptResponder, its purpose is to put the
    trained babyGPT model to use by having it respond appropriately to the prompts
    supplied by a user.  Given a prompt in the form of a sentence fragment, the
    PromptResponder uses its next-token prediction ability to keep on generating the
    tokens until it reaches the end-of-sentence token or until it has generated a
    specified number of sentences through this process.
 
 
DEALING WITH THE PROBLEM OF CONTEXT DISRUPTION CAUSED BY THE "<SOS>" TOKEN:
 
    What comes in the way of training babyGPT are the textual discontinuities created
    by how a batch is constructed for each new iteration of training.  As explained
    elsewhere in this doc page, the list of all the documents in the training corpus
    is first randomized and then divided into a number of token streams, with one
    stream for each batch instance. (This randomization of the files and the
    division into token streams is carried out afresh at the beginning of each
    epoch.)  Subsequently, when a fresh batch is needed, for each batch instance you
    "draw" from its corresponding stream a max_seq_length number of tokens. The
    special <SOS> token is placed at the beginning of each such token stream segment
    and another special token <EOS> at the end.
 
    This insertion of the <SOS> and <EOS> tokens disrupts the continuity of the token
    streams as you imagine --- which runs contrary to the main point of the exercise
    which is to learn the continuity properties. Since the narrative continuity
    properties are context dependent, it would be fair to say that the <SOS> token
    causes a context disruption for the token that comes after <SOS> at the beginning
    of each batch instance.  Over the years, various strategies have been proposed to
    circumvent this problem, one of the most recent being the "sliding-window based
    Attention" as presented by Beltagy, Peters, and Cohan in their 2023 paper
    "Longformer: The Long-Document Transformer".  In this approach, a fixed-sized
    window is used to calculate the attention at the token that is at the center of
    the window.  In this manner, what is calculated for Self Attention is the extent
    to which each token attends to the W/2 tokens on each side of the token at the
    center.  As the authors say: "Using multiple stacked layers of such windowed
    attention results in a large receptive field, where top layers have access to all
    input locations and have the capacity to build representations that incorporate
    information across the entire input."
 
    In keeping with the spirit of babyGPT, I have used a much simpler approach to
    deal with the context-disruption problem created by the <SOS> token.  My
    solution is based on the idea I call "Context Buffer".  In the token input
    stream that corresponds to each batch instance, a context buffer is the last n
    tokens that are meant to serve as the context for the first real token in the
    same instance in the next batch.  
 
    To elaborate, let's assume that N is the size of the Context Window for your
    Transformer based processing of text.  N is the maximum length of the input token
    sequence for which you have designed your Transformer implementation.  [This also
    means that your Attention Map will be an array of size NxN.] And let n be the
    smallest number of previous tokens that you think will provide a reasonable
    context for predicting the current token.  So, during each training iteration,
    from each batch instance at the input, we want to save the last n tokens to serve
    as the context buffer for the new token sequence in the same batch instance.
    Therefore, at the next iteration you will feed n+N tokens into the transformer,
    but, as you can imagine, at the output of the transformer, you would only retain
    the N tokens that come after the context-buffer n tokens.
 
    It is this idea of context buffer that is invoked by the code in the second
    script mentioned at the beginning of this section.
 
    It is interesting to note that the above mentioned problem with context
    disruption does NOT arise with sentence-based language modeling (as in BERT)
    since <SOS> is what you would want to use for designating the start of the
    sentence.  (For such learning, you would also use another token, denoted <EOS>
    for "End of Sequence", to mark the end of a sentence.)
 
 
INSTALLATION:
 
    The babyGPT class was packaged using setuptools.  For installation, execute
    the following command in the source directory (this is the directory that
    contains the setup.py file after you have downloaded and uncompressed the
    gzipped tar archive for the module):
 
            sudo python3 setup.py install
 
    On Linux distributions, this will install the module file at a location that
    looks like
 
             /usr/local/lib/python3.10/dist-packages/
 
    If you do not have root access, you have the option of working directly off
    the directory in which you downloaded the software by simply placing the
    following statements at the top of your scripts that use the
    babyGPT class:
 
            import sys
            sys.path.append( "pathname_to_babyGPT_directory" )
 
    To uninstall the module, simply delete the source directory, locate where the
    babyGPT module was installed with "locate
    babyGPT" and delete those files.  As mentioned above, the full
    pathname to the installed version is likely to look like
    /usr/local/lib/python2.7/dist-packages/babyGPT*
 
    If you want to carry out a non-standard install of the babyGPT
    module, look up the on-line information on Disutils by pointing your browser
    to
 
              http://docs.python.org/dist/dist.html
 
USAGE:
 
    If you want to use babyGPT for unsupervised learning of a base model for a text
    corpus, you would need to construct an instance of the main babyGPT class and its
    supporting classes as follows:
 
    baby_gpt = babyGPT(
                        max_seq_length = max_seq_length,
                        batch_size = batch_size,
                        embedding_size = embedding_size,
                        num_basic_decoders = num_basic_decoders,
                        num_atten_heads = num_atten_heads,
                        optimizer_params = optimizer_params,
                        num_warmup_steps = num_warmup_steps,
                        masking = masking,
                        verify_text_corpus = False,
                        path_saved_model = {"decoder" : "./saved_decoder",                                                             
                                            "embedding_generator" : "./saved_embedding_generator",                             
                                           },
                      )
    
    xformer = baby_gpt.TransformerFG( 
                        max_seq_length = max_seq_length,
                        embedding_size = embedding_size,
                        tokenizer_json = tokenizer_json,
                        num_warmup_steps = num_warmup_steps,
                        optimizer_params = optimizer_params,
              )
    
    master_decoder = baby_gpt.MasterDecoderWithMasking(
                        xformer, 
                        num_basic_decoders = num_basic_decoders,
                        num_atten_heads = num_atten_heads,
                        masking = masking
                     )
    
    dataloader = baby_gpt.ArticleDatasetWithBufferedContext(
                        gpt = baby_gpt,
                        tokenizer_json = tokenizer_json,
                        context_window_size = context_window_size,
                        context_buffer_size = context_buffer_size,
                        articles_dir = articles_dir,
                 )
 
    
 
THE Examples DIRECTORY:
 
    This directory contains the following four scripts for working with babyGPT:
 
        1.  run_gatherer.py
 
            This script is for collecting a corpus for experimenting with babyGPT.
            The script requires a list of URLs as article sources as illustrated
            by the following example:
 
                urls = ['https://finance.yahoo.com','http://cnn.com',
                        'https://sports.yahoo.com',
                        'https://purdueexponent.org','https://slate.com',
                        'https://timesofindia.indiatimes.com',
                        'http://cnn.com',
                        'https://slate.com'
                       ]
 
        2.  train_tokenizer.py
 
            If the text corpus you have collected is for a specialized domain (such
            as movies, sports, healthcare, etc.), you are likely to get better
            results from babyGPT if you first train a new tokenizer for that domain.
            You train a new tokenizer merely by invoking this script after you have
            set its variable "articles_dir" so that it points to the corpus 
            directory.
 
 
        3.  create_base_model_with_buffered_context.py
 
            This is the script to run if you want to create a Base Model for your
            corpus.  By Base Model I mean a language model acquired through
            unsupervised learning from a training corpus.  Since this script calls on
            the core language modeling functionality of babyGPT, you have to set a
            relatively large number of parameters in the script.  These parameters
            are shown below:
 
                articles_dir
                tokenizer_json 
                max_seq_length 
                context_window_size
                context_buffer_size
                batch_size 
                embedding_size
                num_atten_heads 
                num_basic_decoders 
                optimizer_params
                num_warmup_steps
 
 
        4.  interact_with_prompts.py
 
            This is the script for interacting with a trained babyGPT model through
            prompts.  The idea is that you supply a small number of words (as, say,
            the beginning of a new thought) as a prompt and the model supplies the
            rest of the words to complete the thought.  At this time, the model
            extends your prompt until it reaches a period (or the end dictated by the
            size of the "max_seq_length" parameter.
 
BUGS:
 
    Please notify the author if you encounter any bugs.  When sending email,
    please place the string 'babyGPT' in the subject line to get past the
    author's spam filter.
 
 
ABOUT THE AUTHOR:
 
    The author, Avinash Kak, is a professor of Electrical and Computer Engineering
    at Purdue University.  For all issues related to this module, contact the
    author at kak@purdue.edu If you send email, please place the string
    "babyGPT" in your subject line to get past the author's spam
    filter.
 
COPYRIGHT:
 
    Python Software Foundation License
 
    Copyright 2025 Avinash Kak
 
@endofdocs

 
Imported Modules
       
matplotlib.animation
blingfire
glob
itertools
json
logging
math
newspaper
torch.nn
numpy
torch.optim
os
matplotlib.pyplot
random
re
signal
string
sys
time
torch
torchvision
torchvision.transforms

 
Classes
       
builtins.object
babyGPT

 
class babyGPT(builtins.object)
    babyGPT(*args, **kwargs)
 

 
  Methods defined here:
__init__(self, *args, **kwargs)
Initialize self.  See help(type(self)) for accurate signature.
run_code_with_buffered_context_for_training_TransformerFG(self, xformer, master_decoder, dataloader, checkpoint_frequency=1000, display_train_loss=False)
Drawn from the training routines in the Transformer module of DLStudio
save_checkpoint_decoder(self, decoder, dir_name, iter_index)
Save the decoder checkpoint
save_checkpoint_embedding_generator(self, embedding_generator, dir_name, iter_index)
save checkpoint for the embedding_generator
save_decoder(self, decoder)
Save the trained decoder to a disk file
save_embedding_generator(self, embedding_generator)

Data descriptors defined here:
__dict__
dictionary for instance variables (if defined)
__weakref__
list of weak references to the object (if defined)

Data and other attributes defined here:
ArticleDatasetWithBufferedContext = <class 'babyGPT.babyGPT.ArticleDatasetWithBufferedContext'>
The parameter 'context_window_size' is related to how many tokens you can feed into the
transformer at one iteration as the training corpus is being scanned.  In my Week 14 lecture
on Transformers, I used the notation 'max_seq_len' for this parameter.
ArticleGatherer = <class 'babyGPT.babyGPT.ArticleGatherer'>
This script is for collecting data for experimenting with the Transformer based
unsupervised learning code in baby_gpt.py.  
 
The articles are downloaded from the URLs that are specified by the argument 'urls' in the
constructor shown below.  See the script "create_base_model.py" in the Examples directory
for how to set the URL strings for this argument.  Here are some examples:
 
    urls = ['https://finance.yahoo.com','http://cnn.com',
             'https://timesofindia.indiatimes.com',
             'https://purdueexponent.org','https://slate.com', 
             'https://sports.yahoo.com']
 
    urls = ['http://cnn.com']
 
    urls = ['https://slate.com']
 
    urls = ['https://timesofindia.indiatimes.com']
AttentionHead = <class 'babyGPT.babyGPT.AttentionHead'>
Borrowed from the Transformers module of DLStudio
BasicDecoderWithMasking = <class 'babyGPT.babyGPT.BasicDecoderWithMasking'>
Borrowed from the Transformers module of DLStudio
EmbeddingGenerator = <class 'babyGPT.babyGPT.EmbeddingGenerator'>
MasterDecoderWithMasking = <class 'babyGPT.babyGPT.MasterDecoderWithMasking'>
Borrowed from the Transformers module of DLStudio
PromptResponder = <class 'babyGPT.babyGPT.PromptResponder'>
Prompting a trained babyGPT models means that you supply a small number of words (as, say, the 
beginning of a new thought) as a prompt and the model supplies the rest of the words to complete 
the thought.  The class comes with two methods, the first for extending your prompt until it 
reaches a period, and the second for going beyond the first period encountered.
 
Any interaction with a trained GPT model has to deal with the following issue:  What to do with
the context buffer that is meant to be a continuation of the last part of the previous "sentence"
fed into the transformer. 
 
Ideally, we should be placing in the context buffer words that create a context for the prompt.
But there is no easy way to that without a more elaborate model. An example of more elaborate
modeling would be to have the input to the transformer consist of, say, an SOS token, a special
context token consisting possibly of integer index values beyond the tokenizer vocab, followed
by a context buffer that would be the last part of the previous sentence, followed, finally, 
by the new input tokens.
 
babyGPT gives you two options regarding what to do with the context buffer for your prompt:
 
        --  all_zeros
 
        --  get_from_prompt
 
With the first option, all of the integer encoding values in the context buffer are set to
the integer zero.  And, with the second option, at this time, the context buffer contains
a portion or all of the prompt itself.  If the tokenized version of the prompt is shorter
than the size of the context buffer, only the context_buffer_size number of elements of the
prompt are retained for the context buffer.  In the opposite case, just the initial 
context_buffer_size number of elements of the prompt are retained.
ScheduledOptim = <class 'babyGPT.babyGPT.ScheduledOptim'>
As in the Transformers module of DLStudio, for the scheduling of the learning rate
during the warm-up phase of training TransformerFG, I have borrowed the class shown below
from the GitHub code made available by Yu-Hsiang Huang at:
 
        https://github.com/jadore801120/attention-is-all-you-need-pytorch
SelfAttention = <class 'babyGPT.babyGPT.SelfAttention'>
Borrowed from the Transformers module of DLStudio
TrainTokenizer = <class 'babyGPT.babyGPT.TrainTokenizer'>
Tokenizers play a critical role in language modeling because they create a
fixed-sized vocabulary for the corpus you are working with --- regardless of
the size of the corpus itself.  Unless your text corpus is based on a set of
documents frozen in time, ordinarily, as the size of a text corpus goes up,
so does the size of the vocabulary --- despite the illusion to the contrary
created by the fixed sizes of the language dictionaries you have seen all
your life.  How we express ourselves is a living thing.  We are constantly
inventing new words and new expressions; these form important components of
what's referred to as the zeitgeist.
 
Having a fixed-sized vocab is important because the loss functions used in
deep-learning network used for language processing are based on
maximum-likelihood prediction of the next token given the tokens seen
previously.  That requires estimating the probabilities associated with all
possible tokens at the next position.  As you can imagine, it would be
impossible to engage in such probabilistic reasoning if you did not know in
advance the size of the vocabulary.
TransformerFG = <class 'babyGPT.babyGPT.TransformerFG'>
I have borrowed from the DLStudio's Transformers module.  "FG" stands for "First Generation" --- which is the Transformer
as originally proposed by Vaswani et al.

 
Functions
       
ctrl_c_handler(signum, frame)
dev()
gen(container)

 
p
        __author__ = 'Avinash Kak (kak@purdue.edu)'
__copyright__ = '(C) 2025 Avinash Kak. Python Software Foundation.'
__date__ = '2025-April-22'
__url__ = 'https://engineering.purdue.edu/kak/distBabyGPT/babyGPT-1.0.6.html'
__version__ = '1.0.6'
 
Author
        Avinash Kak (kak@purdue.edu)