PyTorch Load and Save .safetensors – PyTorch Tutorial

By | March 18, 2024

Some pytorch models are saved .safetensors. For example:

.safetensors file in pytorch

In this tutorial, we will introduce how to load and save .safetensors.

Save pytorch model weights to .safetensors

Here is the example code:

import torch
from safetensors.torch import save_file

tensors = {
    "embedding": torch.zeros((2, 2)),
    "attention": torch.zeros((2, 3))
}
save_file(tensors, "model.safetensors")

Here tensors is all weights in a model, we can use model.state_dict() to get it.

Understand PyTorch model.state_dict() – PyTorch Tutorial

Load .safetensors model file in pytorch

We can not use torch.load() to load a .safetensors file. However, we can do it as follows:

from safetensors import safe_open

tensors = {}
with safe_open("model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

Here model.safetensors is an existing .safetensors file. The tensors == model.state_dict()