When you plan to use pytorch to build an AI model, you should know how to save and load a pytorch model. In this tutorial, we will introduce you how do do.
Save a pytorch model
We can use torch.save(object, PATH) to save an object to PATH.
For example:
torch.save(model.state_dict(), PATH)
This example will only save a pytorch model state_dict() to PATH.
However, we also can save more information to a path.
For example:
state = {'epoch': epoch, 'epochs_since_improvement': epochs_since_improvement, 'loss': loss, 'model': model, 'optimizer': optimizer} filename = 'checkpoint.tar' torch.save(state, filename)
Here model is a pytorch model object.
In this example, we will save epoch, loss, pytorch model and an optimizer to checkpoint.tar file.
Load a pytorch model
In pytorch, we can use torch.load() function to load an existing model.
As mentioned above, if we only save a pytorch model state_dict(), we can load a model as follows:
global device if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') checkpoint = 'checkpoint.tar' print('loading model: {}...'.format(checkpoint)) model = Tacotron2(HParams()) # create your pytorch model object state = torch.load(checkpoint, map_location=device) print(state) model.load_state_dict(state['model'].state_dict()) model = model.to(device)
In this way, we should:
(1) create a pytorch model first
model = Tacotron2(HParams()) # create your pytorch model object
(2) use torch.load() method to load
state = torch.load(checkpoint, map_location=device) print(state) model.load_state_dict(state['model'].state_dict())
(3) move pythorch to cpu or gpu
model = model.to(device)
However, if we have saved a pytorch model object, we can load an existing a pytorch model object directly.
For example:
checkpoint = torch.load(checkpoint) model = checkpoint['model'] model = model.to(device)
Here checkpoint[‘model’] is a pytorch model object.
After we have loaded an existing model, we can use it easily.