Offload LLM using Facebook FairScale

Offload LLM using Facebook FairScale

What is Fairscale

FairScale is an open-source library from Facebook. It help researchers and engineers optimize the training of large-scale neural networks.

I am interested in its Pipeline Parallelism, model sharding, and Checkpointing.

The limit of using Fairscale to offload

One of the limit I found using Fairscale for the LLaMA language model is that the pre-trained weights of LLaMA is different than the one from Hugging Face, even though it is the same model.

The reason is Hugging Face has its own API from transformer model. To use the pre-trained weights, you need to do some conversion.

This means models from Hugging Face Transformer API is not compatible with Fairscale.

Also, the Offloading feature from Faurscale currently only support the layers with torch.nn.Sequential(). So if your layers is note nn.Sequential(), you need to do some tinkering.

Convert to nn.Sequential


class Transformer(torch.nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = ParallelEmbedding(
            params.vocab_size, params.dim, init_method=lambda x: x
        )

        self.layers = torch.nn.Sequential()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlockLowMemory(layer_id, params))

        self.layers = OffloadModel(
            model=self.layers,
            device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            num_slices=params.n_layers,
            checkpoint_activation=False,
           num_microbatches=1,
        )

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads,
            self.params.max_seq_len * 2,
            params.rope_theta,
        )

Whole script with forward call

# This file is a modified version of the original file from the LLAMA repository
from dataclasses import dataclass
import torch
import torch.nn.functional as F

from fairscale.nn.model_parallel.layers import (
    ColumnParallelLinear,
    ParallelEmbedding,
)
from fairscale.experimental.nn.offload import OffloadModel

from llama.model import ModelArgs, precompute_freqs_cis, RMSNorm, TransformerBlock

from typing import List, Optional



class TransformerBlockLowMemory(TransformerBlock):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__(layer_id, args)

    def forward(self, condese):
        seqlen = int(condese[0])
        start_pos = int(condese[1])
        x = condese[2:2+seqlen*self.dim].view(1, seqlen, self.dim).real.half()
        freqs_cis = condese[2+seqlen*self.dim:2+seqlen*self.dim+seqlen*64].view( seqlen, 64)
        mask = condese[2+seqlen*self.dim+seqlen*64:]
        if mask.sum() == 0:
            mask = None
        else:
            mask = mask.view(1, 1, seqlen, seqlen).real.half()

        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward.forward(self.ffn_norm(h))

        condese = torch.tensor([seqlen, start_pos], dtype=h.dtype).to(h.device)
        if mask is not None:
            condese_mask = mask.view(-1)
        else:
            condese_mask = torch.tensor([0]).to(h.device)
        condese = torch.cat((condese, out.view(-1), freqs_cis.view(-1), condese_mask))
        return condese



class Transformer(torch.nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = ParallelEmbedding(
            params.vocab_size, params.dim, init_method=lambda x: x
        )

        self.layers = torch.nn.Sequential()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlockLowMemory(layer_id, params))

        self.layers = OffloadModel(
            model=self.layers,
            device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            num_slices=params.n_layers,
            checkpoint_activation=False,
            num_microbatches=1,
        )

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads,
            self.params.max_seq_len * 2,
            params.rope_theta,
        )

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
            )
            mask = mask.to(torch.float32).triu(diagonal=start_pos+1).type_as(h)

        condese = torch.tensor([seqlen, start_pos], dtype=h.dtype).to(h.device)
        if mask is not None:
            condese_mask = mask.view(-1)
        else:
            condese_mask = torch.tensor([0]).to(h.device)
        condese = torch.cat((condese, h.view(-1), freqs_cis.view(-1), condese_mask))


        h = self.layers(condese)

        h = h[2:2+seqlen*self.params.dim].view(1, seqlen, self.params.dim).real.half()

        h = self.norm(h)
        output = self.output(h).float()
        return output