diff --git a/hf-bitsandbytes-integration.md b/hf-bitsandbytes-integration.md index 512761dc63..31f6f29662 100644 --- a/hf-bitsandbytes-integration.md +++ b/hf-bitsandbytes-integration.md @@ -28,27 +28,36 @@ thumbnail: /blog/assets/96_hf_bitsandbytes_integration/Thumbnail_blue.png guest + + +
+ stas + Stas Bekman +
+
![thumbnail](assets/96_hf_bitsandbytes_integration/Thumbnail_blue.png) -## Introduction +# Introduction Language models are becoming larger all the time. At the time of this writing, PaLM has 540B parameters, OPT, GPT-3, and BLOOM have around 176B parameters, and we are trending towards even larger models. Below is a qualitative diagram showing the size of some recent language models. ![LLM](assets/96_hf_bitsandbytes_integration/LLM.png) -Therefore, these models are hard to run on easily accessible devices. For example, just to do inference on BLOOM-175B, you would need to have 8x 80GB A100 GPUs (~$15k each). To fine-tune BLOOM-175B, you'd need 72 of these GPUs! Much larger models, like PaLM would require even more resources. -Because these huge models require so many GPUs to run, we need to find ways to reduce these requirements while preserving the model's performance. Various technologies have been developed that try to shrink the model size, you may have heard of quantization and distillation, and there are many others. -At Hugging Face and BigScience, after completing the training of BLOOM-176B, one of the approaches we took was to collaborate with `bitsandbytes` and integrate the technology described in the recent "GPT3.int8(): 8-bit Matrix Multiplication for Transformers at Scale" paper into Hugging Face Transformers. We chose to integrate it since no post-training quantization is required to run this method, and you can reduce the memory footprint of any model by 2x by adding just a few lines of code. +Therefore, these models are hard to run on easily accessible devices. For example, just to do inference on BLOOM-176B, you would need to have 8x 80GB A100 GPUs (~$15k each). To fine-tune BLOOM-176B, you'd need 72 of these GPUs! Much larger models, like PaLM would require even more resources. + +Because these huge models require so many GPUs to run, we need to find ways to reduce these requirements while preserving the model's predictive performance. Various technologies have been developed that try to shrink the model size, you may have heard of quantization and distillation, and there are many others. + +At Hugging Face and BigScience, after completing the training of BLOOM-176B, one of the approaches we took was to collaborate with `bitsandbytes` and integrate the technology described in the recent "LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale" paper into Hugging Face Transformers. We chose to integrate it since no further quantization is required after loading a model checkpoint to run this method, and you can reduce the memory footprint of any model by 2x by adding just a few lines of code. This article focuses on giving a high-level overview of this quantization technology, outlining the difficulties in incorporating it into the `transformers` library, and drawing up the long-term goals of this partnership. What elements affect a model's size? What makes BLOOM 350GB? Let's begin by gradually going over a few basic premises. -## Common data types used in Machine Learning +# Common data types used in Machine Learning We start with the basic understanding of different floating point data types, which are also referred to as "precision" in the context of Machine Learning. @@ -56,14 +65,14 @@ The size of a model is determined by the number of its parameters, and their pre ![Summary](assets/96_hf_bitsandbytes_integration/tf32-Mantissa-chart-hi-res-FINAL.png) -Float32 (FP32) stands for the standard 32-bit floating point representation. With this data type it is possible to represent a wide range of floating numbers. In FP32, 8 bits are reserved for the "exponent", 23 bits for the "mantissa" and 1 bit for the sign of the number. In addition to that, most of the hardware supports FP32 operations and instructions. +Float32 (FP32) stands for the standardized IEEE 32-bit floating point representation. With this data type it is possible to represent a wide range of floating numbers. In FP32, 8 bits are reserved for the "exponent", 23 bits for the "mantissa" and 1 bit for the sign of the number. In addition to that, most of the hardware supports FP32 operations and instructions. In the float16 (FP16) data type, 5 bits are reserved for the exponent and 10 bits are reserved for the mantissa. This makes the representable range of FP16 numbers much lower than FP32. This exposes FP16 numbers to the risk of overflowing (trying to represent a number that is very large) and underflowing (representing a number that is very small). -For example, if you do `10k * 10k` you end up with `100k` which is not possible to represent in FP16, as the largest number possible is `64k`. And thus you'd end up with `NaN` (Not a Number) result and all the prior work is destroyed. -Usually, scaling is used to overcome this issue, but it doesn't always work well. +For example, if you do `10k * 10k` you end up with `100k` which is not possible to represent in FP16, as the largest number possible is `64k`. And thus you'd end up with `NaN` (Not a Number) result and if you have sequential computation like in neural networks, all the prior work is destroyed. +Usually, loss scaling is used to overcome this issue, but it doesn't always work well. A new format, bfloat16 (BF16), was created to avoid these constraints. In BF16, 8 bits are reserved for the exponent (which is the same as in FP32) and 7 bits are reserved for the fraction. @@ -72,12 +81,16 @@ This means that in BF16 we can retain the same dynamic range as FP32. But we los In the Ampere architecture, NVIDIA also introduced [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) (TF32) precision format, combining the dynamic range of BF16 and precision of FP16 to only use 19 bits. It's currently only used internally during certain operations. +XXX: for all images that aren't original should add the link to the source page. original don't need to say original ;) +XXX: actually, perhaps we should use just one image https://blogs.nvidia.com/wp-content/uploads/2020/05/tf32-Mantissa-chart-hi-res-FINAL.png from https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ as it nicely aligns all the 4 formats and it's easier for the user to see the important difference at once. -> proposed a small refactoring above! + + In the machine learning jargon FP32 is called full precision (4 bytes), while BF16 and FP16 are referred to as half-precision (2 bytes). On top of that, the int8 (INT8) data type consists of an 8-bit representation that can store 2^8 different values (between [0, 255] or [-128, 127] for signed integers). -While, ideally the training and inference should be done in FP32, it is usually quite expensive and therefore a mixed precision approach is used where some parts of the training loop is done in FP32 while the other in either BF16 or FP16. The lower precision also runs faster. +While, ideally the training and inference should be done in FP32, it is two times slower than FP16/BF16 and therefore a mixed precision approach is used where the weights are held in FP32 as a precise "main weights" refrence, while computation in a forward and backward pass are done for FP16/BF16 to enhance training speed. The FP16/BF16 gradients are then used to update the FP32 main weights. -During training, the master weights are always stored in FP32, but in practice, the half-precision weights often provide similar quality during inference as their FP32 counterpart. This means we can use the half-precision weights and use half the GPUs to accomplish the same outcome. +During training, the main weights are always stored in FP32, but in practice, the half-precision weights often provide similar quality during inference as their FP32 counterpart -- a precise reference of the model is only needed when it receives multiple gradient updates. This means we can use the half-precision weights and use half the GPUs to accomplish the same outcome. ![Model-storage](assets/96_hf_bitsandbytes_integration/Model-storage.png) @@ -86,9 +99,9 @@ To calculate the model size in bytes, one multiplies the number of parameters by But what if we can store those weights with less memory using a different data type? A methodology called quantization has been used widely in Deep Learning. -## Introduction to model quantization +# Introduction to model quantization -Experientially we have discovered that instead of using the 4-byte FP32 precision, we can get an almost identical inference outcome with 2-byte BF16/FP16 half-precision, which halves the model size. It'd be amazing to cut it further, but the inference quality outcome starts to drop dramatically at lower precision. +Experimentially, we have discovered that instead of using the 4-byte FP32 precision, we can get an almost identical inference outcome with 2-byte BF16/FP16 half-precision, which halves the model size. It'd be amazing to cut it further, but the inference quality outcome starts to drop dramatically at lower precision. To remediate that, we introduce 8-bit quantization. This method uses a quarter precision, thus needing only 1/4th of the model size! But it's not done by just dropping another half of the bits. @@ -110,43 +123,41 @@ To retrieve the latest, one can just divide in full precision the int8 number wi ![out-quant.gif](assets/96_hf_bitsandbytes_integration/out-quant.gif) -For an unsigned int8, we would subtract the minimum and scale by the absolute maximum. This is close to what zero-point quantization does. It's is similar to a min-max scaling but the latter maintains the value scales in such a way that the value “0” is always represented by an integer without any quantization error. - -These tricks can be combined in several ways, for example, row-wise or vector-wise quantization, when it comes to matrix multiplication for more accurate results. +For an unsigned int8, we would subtract the minimum and scale by the absolute maximum. This is close to what zero-point quantization does. This is similar to a min-max scaling but the latter maintains the value scales in such a way that the value “0” is always represented by an integer without any quantization error. If you want to read more details about how classic quantization techniques work, we recommend reading this [blog post] -If you want to read more details about how classic quantization techniques work, we recommend reading this [blog post](https://intellabs.github.io/distiller/algo_quantization.html) or the GPT3.int8() paper (XXX: link). +These tricks can be combined in several ways, for example, row-wise or vector-wise quantization, when it comes to matrix multiplication for more accurate results. Looking at the matrix multiplication, A\*B=C, instead of regular quantization that normalize by a absolute maximum value per tensor, vector-wise quantization finds the absolute maximum of each row of A and each column of B. Then we normalize A and B by dividing these vectors. We then multiply A\*B to get C. Finally, to get back the FP16 values, we denormalize by computing the outer product of the absolute maximum vector of A and B. More details on this technique can be found in the LLM.int8() paper (XXX: link here). -While these basic techniques enable us to quanitize Deep Learning models, they usually lead to a drop in accuracy for larger models. The bnb-int8 implementation that we integrated into Hugging Face Transformers and Accelerate libraries is the first technique that does not degrade performance even for large models with 176B parameters, such as BLOOM. +While these basic techniques enable us to quanitize Deep Learning models, they usually lead to a drop in accuracy for larger models. The LLM.int8() implementation that we integrated into Hugging Face Transformers and Accelerate libraries is the first technique that does not degrade performance even for large models with 176B parameters, such as BLOOM. -## A gentle summary of mixed int8 matrix multiplication for Large Language Models +# A gentle summary of LLM.int8(): zero degradation matrix multiplication for Large Language Models -Authors have demonstrated that it is crucial to comprehend the scale-dependent emergent properties of transformers in order to understand why traditional quantization fails for large models. They demonstrate that performance deterioration is caused by outlier features, which will be explained next. +Authors of LLM.int8() have demonstrated that it is crucial to comprehend the scale-dependent emergent properties of transformers in order to understand why traditional quantization fails for large models. They demonstrate that performance deterioration is caused by outlier features, which we explain in the next section. The LLM.int8() algorithm itself can be explain as follows. -In essence, 8-bit Matrix Multiplication at Scale for Transformers seeks to complete the matrix multiplication computation in three steps: +In essence, LLM.int8() seeks to complete the matrix multiplication computation in three steps: 1. From the input hidden states, extract the outliers (i.e. values that are larger than a certain threshold) by column. -2. Perform the matrix multiplication of the outliers in fp16 and the non-outliers in int8. -3. Dequantize the non-outlier results and retrieve the full result in fp16. +2. Perform the matrix multiplication of the outliers in FP16 and the non-outliers in int8. +3. Dequantize the non-outlier results and add both outlier and non-outlier results togetehr to receive the full result in FP16. These steps can be summarized in the following animation: ![Mixed-int8.gif](assets/96_hf_bitsandbytes_integration/Mixed-int8.gif) -### The importance of outlier features +## The importance of outlier features -A value that is outside the range of some numbers' global distribution is generally referred to as an outlier. Outlier detection has been widely used and covered in the current literature, and having prior knowledge of the distribution of your features helps with the task of outlier detection. More specifically, authors have observed that classic quantization at scale fails for transformer-based models >6B parameters. +A value that is outside the range of some numbers' global distribution is generally referred to as an outlier. Outlier detection has been widely used and covered in the current literature, and having prior knowledge of the distribution of your features helps with the task of outlier detection. More specifically, authors have observed that classic quantization at scale fails for transformer-based models >6B parameters. While large outlier features are also present in smaller models, the authors observe that a certain threshold these outliers from highly systematic patterns across transformers which are present in every layer of the transformer. For more details on these phenomena see the LLM.int8() paper. -For the majority of models, hidden state features in transformers increase in magnitude with model size. As mentioned earlier, 8-bit precision is extremely constrained, therefore quantizing a vector with several big values can produce wildly erroneous results. Additionally, because of a built-in characteristic of the transformer-based architecture that links all the elements together, these errors tend to compound as they get propagated across multiple layers. Therefore, mixed-precision decomposition has been developed to facilitate efficient quantization with such extreme outliers. It is discussed next. +As mentioned earlier, 8-bit precision is extremely constrained, therefore quantizing a vector with several big values can produce wildly erroneous results. Additionally, because of a built-in characteristic of the transformer-based architecture that links all the elements together, these errors tend to compound as they get propagated across multiple layers. Therefore, mixed-precision decomposition has been developed to facilitate efficient quantization with such extreme outliers. It is discussed next. -### Inside the MatMul +## Inside the MatMul -Once the hidden states are computed we extract the outliers using a custom threshold (6.0 in our example) and we decompose the matrix into two parts as explained above. The outlier part is done in fp16 so it is a classic matrix multiplication, whereas the 8-bit matrix multiplication is done by quantizing the weights and hidden states into 8-bit precision using row-wise absmax quantization for the hidden states and column-wise absmax quantization for the weight matrix. +Once the hidden states are computed we extract the outliers using a custom threshold and we decompose the matrix into two parts as explained above. The authors find, that extracting all outliers with magnitude 6 or greater in this way recoveres full inference performance. The outlier part is done in fp16 so it is a classic matrix multiplication, whereas the 8-bit matrix multiplication is done by quantizing the weights and hidden states into 8-bit precision using vector-wise quantization -- that is, row-wise quantization for the hidden state and column-wise quantization for the weight matrix. After this step, the results are dequantized and returned in half-precision in order to add them to the first matrix multiplication. ![Matmul.png](assets/96_hf_bitsandbytes_integration/Matmul.png) -### What does 0 degradation mean? +## What does 0 degradation mean? How can we properly evaluate the performance degradation of this method? How much quality do we lose in terms of generation when using 8-bit models? @@ -181,9 +192,9 @@ For BLOOM-176: We indeed observe 0 performance degradation for those models since the absolute difference of the metrics are all below the standard error (except for BLOOM-int8 which is slightly better than the native model on lambada). For a more detailed performance evaluation against state-of-the-art approaches, take a look at the paper! -### Is it faster than native models? +## Is it faster than native models? -We also benchmarked the int8 inference speed of several models. Although we are close to having the same speed as the native model for large models (tested on BLOOM-176), the inference speed seems to be much slower than the native model on smaller models. +While the authors state that the main purpose of the LLM.int8() method is to make large models more accessible without performance degradation, we also benchmarked the inference speed of int8 models on different models. Although we are close to having the same speed as the native model for large models (tested on BLOOM-176), the inference speed seems to be much slower than the native model on smaller models. These issues are currently as expected and might be improve with [updates](https://github.com/TimDettmers/bitsandbytes/issues/6#issuecomment-1211345635) to the bitsandbytes software. | Model | Number of parameters | Hardware | Time per token in milliseconds for Batch Size 1 | Time per token in milliseconds for Batch Size 8 | Time per token in milliseconds for Batch Size 32 | | -------------- | -------------------- | ------------ | ----------------------------------------------- | ----------------------------------------------- | ------------------------------------------------ | @@ -198,17 +209,17 @@ We also benchmarked the int8 inference speed of several models. Although we are | T5-3b-int8 | 3B | 1xT4 15GB | 312 | 39.1 | 10.2 | -For a more technical deep dive into the method, we highly suggest checking out Tim Dettmers' blog : (link) +For a more technical deep dive into the method, we highly suggest checking out Tim Dettmers' blog post about [LLM.int8()](https://timdettmers.com/2022/10/16/llm-int8/). -### Hugging Face `transformers` integration nuances +# Which technology to use for `transformers` integration? -Next let's discuss the specifics of the Hugging Face `transformers` integration. Let's look at the usage and the common culprit you may encounter while trying to set things up. +How were these technologies incorporated into the `transformers` library? What were the difficulties we faced and the main techniques we employed in the integration project? Let's examine everything in the next sections! -### Usage +## How to use it in `bitsandbytes` library? -The module responsible for the whole magic described in this blog post is called `Linear8bitLt` and you can easily import it from the `bitsandbytes` library. It is derived from a classic `torch.nn` Module and can be easily used and deployed in your architecture with the code described below. +The module responsible for the whole magic described in this blog post is called `Linear8bitLt` and you can easily import it from the `bitsandbytes` library. It is derived from a classic `torch.nn` Module and can be easily used and deployed in your architecture with the commands described below. -Here is a step-by-step example of the following use case: let's say you want to convert a shallow model in int8 using `bitsandbytes`. +Let's walk through step by step with a specific use case: let's say you want to convert a shallow model in int8 using `bitsandbytes`. 1. First we need the correct imports below! @@ -220,7 +231,7 @@ import bitsandbytes as bnb from bnb.nn import Linear8bitLt ``` -2. Then you can define your own FP16 model. This detail is very important as you absolutely need a FP16 model to make it work. You might be able to use FP32 or BF16 model weights and cast those to FP16 (but at your own risk, since after conversion a model may fail to work - e.g. fp16 easily overflows with large numbers). +2. Then you can define your own FP16 model. This detail is very important as you absolutely need a FP16 model to make it work. You can also train or load a FP32 or BF16 model and cast it directly to FP16. ```py fp16_model = nn.Sequential( @@ -229,14 +240,13 @@ fp16_model = nn.Sequential( ).to(torch.float16) ``` -3. Next, you train the model and save the result: +3. Let's say you have trained your model on your favorite dataset and task! Now time to save the model: ```py -[... train the model ...] torch.save(fp16_model.state_dict(), "model.pt") ``` -4. Next we define an int8 model: +4. Now that your `state_dict` is saved, let us define an int8 model: ```py int8_model = nn.Sequential( @@ -245,13 +255,13 @@ int8_model = nn.Sequential( ) ``` -Here it is very important to add the flag `has_fp16_weights`. By default, this is set to `True` because loading a model with `has_fp16_weights=True` is not very well supported yet. +Here it is very important to add the flag `has_fp16_weights`. By default, this is set to `True` which is used to train in mixed Int8/FP16 precision. However, we are interested in memory efficient inference for which we need to use `has_fp16_weights=False`. -5. Finally the fp16 weights are loaded into the 8-bit model! +5. Now time to load your model in 8-bit! ```py int8_model.load_state_dict(torch.load("model.pt")) -int8_model = int8_model.to(0) # Quantization happens here +int8_model = int8_model.to(torch.device('cuda', 0))# Quantization happens here ``` Note that the quantization step is done in the second line once the model is set on the GPU. If you print `int8_model[0].weight` before calling the `.to` function you get: @@ -269,7 +279,7 @@ tensor([[ 0.0031, -0.0438, 0.0494, ..., -0.0046, -0.0410, 0.0436], dtype=torch.float16) ``` -Whereas, if you print it after the `.to` call, you get: +Whereas if you print it after the second line's call you get: ``` int8_model[0].weight @@ -280,16 +290,15 @@ tensor([[ 3, -47, 54, ..., -5, -44, 47], ..., [ 82, 74, 65, ..., -49, -53, -109], [ -21, -42, 68, ..., 13, 57, -12], - [ -4, 88, -1, ..., -43, -78, 121]], - device='cuda:0', dtype=torch.int8, requires_grad=True) + [ -4, 88, -1, ..., -43, -78, 121]], device='cuda:0', + dtype=torch.int8, requires_grad=True) ``` -The weights values are "truncated" as we have seen when explaining quantization in the previous sections. Also, the values seem to be distributed between `[-128, 127]`. - +The weights values are "truncated" as we have seen when explaining quantization in the previous sections. Also, the values seem to be distributed between [-127, 127]. You might also wonder how to retrieve the FP16 weights in order to perform the outlier MatMul in fp16? You can simply do: ```py -(int8_model[0].weight.CB * int8_model[0].weight.SCB) / 127 +(int8_model[0].weight.CB * int8_model[0].weight.SCB)/127 ``` And you will get: @@ -305,32 +314,29 @@ tensor([[ 0.0028, -0.0459, 0.0522, ..., -0.0049, -0.0428, 0.0462], device='cuda:0') ``` -Which is quite close to the original FP16 values (2 print outs up)! +Which is close enough to the original FP16 values! 6. Now you can safely infer using your model by making sure your model is on the correct GPU: ```py input_ = torch.randn(8, 64, dtype=torch.float16) -hidden_states = int8_model(input_.to(0)) +hidden_states = int8_model(input_.to(torch.device('cuda', 0))) ``` -Check out [this gist](https://gist.github.com/younesbelkada/9035e247b066d1cf18682e9e4c21032d) for the full minimal code! +Check out [this gist](https://gist.github.com/younesbelkada/9035e247b066d1cf18682e9e4c21032d) for the full minimal code! Now the time has come to understand how to integrate that into the `transformers` library! As a side note, you should be aware that these modules differ slightly from the `nn.Linear` modules in that their parameters come from the `bnb.nn.Int8Params` class rather than the `nn.Parameter` class. You'll see later that this presented an additional obstacle on our journey! -Now the time has come to understand how to integrate that into the `transformers` library! - - -### `accelerate` is all you need +## `accelerate` is all you need -When working with huge models, the `accelerate` library includes a number of helpful utilities. The `init_empty_weights` method is especially helpful because any model, regardless of size, may be initialized with this method as a context manager without allocating any memory for the model weights. +When working with huge models, the `accelerate` library includes a number of helpful utilities. The `init_empty_weights` method is especially helpful because any model, regardless of size, may be initialized with this method as a context manager with 0 cost, aka **no memory**. ```py import torch.nn as nn from accelerate import init_empty_weights with init_empty_weights(): - model = nn.Sequential([nn.Linear(100000, 100000) for _ in range(1000)]) # This will consume ~0 RAM! + model = nn.Sequential([nn.Linear(100000, 100000) for _ in range(1000)]) # This will take 0 RAM! ``` The initialized model will be put on PyTorch's `meta` device, an underlying mechanism to represent shape and dtype without allocating memory for storage. How cool is that? @@ -349,7 +355,7 @@ kwargs = module._parameters[name].__dict__ module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta")), **kwargs) ``` -Now that this is fixed, we can easily leverage this context manager and play with it to replace all `nn.Linear` modules to `bnb.nn.Linear8bitLt` at no memory cost using a custom function! +Now that this is fixed, we can easily leverage this context manager and play with it to replace all `nn.Linear` modules to `bnb.nn.Linear8bitLt` with no cost using a custom function! ```py def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): @@ -375,11 +381,11 @@ We also discard the replacement for some modules (here the `lm_head`) since we w But it isn't over yet! The function above is executed under the `init_empty_weights` context manager which means that the new model will be still in the `meta` device. For models that are initialized under this context manager, `accelerate` later manually loads the parameters of each module and sets it on the correct device. -In `bitsandbytes`, setting a `Linear8bitLt` module's device is a crucial step (the code snippet below is from [here](https://github.com/TimDettmers/bitsandbytes/blob/bd515328d70f344f935075f359c5aefc616878d5/bitsandbytes/nn/modules.py#L94)) as we have seen in our toy script. If you look more closely, this happens when `.to` or `.cuda` is called: +In `bitsandbytes`, setting a `Linear8bitLt` module's device is a crucial step (line below from [here](https://github.com/TimDettmers/bitsandbytes/blob/bd515328d70f344f935075f359c5aefc616878d5/bitsandbytes/nn/modules.py#L94)) as we have seen in our toy script. If you look more closely, this happens when `.to` or `.cuda` is called: ```py -## we store the 8-bit rows-major weight -## we convert this weight to the turning/ampere weight during the first inference pass +# we store the 8-bit rows-major weight +# we convert this weight to the turning/ampere weight during the first inference pass B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt @@ -388,16 +394,17 @@ self.data = CB setattr(self, 'CB', CB) setattr(self, 'SCB', SCB) ``` +These lines of code convert the regular Int8 tensor into specialized Int8 tensor formates for Turing or Ampere GPUs. Here, setting a parameter's device step is extremely crucial since the quantization statistics fails when calling it twice. We had to come up with an implementation of `accelerate`'s `set_module_tensor_to_device` function (termed as `set_module_8bit_tensor_to_device`) to make sure we don't call it twice. Let's discuss this in detail in the section below! -### Be very careful on how to set devices with `accelerate` +## Be very careful on how to set devices with `accelerate` Here we played a very delicate balancing act with the `accelerate` library! Once you load your model and set it on the correct devices, sometimes you still need to call `set_module_tensor_to_device` to dispatch the model with hooks on all devices. This is done inside the `dispatch_model` function from `accelerate`, which involves potentially calling `.to` several times and is something we want to avoid. 2 Pull Requests were needed to achieve what we wanted! The initial PR proposed [here](https://github.com/huggingface/accelerate/pull/539/) broke some tests but [this PR](https://github.com/huggingface/accelerate/pull/576/) successfully fixed everything! -### Wrapping it all up +## Wrapping it all up Therefore the ultimate recipe is: 1. Initialize a model in the `meta` device with the correct modules @@ -410,13 +417,13 @@ All said and done, this integration adventure was very fun; from deep diving and Now time to see how to benefit from this integration and how to successfully use it in `transformers`! -## How to use it in `transformers` +# How to use it in `transformers` -### Hardware requirements +## Hardware requirements 8-bit tensor cores are not supported on the CPU. bitsandbytes can be run on 8-bit tensor core-supported hardware, which are Turing and Ampere GPUs (RTX 20s, RTX 30s, A40-A100, T4+). For example, Google Colab GPUs are usually NVIDIA T4 GPUs, and their latest generation of GPUs does support 8-bit cores. Our demos are based on Google Colab so check them out below! -### Installation +## Installation Just install the latest version of the libraries using the commands below (make sure that you are using python>=3.8) and run the commands below to try out @@ -426,7 +433,7 @@ pip install bitsandbytes pip install git+https://github.com/huggingface/transformers.git ``` -### Example demos - running T5 11b on a Google Colab +## Example demos - running T5 11b on a Google Colab Check out the Google Colab demos for running 8bit models on a BLOOM-3B model! @@ -439,32 +446,32 @@ Or this demo for BLOOM-3B: [![Open In Colab: BLOOM-3b demo](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing) -## Scope of improvements +# Scope of improvements This approach, in our opinion, greatly improves access to very large models. With no performance degradation, it enables users with less compute to access models that were previously inaccessible. We've found several areas for improvement that can be worked on in the future to make this method even better for large models! -### Inference speed and slowing down on smaller models +## Inference speed and slowing down on smaller models For very large language models, we have observed that we nearly maintain the same inference speed using the native model as opposed to the mixed-8bit model (see attached experiments on BLOOM-176B). -However, due to the numerous internal casting steps, together with the outlier detection procedure that take place inside each 8bit-Linear layer, this method can significantly slow down inference speed on small models (models with less than 6b parameters). +However, due to the overhead of quantization, together with the outlier detection procedure that take place inside each 8bit-Linear layer, this method can significantly slow down inference speed on small models (models with less than 6b parameters). -One could attempt to improve that in the future and see how the inference speed can be decreased, probably by avoiding the casting operations or writing more efficient CUDA kernels. +One could attempt to improve that in the future and see how the inference speed can be decreased, probably by making the outlier extraction more efficient or parallelizing the outlier and non-outlier matrix multiplication which are currently done sequentially. -### Saving 8-bit state dicts on the Hub +## Saving 8-bit state dicts on the Hub 8-bit state dicts cannot currently be loaded directly into the 8-bit model after being pushed on the Hub. This is due to the fact that the statistics (remember `weight.CB` and `weight.SCB`) computed by the model are not currently stored or taken into account inside the state dict, and the `Linear8bitLt` module does not support this feature yet. We think that having the ability to save that and push it to the Hub might contribute to greater accessibility. -### CPU support +## CPU support CPU devices do not support 8-bit cores, as was stated at the beginning of this blogpost. Can we, however, get past that? Running this module on CPUs would also significantly improve usability and accessibility. -### Scaling up on other modalities +## Scaling up on other modalities Currently, language models dominate very large models. Leveraging this method on very large vision, audio, and multi-modal models might be an interesting thing to do for better accessibility in the coming years as these models become more accessible. -## Credits +# Credits Huge thanks to the following who contributed to improve the readability of the article as well as contributed in the integration procedure in `transformers` (listed in alphabetic order): JustHeuristic (Yozh),