PyTorch Save Models with Limited Model Number – PyTorch Tutorial

By | December 21, 2023

To save a pytorch model, we can use torch.save() function. Here is the tutorial:

Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial

However, we may want to save models by training steps. For example, you may want to save a model every 1,000 steps, which will make many models saved.

How to limit the model number? In this tutorial, we will use an example to show you how to do.

This is the example code.

  1. import os
  2. import glob
  3. def limit_model_count(model_dir_path, regex = "G_*.pth", max_model_num = 5):
  4. if os.path.isfile(model_dir_path):
  5. #get dir
  6. dirname = os.path.dirname(model_dir_path)
  7. model_dir_path = dirname
  8. #get all model by regex
  9. model_list = glob.glob(os.path.join(model_dir_path, regex))
  10. model_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))), reverse = True)
  11. if len(model_list) > max_model_num:
  12. delete_list = model_list[max_model_num:]
  13. for f in delete_list:
  14. os.remove(f)
import os
import glob

def limit_model_count(model_dir_path, regex = "G_*.pth", max_model_num = 5):
  if os.path.isfile(model_dir_path):
    #get dir
    dirname = os.path.dirname(model_dir_path)
    model_dir_path = dirname
  #get all model by regex
  model_list = glob.glob(os.path.join(model_dir_path, regex))
  model_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))), reverse = True)
  if len(model_list) > max_model_num:
    delete_list = model_list[max_model_num:]
    for f in delete_list:
      os.remove(f)

We can use this function to limit the count of saved models after we have saved a new model.

For example:

  1. limit_model_count("log_emotion/", regex = "[GD]_*.pth", max_model_num = 10)
limit_model_count("log_emotion/", regex = "[GD]_*.pth", max_model_num = 10)

Run this code, we will see:

PyTorch Save Models with Limited Model Number - PyTorch Tutorial