Offload LLM custom back propagation

Offload LLM custom back propagation

What is Model Offloading?

Model offloading, as described in the DeepSpeed ZeRO’s paper, refers to the technique of offloading parts of the model’s parameters, gradients, and optimizer states from GPU memory to CPU memory or NVMe storage during training. This allows for the training of models that are larger than the GPU memory, by leveraging the typically larger memory capacity available on the CPU or fast storage devices.

What is NanoGPT?

NanoGPT is develop by Andrew Karpathy, a Lead engineer from Tesla.

The simplest, fastest repository for training/finetuning medium-sized GPTs.

Andrew Karpathy

It basically only 2 files, less than 200 lines of code. and all is coded from scratch. It is great to learn about the ins and outs of a large language model.

Convert NanoGPT to support offload

# %%
from torch import nn
import torch.nn.functional as F
import torch
from torch.autograd import Function
import numpy as np
import time, os

from nanoGPT.model import GPT, Block, GPTConfig

data_dir = '/media/j/wdata/git/PYTHON_IMPORT/nanoGPT/data/all_news_char'
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
def get_batch(split, block_size, batch_size, device_type, device):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
        x, y =,
    return x, y

# %%

if __name__ == '__main__':
  device = 'cuda'
  model_args = dict(n_layer=4, n_head=64, n_embd=768, block_size=4096, bias=True, vocab_size=128, dropout=0.1)
  gptconf = GPTConfig(**model_args)

  model_part_1 = GPT(gptconf)

  model_part_1.transformer.h[1] = None
  model_part_1.transformer.h[2] = None
  model_part_1.transformer.h[3] = None
  model_part_1.transformer.ln_f = None
  model_part_1.lm_head = None

  model_part_2 = GPT(gptconf)
  model_part_2.transformer.rrr = torch.nn.Parameter(torch.randn(50000,10000))
  model_part_2.transformer.wte = None
  model_part_2.transformer.wpe = None
  model_part_2.transformer.drop = None
  model_part_2.transformer.h[0] = None
  model_part_2.transformer.h[2] = None
  model_part_2.transformer.h[3] = None
  model_part_2.transformer.ln_f = None
  model_part_2.lm_head = None

  model_part_3 = GPT(gptconf)
  model_part_3.transformer.wte = None
  model_part_3.transformer.wpe = None
  model_part_3.transformer.drop = None
  model_part_3.transformer.h[0] = None
  model_part_3.transformer.h[1] = None
  model_part_3.transformer.h[3] = None
  model_part_3.transformer.ln_f = None
  model_part_3.lm_head = None

  model_part_4 = GPT(gptconf)
  model_part_4.transformer.wte = None
  model_part_4.transformer.wpe = None
  model_part_4.transformer.drop = None
  model_part_4.transformer.h[0] = None
  model_part_4.transformer.h[1] = None
  model_part_4.transformer.h[2] = None

  # Choose a loss function
  loss_fn = torch.nn.CrossEntropyLoss()

  # Choose an optimizer
  weight_decay = 1e-1
  learning_rate = 9e-3
  beta1 = 0.9
  beta2 = 0.99

  params = list(model_part_1.parameters()) + list(model_part_2.parameters()) + list(model_part_3.parameters()) + list(model_part_4.parameters())

  decay_params = [p for p in params if p.dim() >= 2]
  nodecay_params = [p for p in params if p.dim() < 2]
  optim_groups = [
      {'params': decay_params, 'weight_decay': weight_decay},
      {'params': nodecay_params, 'weight_decay': 0.0}
  num_decay_params = sum(p.numel() for p in decay_params)
  num_nodecay_params = sum(p.numel() for p in nodecay_params)
  print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
  print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
  # Create AdamW optimizer and use the fused version if it is available
  optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(beta1, beta2))

  start_time = time.time()

  # Training loop
  for i in range(50):
    x, targets = get_batch('train', 128, 50, device, device)
    print(f'x, targets {torch.cuda.memory_allocated() / 1024**2:0.2f}')'cuda')
    print(f"for'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_1, _ = model_part_1(x) 
    print(f"for model_part_1 empty_cache {torch.cuda.memory_allocated() / 1024**2:0.2f}")'cuda')
    print(f"'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    logits_2, _ = model_part_2(logits_1)'cpu')
    print(f"for'cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")'cuda')
    print(f"for'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    logits_3, _ = model_part_3(logits_2)'cpu')
    print(f"for'cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")'cuda')
    print(f"for'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_4, loss = model_part_4(logits_3, targets)'cpu')
    print(f"for'cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    # Zero the gradients

    # Backpropagation'cuda')
    print(f"back'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_4.grad = torch.autograd.grad(outputs=loss, inputs=logits_4, retain_graph=True, create_graph=False)[0]'cuda')
    print(f"back'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    logits_3.grad = torch.autograd.grad(outputs=logits_4, inputs=logits_3, grad_outputs=logits_4.grad, retain_graph=True, create_graph=False)[0]'cpu')
    print(f"back'cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")'cuda')
    print(f"back'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_2.grad = torch.autograd.grad(outputs=logits_3, inputs=logits_2, grad_outputs=logits_3.grad, retain_graph=True, create_graph=False)[0]'cpu')
    print(f"back'cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")'cuda')
    print(f"back'cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_1.grad = torch.autograd.grad(outputs=logits_2, inputs=logits_1, grad_outputs=logits_2.grad, retain_graph=True, create_graph=False)[0]'cpu')
    print(f"back'cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    model_part_1.transformer.wpe.weight.grad, model_part_1.transformer.wte.weight.grad = torch.autograd.grad(outputs=logits_1, inputs=[model_part_1.transformer.wpe.weight, model_part_1.transformer.wte.weight], grad_outputs=logits_1.grad, retain_graph=False, create_graph=False)'cpu')
    print(f"back'cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    # Update the parameters


  print(f'Time taken: {time.time() - start_time} seconds')
# %%