Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial

By | April 28, 2022

When you plan to use pytorch to build an AI model, you should know how to save and load a pytorch model. In this tutorial, we will introduce you how do do.

Save a pytorch model

We can use torch.save(object, PATH) to save an object to PATH.

For example:

  1. torch.save(model.state_dict(), PATH)
torch.save(model.state_dict(), PATH)

This example will only save a pytorch model state_dict() to PATH.

However, we also can save more information to a path.

For example:

  1. state = {'epoch': epoch,
  2. 'epochs_since_improvement': epochs_since_improvement,
  3. 'loss': loss,
  4. 'model': model,
  5. 'optimizer': optimizer}
  6. filename = 'checkpoint.tar'
  7. torch.save(state, filename)
state = {'epoch': epoch,
         'epochs_since_improvement': epochs_since_improvement,
         'loss': loss,
         'model': model,
         'optimizer': optimizer}

filename = 'checkpoint.tar'
torch.save(state, filename)

Here model is a pytorch model object.

In this example, we will save epoch, loss, pytorch model and an optimizer to checkpoint.tar file.

Load a pytorch model

In pytorch, we can use torch.load() function to load an existing model.

As mentioned above, if we only save a pytorch model state_dict(), we can load a model as follows:

  1. global device
  2. if torch.cuda.is_available():
  3. device = torch.device('cuda')
  4. else:
  5. device = torch.device('cpu')
  6. checkpoint = 'checkpoint.tar'
  7. print('loading model: {}...'.format(checkpoint))
  8. model = Tacotron2(HParams()) # create your pytorch model object
  9. state = torch.load(checkpoint, map_location=device)
  10. print(state)
  11. model.load_state_dict(state['model'].state_dict())
  12. model = model.to(device)
global device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

checkpoint = 'checkpoint.tar'
print('loading model: {}...'.format(checkpoint))
model = Tacotron2(HParams()) # create your pytorch model object
state = torch.load(checkpoint, map_location=device)
print(state)
model.load_state_dict(state['model'].state_dict())
model = model.to(device)

In this way, we should:

(1) create a pytorch model first

  1. model = Tacotron2(HParams()) # create your pytorch model object
model = Tacotron2(HParams()) # create your pytorch model object

(2) use torch.load() method to load

  1. state = torch.load(checkpoint, map_location=device)
  2. print(state)
  3. model.load_state_dict(state['model'].state_dict())
state = torch.load(checkpoint, map_location=device)
print(state)
model.load_state_dict(state['model'].state_dict())

(3) move pythorch to cpu or gpu

  1. model = model.to(device)
model = model.to(device)

However, if we have saved a pytorch model object, we can load an existing a pytorch model object directly.

For example:

  1. checkpoint = torch.load(checkpoint)
  2. model = checkpoint['model']
  3. model = model.to(device)
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
model = model.to(device)

Here checkpoint[‘model’] is a pytorch model object.

After we have loaded an existing model, we can use it easily.

Leave a Reply