Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial

By | April 28, 2022

When you plan to use pytorch to build an AI model, you should know how to save and load a pytorch model. In this tutorial, we will introduce you how do do.

Save a pytorch model

We can use torch.save(object, PATH) to save an object to PATH.

For example:

torch.save(model.state_dict(), PATH)

This example will only save a pytorch model state_dict() to PATH.

However, we also can save more information to a path.

For example:

state = {'epoch': epoch,
         'epochs_since_improvement': epochs_since_improvement,
         'loss': loss,
         'model': model,
         'optimizer': optimizer}

filename = 'checkpoint.tar'
torch.save(state, filename)

Here model is a pytorch model object.

In this example, we will save epoch, loss, pytorch model and an optimizer to checkpoint.tar file.

Load a pytorch model

In pytorch, we can use torch.load() function to load an existing model.

As mentioned above, if we only save a pytorch model state_dict(), we can load a model as follows:

global device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

checkpoint = 'checkpoint.tar'
print('loading model: {}...'.format(checkpoint))
model = Tacotron2(HParams()) # create your pytorch model object
state = torch.load(checkpoint, map_location=device)
print(state)
model.load_state_dict(state['model'].state_dict())
model = model.to(device)

In this way, we should:

(1) create a pytorch model first

model = Tacotron2(HParams()) # create your pytorch model object

(2) use torch.load() method to load

state = torch.load(checkpoint, map_location=device)
print(state)
model.load_state_dict(state['model'].state_dict())

(3) move pythorch to cpu or gpu

model = model.to(device)

However, if we have saved a pytorch model object, we can load an existing a pytorch model object directly.

For example:

checkpoint = torch.load(checkpoint)
model = checkpoint['model']
model = model.to(device)

Here checkpoint[‘model’] is a pytorch model object.

After we have loaded an existing model, we can use it easily.

Leave a Reply