When you plan to use pytorch to build an AI model, you should know how to save and load a pytorch model. In this tutorial, we will introduce you how do do.
Save a pytorch model
We can use torch.save(object, PATH) to save an object to PATH.
For example:
- torch.save(model.state_dict(), PATH)
This example will only save a pytorch model state_dict() to PATH.
However, we also can save more information to a path.
For example:
- state = {'epoch': epoch,
- 'epochs_since_improvement': epochs_since_improvement,
- 'loss': loss,
- 'model': model,
- 'optimizer': optimizer}
- filename = 'checkpoint.tar'
- torch.save(state, filename)
Here model is a pytorch model object.
In this example, we will save epoch, loss, pytorch model and an optimizer to checkpoint.tar file.
Load a pytorch model
In pytorch, we can use torch.load() function to load an existing model.
As mentioned above, if we only save a pytorch model state_dict(), we can load a model as follows:
- global device
- if torch.cuda.is_available():
- device = torch.device('cuda')
- else:
- device = torch.device('cpu')
- checkpoint = 'checkpoint.tar'
- print('loading model: {}...'.format(checkpoint))
- model = Tacotron2(HParams()) # create your pytorch model object
- state = torch.load(checkpoint, map_location=device)
- print(state)
- model.load_state_dict(state['model'].state_dict())
- model = model.to(device)
In this way, we should:
(1) create a pytorch model first
- model = Tacotron2(HParams()) # create your pytorch model object
(2) use torch.load() method to load
- state = torch.load(checkpoint, map_location=device)
- print(state)
- model.load_state_dict(state['model'].state_dict())
(3) move pythorch to cpu or gpu
- model = model.to(device)
However, if we have saved a pytorch model object, we can load an existing a pytorch model object directly.
For example:
- checkpoint = torch.load(checkpoint)
- model = checkpoint['model']
- model = model.to(device)
Here checkpoint[‘model’] is a pytorch model object.
After we have loaded an existing model, we can use it easily.