What is 32-bit full precision training and 16-bit half precision training?
This picture pretty much sums up:
Range
Now, imagine each node/neuron in a neuron network. It is important to has the maximum range of floating number representation.
From minimum number to maximum number that each type can represent.
bfloat16 | 1.18e-38 | 3.39e+38 |
fp16 | 6.10e-05 | 6.55e+04 |
fp32 | 1.18e-38 | 3.40e+38 |
What is 8-bit quantization?
It used to reduce the precision of the weights and activations of a neural network from their original floating-point representation to 8-bit integers.
Compare to 32-bit, it save up to 4x less memory.
What is the performance lost for 8-bit quantization?
8-bit integer can only represent 256 integer.
bfloat16 | 1.18e-38 | 3.39e+38 |
fp16 | 6.10e-05 | 6.55e+04 |
fp32 | 1.18e-38 | 3.40e+38 |
int8 | -128 | 127 |
From 32-bit to bfloat16 training, the result downgraded just a littie.
But 8-bit quantization, the result can downgrade dramatically. We don’t usually use 8-bit to train or fine-tune a model. We only use it to generate prompts.
Code
from transformers import AutoModelForCausalLM
import torch
# Load pretrained LLaMA model
model = AutoModelForCausalLM.from_pretrained('allenai/llama')
# Prepare model for dynamic quantization (PyTorch's quantization support)
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
https://github.com/jljacoblo/jacAI/commit/8aa6075e06f8b7478911c05de7d56b3b67128248