Understand load_in_8bit in AutoModelForCausalLM.from_pretrained() – LLM Tutorial

By | January 10, 2024

We often load a pretrained LLM as follows:

from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, AutoModel 
model = AutoModelForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map=device_map,
            trust_remote_code=True,
        )

Here base_model is the path of a LLM and device_map often can be auto. How about load_in_8bit?

load_in_8bit means we can load a LLM with 8bit precision data types. This way can save memory.

There are aslo a load_in_4bit parameter, they are:

load_in_8bit (bool, optional, defaults to False) — If True, will convert the loaded model into mixed-8bit quantized model. To use this feature please install bitsandbytes (pip install -U bitsandbytes).
load_in_4bit (bool, optional, defaults to False) — If True, will convert the loaded model into 4bit precision quantized model. To use this feature install the latest version of bitsandbytes (pip install -U bitsandbytes).

Understand load_in_8bit in AutoModelForCausalLM.from_pretrained() - LLM Tutorial