Offload LLM custom back propagation

Offload LLM custom back propagation

What is Model Offloading?

Model offloading, as described in the DeepSpeed ZeRO’s paper, refers to the technique of offloading parts of the model’s parameters, gradients, and optimizer states from GPU memory to CPU memory or NVMe storage during training. This allows for the training of models that are larger than the GPU memory, by leveraging the typically larger memory capacity available on the CPU or fast storage devices.

What is NanoGPT?

NanoGPT is develop by Andrew Karpathy, a Lead engineer from Tesla.

The simplest, fastest repository for training/finetuning medium-sized GPTs.

Andrew Karpathy

It basically only 2 files, less than 200 lines of code. and all is coded from scratch. It is great to learn about the ins and outs of a large language model.

Convert NanoGPT to support offload

# %%
from torch import nn
import torch.nn.functional as F
import torch
from torch.autograd import Function
import numpy as np
import time, os

from nanoGPT.model import GPT, Block, GPTConfig


data_dir = '/media/j/wdata/git/PYTHON_IMPORT/nanoGPT/data/all_news_char'
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
def get_batch(split, block_size, batch_size, device_type, device):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y



# %%

if __name__ == '__main__':
  device = 'cuda'
  model_args = dict(n_layer=4, n_head=64, n_embd=768, block_size=4096, bias=True, vocab_size=128, dropout=0.1)
  gptconf = GPTConfig(**model_args)

  model_part_1 = GPT(gptconf)

  model_part_1.transformer.h[1] = None
  model_part_1.transformer.h[2] = None
  model_part_1.transformer.h[3] = None
  model_part_1.transformer.ln_f = None
  model_part_1.lm_head = None


  model_part_2 = GPT(gptconf)
  model_part_2.transformer.rrr = torch.nn.Parameter(torch.randn(50000,10000))
  model_part_2.transformer.wte = None
  model_part_2.transformer.wpe = None
  model_part_2.transformer.drop = None
  model_part_2.transformer.h[0] = None
  model_part_2.transformer.h[2] = None
  model_part_2.transformer.h[3] = None
  model_part_2.transformer.ln_f = None
  model_part_2.lm_head = None


  model_part_3 = GPT(gptconf)
  model_part_3.transformer.wte = None
  model_part_3.transformer.wpe = None
  model_part_3.transformer.drop = None
  model_part_3.transformer.h[0] = None
  model_part_3.transformer.h[1] = None
  model_part_3.transformer.h[3] = None
  model_part_3.transformer.ln_f = None
  model_part_3.lm_head = None


  model_part_4 = GPT(gptconf)
  model_part_4.transformer.wte = None
  model_part_4.transformer.wpe = None
  model_part_4.transformer.drop = None
  model_part_4.transformer.h[0] = None
  model_part_4.transformer.h[1] = None
  model_part_4.transformer.h[2] = None

  # Choose a loss function
  loss_fn = torch.nn.CrossEntropyLoss()

  # Choose an optimizer
  weight_decay = 1e-1
  learning_rate = 9e-3
  beta1 = 0.9
  beta2 = 0.99

  params = list(model_part_1.parameters()) + list(model_part_2.parameters()) + list(model_part_3.parameters()) + list(model_part_4.parameters())

  decay_params = [p for p in params if p.dim() >= 2]
  nodecay_params = [p for p in params if p.dim() < 2]
  optim_groups = [
      {'params': decay_params, 'weight_decay': weight_decay},
      {'params': nodecay_params, 'weight_decay': 0.0}
  ]
  num_decay_params = sum(p.numel() for p in decay_params)
  num_nodecay_params = sum(p.numel() for p in nodecay_params)
  print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
  print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
  # Create AdamW optimizer and use the fused version if it is available
  optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(beta1, beta2))


  start_time = time.time()

  # Training loop
  for i in range(50):
    x, targets = get_batch('train', 128, 50, device, device)
    print(f'x, targets {torch.cuda.memory_allocated() / 1024**2:0.2f}')

    model_part_1.to('cuda')
    print(f"for model_part_1.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_1, _ = model_part_1(x) 
    # model_part_1.to('cpu')
    torch.cuda.empty_cache()
    print(f"for model_part_1 empty_cache {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    model_part_2.to('cuda')
    print(f"formodel_part_2.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    logits_2, _ = model_part_2(logits_1)
    model_part_2.to('cpu')
    torch.cuda.empty_cache()
    print(f"for model_part_2.to('cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")


    model_part_3.to('cuda')
    print(f"for model_part_3.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    logits_3, _ = model_part_3(logits_2)
    model_part_3.to('cpu')
    torch.cuda.empty_cache()
    print(f"for model_part_3.to('cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")


    model_part_4.to('cuda')
    print(f"for model_part_4.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_4, loss = model_part_4(logits_3, targets)
    model_part_4.to('cpu')
    torch.cuda.empty_cache()
    print(f"for model_part_4.to('cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    # Zero the gradients
    optimizer.zero_grad()

    # Backpropagation 
    model_part_4.to('cuda')
    print(f"back model_part_4.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_4.grad = torch.autograd.grad(outputs=loss, inputs=logits_4, retain_graph=True, create_graph=False)[0]
    model_part_3.to('cuda')
    print(f"back model_part_3.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    logits_3.grad = torch.autograd.grad(outputs=logits_4, inputs=logits_3, grad_outputs=logits_4.grad, retain_graph=True, create_graph=False)[0]
    model_part_4.to('cpu')
    torch.cuda.empty_cache()
    print(f"back model_part_4.to('cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    model_part_2.to('cuda')
    print(f"back model_part_2.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_2.grad = torch.autograd.grad(outputs=logits_3, inputs=logits_2, grad_outputs=logits_3.grad, retain_graph=True, create_graph=False)[0]
    model_part_3.to('cpu')
    torch.cuda.empty_cache()
    print(f"back model_part_3.to('cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    model_part_1.to('cuda')
    print(f"back model_part_1.to('cuda') {torch.cuda.memory_allocated() / 1024**2:0.2f}")
    logits_1.grad = torch.autograd.grad(outputs=logits_2, inputs=logits_1, grad_outputs=logits_2.grad, retain_graph=True, create_graph=False)[0]
    model_part_2.to('cpu')
    torch.cuda.empty_cache()
    print(f"back model_part_2.to('cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    model_part_1.transformer.wpe.weight.grad, model_part_1.transformer.wte.weight.grad = torch.autograd.grad(outputs=logits_1, inputs=[model_part_1.transformer.wpe.weight, model_part_1.transformer.wte.weight], grad_outputs=logits_1.grad, retain_graph=False, create_graph=False)

    model_part_1.to('cpu')
    torch.cuda.empty_cache()
    print(f"back model_part_1.to('cpu') {torch.cuda.memory_allocated() / 1024**2:0.2f}")

    # Update the parameters
    optimizer.step()

    print(loss)

  print(f'Time taken: {time.time() - start_time} seconds')
# %%

https://github.com/jljacoblo/jacAI/commit/8aa6075