Some pytorch models are saved .safetensors. For example:
In this tutorial, we will introduce how to load and save .safetensors.
Save pytorch model weights to .safetensors
Here is the example code:
import torch from safetensors.torch import save_file tensors = { "embedding": torch.zeros((2, 2)), "attention": torch.zeros((2, 3)) } save_file(tensors, "model.safetensors")
Here tensors is all weights in a model, we can use model.state_dict() to get it.
Understand PyTorch model.state_dict() – PyTorch Tutorial
Load .safetensors model file in pytorch
We can not use torch.load() to load a .safetensors file. However, we can do it as follows:
from safetensors import safe_open tensors = {} with safe_open("model.safetensors", framework="pt", device=0) as f: for k in f.keys(): tensors[k] = f.get_tensor(k)
Here model.safetensors is an existing .safetensors file. The tensors == model.state_dict()