In this blog post, we will dive into the process of pretraining a Large Language Model (LLM) using unlabeled data. We will implement the training loop and basic model evaluation code necessary to train our model from scratch. Additionally, we’ll explore how to leverage openly available pretrained weights from OpenAI to enhance our model’s performance. So let’s get started!
The topics covered in this blog post are shown below:
To begin, it’s essential to ensure that we have the right packages installed and up-to-date. Here’s a quick check of the versions of the key libraries we will use:
from importlib.metadata import version
pkgs = ["matplotlib",
"numpy",
"tiktoken",
"torch",
"tensorflow" # For OpenAI's pretrained weights
]
for p in pkgs:
print(f"{p} version: {version(p)}")
OUTPUT
matplotlib version: 3.9.0
numpy version: 1.26.4
tiktoken version: 0.7.0
torch version: 2.4.0
tensorflow version: 2.16.1
These versions ensure compatibility with our code and facilitate smooth execution of the training and evaluation processes.
1 Evaluating Generative Text Models
In this section, we’ll begin by revisiting how to initialize a GPT model, using the code we covered in the this blog. We’ll then explore basic evaluation metrics for Large Language Models (LLMs) and apply these metrics to both training and validation datasets.
1.1 Using GPT to Generate Text
Let’s start by initializing a GPT model using the configuration from the previous blog:
import torch
from previous_blogs import GPTModel
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 256, # Shortened context length (orig: 1024)
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
}
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.eval() # Disable dropout during inference
We use a dropout of 0.1 here, though it’s increasingly common to train LLMs without dropout. Additionally, modern LLMs typically do not use bias vectors in the nn.Linear layers for the query, key, and value matrices, achieved by setting „qkv_bias“: False.
To reduce computational resource requirements, we’ve set the context length to 256 tokens, compared to the original 124 million parameter GPT-2 model, which used 1024 tokens. This setup makes it easier for readers to execute the code on a standard laptop. However, you can increase the context_length to 1024 tokens without changing any other code.
Next, let’s generate text using the generate_text_simple function from the previous blog. We also define two utility functions, text_to_token_ids and token_ids_to_text, to convert between text and token representations:
import tiktoken
from previous_blogs import generate_text_simple
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text, allowed_special={''})
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
return encoded_tensor
def token_ids_to_text(token_ids, tokenizer):
flat = token_ids.squeeze(0) # remove batch dimension
return tokenizer.decode(flat.tolist())
start_context = "Every effort moves you"
tokenizer = tiktoken.get_encoding("gpt2")
token_ids = generate_text_simple(
model=model,
idx=text_to_token_ids(start_context, tokenizer),
max_new_tokens=10,
context_size=GPT_CONFIG_124M["context_length"]
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
OUTPUT
Every effort moves you rentingetic wasnم refres RexMeCHicular stren
As seen above, the model does not produce coherent text since it hasn’t been trained yet. To evaluate and track the training progress, we need to measure how „good“ the generated text is in numerical terms. The next subsection introduces metrics for calculating a loss metric for the generated outputs.
1.2 Calculating the Text Generation Loss: Cross-Entropy and Perplexity
Consider the following tensors, representing token IDs for two training examples:
inputs = torch.tensor([[16833, 3626, 6100], # ["every effort moves",
[40, 1107, 588]]) # "I really like"]
targets = torch.tensor([[3626, 6100, 345 ], # [" effort moves you",
[1107, 588, 11311]]) # " really like chocolate"]
Feeding the inputs into the model, we get the logits vector for these input examples:
with torch.no_grad():
logits = model(inputs)
probas = torch.softmax(logits, dim=-1) # Probability of each token in vocabulary
print(probas.shape) # Shape: (batch_size, num_tokens, vocab_size)
OUTPUT
torch.Size([2, 3, 50257])
As discussed earlier, applying the argmax function converts the probability scores into predicted token IDs:
token_ids = torch.argmax(probas, dim=-1, keepdim=True)
print("Token IDs:\n", token_ids)
OUTPUT
tensor([[[16657],
[ 339],
[42826]],
[[49906],
[29669],
[41751]]])
Comparing these predictions to the target tokens shows a significant mismatch, as the model hasn’t been trained yet. To train the model, we must calculate how far off it is from the correct predictions.
Next, we compute the average log probability:
log_probas = torch.log(torch.cat((target_probas_1, target_probas_2)))
avg_log_probas = torch.mean(log_probas)
print(avg_log_probas)
OUTPUT
tensor(-10.7940)
In deep learning, instead of maximizing the average log-probability, it’s standard to minimize the negative average log-probability value. This value is also called cross-entropy loss:
neg_avg_log_probas = avg_log_probas * -1
print(neg_avg_log_probas)
OUTPUT
tensor(10.7940)
A related concept is perplexity, which is simply the exponential of the cross-entropy loss:
perplexity = torch.exp(loss)
print(perplexity)
OUTPUT
tensor(48725.8203)
A lower perplexity indicates that the model predictions are closer to the actual distribution, making perplexity a useful metric for evaluating model quality.
1.3 Calculating the Training and Validation Set Losses
We will use a small dataset to train the LLM, specifically a short public domain text. This allows you to run the examples quickly on a standard laptop without the need for extensive computational resources.
First, let’s load the dataset:
import os
import urllib.request
file_path = "the-verdict.txt"
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
if not os.path.exists(file_path):
with urllib.request.urlopen(url) as response:
text_data = response.read().decode('utf-8')
with open(file_path, "w", encoding="utf-8") as file:
file.write(text_data)
else:
with open(file_path, "r", encoding="utf-8") as file:
text_data = file.read()
We check that the text loaded correctly by printing the first and last 100 words:
print(text_data[:99])
print(text_data[-99:])
OUTPUT
I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no
it for me! The Strouds stand alone, and happen once--but there's no exterminating our kind of art."
Next, let’s divide the dataset into training and validation sets, then use data loaders to prepare batches for LLM training:
from previous_blogs import create_dataloader_v1
train_ratio = 0.90
split_idx = int(train_ratio * len(text_data))
train_data = text_data[:split_idx]
val_data = text_data[split_idx:]
torch.manual_seed(123)
train_loader = create_dataloader_v1(
train_data,
batch_size=2,
max_length=GPT_CONFIG_124M["context_length"],
stride=GPT_CONFIG_124M["context_length"],
drop_last=True,
shuffle=True,
num_workers=0
)
val_loader = create_dataloader_v1(
val_data,
batch_size=2,
max_length=GPT_CONFIG_124M["context_length"],
stride=GPT_CONFIG_124M["context_length"],
drop_last=False,
shuffle=False,
num_workers=0
)
Next, we implement utility functions to calculate the cross-entropy loss for a given batch and for a specified number of batches in a data loader:
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
return loss
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
else:
break
return total_loss / num_batches
If you have a CUDA-supported GPU, the LLM will train on the GPU without requiring any code changes. Here’s how to calculate the training and validation losses:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # no assignment model = model.to(device) needed
train_loss = calc_loss_loader(train_loader, model, device)
val_loss = calc_loss_loader(val_loader, model, device)
print(f"Train loss: {train_loss:.2f}")
print(f"Val loss: {val_loss:.2f}")
OUTPUT
Train loss: 10.79
Val loss: 10.79
The values are initially high since the model hasn’t been trained yet. After training, these loss values should decrease significantly.
2 Training an LLM
In this section, we finally implement the code for training the LLM. We’ll focus on a simple training function.
2.1 Simple Training Function
Let’s start by defining a simple training function:
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, start_context, tokenizer):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, track_tokens_seen = [], [], []
tokens_seen, global_step = 0, -1
# Main training loop
for epoch in range(num_epochs):
model.train() # Set model to training mode
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
loss = calc_loss_batch(input_batch, target_batch, model, device)
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients
tokens_seen += input_batch.numel()
global_step += 1
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter)
train_losses.append(train_loss)
val_losses.append(val_loss)
track_tokens_seen.append(tokens_seen)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
# Print a sample text after each epoch
generate_and_print_sample(
model, tokenizer, device, start_context
)
return train_losses, val_losses, track_tokens_seen
This function is simple yet effective for educational purposes. It tracks the training and validation losses, evaluates the model at regular intervals, and generates a sample text after each epoch.
2.2 Evaluation and Sample Generation
Next, let’s define the helper functions to evaluate the model and generate sample text:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
model.train()
return train_loss, val_loss
def generate_and_print_sample(model, tokenizer, device, start_context):
model.eval()
context_size = model.pos_emb.weight.shape[0]
encoded = text_to_token_ids(start_context, tokenizer).to(device)
with torch.no_grad():
token_ids = generate_text_simple(
model=model, idx=encoded,
max_new_tokens=50, context_size=context_size
)
decoded_text = token_ids_to_text(token_ids, tokenizer)
print(decoded_text.replace("\n", " ")) # Compact print format
model.train()
These functions support model evaluation during training and help visualize progress by generating text samples after each epoch.
2.3 Training the Model
Now, let’s train the LLM using the training function defined above:
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)
num_epochs = 10
train_losses, val_losses, tokens_seen = train_model_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=num_epochs, eval_freq=5, eval_iter=5,
start_context="Every effort moves you", tokenizer=tokenizer
)
OUTPUT
Ep 1 (Step 000000): Train loss 9.781, Val loss 9.933
Ep 1 (Step 000005): Train loss 8.111, Val loss 8.339
Every effort moves you,,,,,,,,,,,,.
Ep 2 (Step 000010): Train loss 6.661, Val loss 7.048
Ep 2 (Step 000015): Train loss 5.961, Val loss 6.616
Every effort moves you, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and,, and, and,
Ep 3 (Step 000020): Train loss 5.726, Val loss 6.600
Ep 3 (Step 000025): Train loss 5.201, Val loss 6.348
Every effort moves you, and I had been.
Ep 4 (Step 000030): Train loss 4.417, Val loss 6.278
Ep 4 (Step 000035): Train loss 4.069, Val loss 6.226
Every effort moves you know the "I he had the donkey and I had the and I had the donkey and down the room, I had
Ep 5 (Step 000040): Train loss 3.732, Val loss 6.160
Every effort moves you know it was not that the picture--I had the fact by the last I had been--his, and in the "Oh, and he said, and down the room, and in
Ep 6 (Step 000045): Train loss 2.850, Val loss 6.179
Ep 6 (Step 000050): Train loss 2.427, Val loss 6.141
Every effort moves you know," was one of the picture. The--I had a little of a little: "Yes, and in fact, and in the picture was, and I had been at my elbow and as his pictures, and down the room, I had
Ep 7 (Step 000055): Train loss 2.104, Val loss 6.134
Ep 7 (Step 000060): Train loss 1.882, Val loss 6.233
Every effort moves you know," was one of the picture for nothing--I told Mrs. "I was no--as! The women had been, in the moment--as Jack himself, as once one had been the donkey, and were, and in his
Ep 8 (Step 000065): Train loss 1.320, Val loss 6.238
Ep 8 (Step 000070): Train loss 0.985, Val loss 6.242
Every effort moves you know," was one of the axioms he had been the tips of a self-confident moustache, I felt to see a smile behind his close grayish beard--as if he had the donkey. "strongest," as his
Ep 9 (Step 000075): Train loss 0.717, Val loss 6.293
Ep 9 (Step 000080): Train loss 0.541, Val loss 6.393
Every effort moves you?" "Yes--quite insensible to the irony. She wanted him vindicated--and by me!" He laughed again, and threw back the window-curtains, I had the donkey. "There were days when I
Ep 10 (Step 000085): Train loss 0.391, Val loss 6.452
Every effort moves you know," was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed luncheon-table, when, on a later day, I had again run over from Monte Carlo; and Mrs. Gis
As the model trains, you can see the training and validation losses decrease. The generated text becomes more coherent, though overfitting is evident due to the small dataset and extensive training.
2.4 Visualizing Training Progress
Finally, let’s visualize the training and validation losses:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
fig, ax1 = plt.subplots(figsize=(5, 3))
# Plot training and validation loss against epochs
ax1.plot(epochs_seen, train_losses, label="Training loss")
ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
ax2.set_xlabel("Tokens seen")
fig.tight_layout() # Adjust layout to make room
plt.savefig("loss-plot.pdf")
plt.show()
epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
3 Decoding Strategies to Control Randomness
Inference with a relatively small LLM, like the GPT model we trained above, is computationally inexpensive. Even if you used a GPU for training, inference can be comfortably performed on a CPU. Using the generate_text_simple function from the previous blog, we can generate new text one word (or token) at a time.
As explained in Section 1.2, the next generated token is the one corresponding to the highest probability score among all tokens in the vocabulary. Let’s demonstrate this by moving our model to the CPU and generating text:
model.to("cpu")
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
token_ids = generate_text_simple(
model=model,
idx=text_to_token_ids("Every effort moves you", tokenizer),
max_new_tokens=25,
context_size=GPT_CONFIG_124M["context_length"]
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
OUTPUT
Every effort moves you know," was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed lun
Even if we execute the generate_text_simple function multiple times, the LLM will always generate the same outputs. This is because the function is deterministic, always selecting the token with the highest probability. To introduce variability and control the randomness of the generated text, we can employ two decoding strategies: temperature scaling and top-k sampling.
3.1 Temperature Scaling
Previously, we always selected the token with the highest probability using torch.argmax. To add variety, we can sample the next token using torch.multinomial(probs, num_samples=1), which samples from the probability distribution provided by the softmax function. In this context, each index’s chance of being picked corresponds to its probability in the input tensor.
Let’s recap how to generate the next token, assuming a very small vocabulary for illustration purposes:
vocab = {
"closer": 0,
"every": 1,
"effort": 2,
"forward": 3,
"inches": 4,
"moves": 5,
"pizza": 6,
"toward": 7,
"you": 8,
}
inverse_vocab = {v: k for k, v in vocab.items()}
# Suppose the input is "every effort moves you", and the LLM
# returns the following logits for the next token:
next_token_logits = torch.tensor(
[4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]
)
probas = torch.softmax(next_token_logits, dim=0)
next_token_id = torch.argmax(probas).item()
# The next generated token is:
print(inverse_vocab[next_token_id])
# Output: forward
torch.manual_seed(123)
next_token_id = torch.multinomial(probas, num_samples=1).item()
print(inverse_vocab[next_token_id])
# Output: forward
Instead of determining the most likely token via torch.argmax, we can use torch.multinomial(probas, num_samples=1) to determine the next token by sampling from the softmax distribution. Here’s how this approach works when we sample the next token 1,000 times using the original softmax probabilities:
def print_sampled_tokens(probas):
torch.manual_seed(123) # Manual seed for reproducibility
sample = [torch.multinomial(probas, num_samples=1).item() for i in range 1_000)]
sampled_ids = torch.bincount(torch.tensor(sample))
for i, freq in enumerate(sampled_ids):
print(f"{freq} x {inverse_vocab[i]}")
print_sampled_tokens(probas)
OUTPUT
73 x closer
0 x every
0 x effort
582 x forward
2 x inches
0 x moves
0 x pizza
343 x toward
The next token distribution favors „forward“ as expected, but there’s still some variability. We can further control this distribution using temperature scaling.
Temperature scaling adjusts the logits by dividing them by a number greater than 0, called the temperature. Here’s what happens:
- Temperature > 1: The resulting probabilities are more uniformly distributed, leading to more random and diverse output.
- Temperature < 1: The resulting probabilities become more peaked, making the model more confident in its choices, reducing diversity.
Let’s see this in action:
def softmax_with_temperature(logits, temperature):
scaled_logits = logits / temperature
return torch.softmax(scaled_logits, dim=0)
# Temperature values
temperatures = [1, 0.1, 5] # Original, higher confidence, and lower confidence
# Calculate scaled probabilities
scaled_probas = [softmax_with_temperature(next_token_logits, T) for T in temperatures]
# Plotting
x = torch.arange(len(vocab))
bar_width = 0.15
fig, ax = plt.subplots(figsize=(5, 3))
for i, T in enumerate(temperatures):
rects = ax.bar(x + i * bar_width, scaled_probas[i], bar_width, label=f'Temperature = {T}')
ax.set_ylabel('Probability')
ax.set_xticks(x)
ax.set_xticklabels(vocab.keys(), rotation=90)
ax.legend()
plt.tight_layout()
plt.savefig("temperature-plot.pdf")
plt.show()
print_sampled_tokens(scaled_probas[1])
OUTPUT
0 x closer
0 x every
0 x effort
985 x forward
0 x inches
0 x moves
0 x pizza
15 x toward
A temperature of 5 results in a more uniform distribution:
print_sampled_tokens(scaled_probas[2])
OUTPUT
165 x closer
75 x every
42 x effort
239 x forward
71 x inches
46 x moves
32 x pizza
227 x toward
103 x you
This approach can lead to nonsensical outputs, such as „every effort moves you pizza“ 3.2% of the time (32 out of 1000 times). To balance diversity and coherence, we can combine temperature scaling with top-k sampling.
3.2 Top-k Sampling
Top-k sampling restricts the model to only sample from the top k most likely tokens, reducing the probability of nonsensical outputs while allowing for diverse text generation. Here’s how to implement it:
top_k = 3
top_logits, top_pos = torch.topk(next_token_logits, top_k)
print("Top logits:", top_logits)
print("Top positions:", top_pos)
# Output:
# Top logits: tensor([6.7500, 6.2800, 4.5100])
# Top positions: tensor([3, 7, 0])
new_logits = torch.where(
condition=next_token_logits < top_logits[-1],
input=torch.tensor(float('-inf')),
other=next_token_logits
)
print(new_logits)
# Output:
# tensor([4.5100, -inf, -inf, 6.7500, -inf, -inf, -inf, 6.2800, -inf])
topk_probas = torch.softmax(new_logits, dim=0)
print(topk_probas)
# Output:
# tensor([0.0615, 0.0000, 0.0000, 0.5775, 0.0000, 0.0000, 0.0000, 0.3610, 0.0000])
This method ensures the model only considers the top 3 tokens, further controlling randomness.
3.3 Modifying the Text Generation Function
Let’s combine temperature scaling and top-k sampling to modify the generate_simple function used to generate text earlier:
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
# For-loop is the same as before: Get logits, and only focus on last time step
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:]
with torch.no_grad():
logits = model(idx_cond)
logits = logits[:, -1, :]
# New: Filter logits with top_k sampling
if top_k is not None:
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0:
logits = logits / temperature
# Apply softmax to get probabilities
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
# Otherwise same as before: get idx of the vocab entry with the highest logits value
else:
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
break
# Same as before: append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
return idx
Let’s generate text with the modified function:
torch.manual_seed(123)
token_ids = generate(
model=model,
idx=text_to_token_ids("Every effort moves you", tokenizer),
max_new_tokens=15,
context_size=GPT_CONFIG_124M["context_length"],
top_k=25,
temperature=1.4
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
OUTPUT
Every effort moves you stand to work on surprise, a one of us had gone with random-
By tweaking the temperature and top-k parameters, you can effectively control the randomness and diversity of the generated text, balancing between coherent and diverse outputs.
4 Loading and Saving Model Weights in PyTorch
Training Large Language Models (LLMs) is computationally intensive, making it essential to save and load model weights efficiently. In PyTorch, the recommended way to save the model weights is by using the torch.save function in conjunction with the .state_dict() method, which captures the model’s parameters.
Here’s how you can save the model weights:
torch.save(model.state_dict(), "model.pth")
To load the saved model weights into a new instance of the GPTModel, follow these steps:
model = GPTModel(GPT_CONFIG_124M)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
model.eval()
When training LLMs, it’s common to use adaptive optimizers like Adam or AdamW instead of standard SGD. These optimizers store additional parameters for each model weight, so it’s wise to save the optimizer state along with the model weights if you plan to resume training later:
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}, "model_and_optimizer.pth")
To load both the model and optimizer states:
checkpoint = torch.load("model_and_optimizer.pth", weights_only=True)
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.train()
5 Loading Pretrained Weights from OpenAI
Pretraining LLMs from scratch can be prohibitively expensive, but fortunately, OpenAI provides pretrained weights for their models. This allows users to leverage the power of pretrained LLMs without the associated computational costs.
To load the pretrained weights provided by OpenAI, some boilerplate code is necessary. Since OpenAI originally used TensorFlow, you’ll need to install TensorFlow along with the tqdm progress bar library:
# pip install tensorflow tqdm
print("TensorFlow version:", version("tensorflow"))
print("tqdm version:", version("tqdm"))
To download the model weights for the 124 million parameter GPT-2 model:
from gpt_download import download_and_load_gpt2
settings, params = download_and_load_gpt2(model_size="124M", models_dir="gpt2")
After downloading the weights, initialize a new GPTModel instance. Note that to correctly load the weights, the model configuration must match the original model, including setting the qkv_bias to True and using a 1024 token context length:
NEW_CONFIG = GPT_CONFIG_124M.copy()
NEW_CONFIG.update({"context_length": 1024, "qkv_bias": True})
gpt = GPTModel(NEW_CONFIG)
gpt.eval()
Next, map the OpenAI weights to the corresponding tensors in your GPTModel instance:
import numpy as np
def assign(left, right):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
return torch.nn.Parameter(torch.tensor(right))
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split((params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.weight = assign(gpt.trf_blocks[b].att.W_query.weight, q_w.T)
gpt.trf_blocks[b].att.W_key.weight = assign(gpt.trf_blocks[b].att.W_key.weight, k_w.T)
gpt.trf_blocks[b].att.W_value.weight = assign(gpt.trf_blocks[b].att.W_value.weight, v_w.T)
q_b, k_b, v_b = np.split((params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.bias = assign(gpt.trf_blocks[b].att.W_query.bias, q_b)
gpt.trf_blocks[b].att.W_key.bias = assign(gpt.trf_blocks[b].att.W_key.bias, k_b)
gpt.trf_blocks[b].att.W_value.bias = assign(gpt.trf_blocks[b].att.W_value.bias, v_b)
gpt.trf_blocks[b].att.out_proj.weight = assign(gpt.trf_blocks[b].att.out_proj.weight, params["blocks"][b]["attn"]["c_proj"]["w"].T)
gpt.trf_blocks[b].att.out_proj.bias = assign(gpt.trf_blocks[b].att.out_proj.bias, params["blocks"][b]["attn"]["c_proj"]["b"])
gpt.trf_blocks[b].ff.layers[0].weight = assign(gpt.trf_blocks[b].ff.layers[0].weight, params["blocks"][b]["mlp"]["c_fc"]["w"].T)
gpt.trf_blocks[b].ff.layers[0].bias = assign(gpt.trf_blocks[b].ff.layers[0].bias, params["blocks"][b]["mlp"]["c_fc"]["b"])
gpt.trf_blocks[b].ff.layers[2].weight = assign(gpt.trf_blocks[b].ff.layers[2].weight, params["blocks"][b]["mlp"]["c_proj"]["w"].T)
gpt.trf_blocks[b].ff.layers[2].bias = assign(gpt.trf_blocks[b].ff.layers[2].bias, params["blocks"][b]["mlp"]["c_proj"]["b"])
gpt.trf_blocks[b].norm1.scale = assign(gpt.trf_blocks[b].norm1.scale, params["blocks"][b]["ln_1"]["g"])
gpt.trf_blocks[b].norm1.shift = assign(gpt.trf_blocks[b].norm1.shift, params["blocks"][b]["ln_1"]["b"])
gpt.trf_blocks[b].norm2.scale = assign(gpt.trf_blocks[b].norm2.scale, params["blocks"][b]["ln_2"]["g"])
gpt.trf_blocks[b].norm2.shift = assign(gpt.trf_blocks[b].norm2.shift, params["blocks"][b]["ln_2"]["b"])
gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
load_weights_into_gpt(gpt, params)
gpt.to(device)
Finally, you can generate new text using the loaded model:
torch.manual_seed(123)
token_ids = generate(
model=gpt,
idx=text_to_token_ids("Every effort moves you", tokenizer).to(device),
max_new_tokens=25,
context_size=NEW_CONFIG["context_length"],
top_k=50,
temperature=1.5
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
OUTPUT
Output text:
Every effort moves you toward finding an ideal new way to practice something!
What makes us want to be on top of that?
If the weights were loaded correctly, the model should generate coherent text, confirming the successful weight transfer.
That is it for today…