What is Fairscale
FairScale is an open-source library from Facebook. It help researchers and engineers optimize the training of large-scale neural networks.
I am interested in its Pipeline Parallelism, model sharding, and Checkpointing.
The limit of using Fairscale to offload
One of the limit I found using Fairscale for the LLaMA language model is that the pre-trained weights of LLaMA is different than the one from Hugging Face, even though it is the same model.
The reason is Hugging Face has its own API from transformer model. To use the pre-trained weights, you need to do some conversion.
This means models from Hugging Face Transformer API is not compatible with Fairscale.
Also, the Offloading feature from Faurscale currently only support the layers with torch.nn.Sequential(). So if your layers is note nn.Sequential(), you need to do some tinkering.
Convert to nn.Sequential
class Transformer(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
self.layers = torch.nn.Sequential()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlockLowMemory(layer_id, params))
self.layers = OffloadModel(
model=self.layers,
device=torch.device("cuda"),
offload_device=torch.device("cpu"),
num_slices=params.n_layers,
checkpoint_activation=False,
num_microbatches=1,
)
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
params.rope_theta,
)
Whole script with forward call
# This file is a modified version of the original file from the LLAMA repository
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
)
from fairscale.experimental.nn.offload import OffloadModel
from llama.model import ModelArgs, precompute_freqs_cis, RMSNorm, TransformerBlock
from typing import List, Optional
class TransformerBlockLowMemory(TransformerBlock):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__(layer_id, args)
def forward(self, condese):
seqlen = int(condese[0])
start_pos = int(condese[1])
x = condese[2:2+seqlen*self.dim].view(1, seqlen, self.dim).real.half()
freqs_cis = condese[2+seqlen*self.dim:2+seqlen*self.dim+seqlen*64].view( seqlen, 64)
mask = condese[2+seqlen*self.dim+seqlen*64:]
if mask.sum() == 0:
mask = None
else:
mask = mask.view(1, 1, seqlen, seqlen).real.half()
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
condese = torch.tensor([seqlen, start_pos], dtype=h.dtype).to(h.device)
if mask is not None:
condese_mask = mask.view(-1)
else:
condese_mask = torch.tensor([0]).to(h.device)
condese = torch.cat((condese, out.view(-1), freqs_cis.view(-1), condese_mask))
return condese
class Transformer(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
self.layers = torch.nn.Sequential()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlockLowMemory(layer_id, params))
self.layers = OffloadModel(
model=self.layers,
device=torch.device("cuda"),
offload_device=torch.device("cpu"),
num_slices=params.n_layers,
checkpoint_activation=False,
num_microbatches=1,
)
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
params.rope_theta,
)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
)
mask = mask.to(torch.float32).triu(diagonal=start_pos+1).type_as(h)
condese = torch.tensor([seqlen, start_pos], dtype=h.dtype).to(h.device)
if mask is not None:
condese_mask = mask.view(-1)
else:
condese_mask = torch.tensor([0]).to(h.device)
condese = torch.cat((condese, h.view(-1), freqs_cis.view(-1), condese_mask))
h = self.layers(condese)
h = h[2:2+seqlen*self.params.dim].view(1, seqlen, self.params.dim).real.half()
h = self.norm(h)
output = self.output(h).float()
return output