To save a pytorch model, we can use torch.save() function. Here is the tutorial:
Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial
However, we may want to save models by training steps. For example, you may want to save a model every 1,000 steps, which will make many models saved.
How to limit the model number? In this tutorial, we will use an example to show you how to do.
This is the example code.
- import os
- import glob
- def limit_model_count(model_dir_path, regex = "G_*.pth", max_model_num = 5):
- if os.path.isfile(model_dir_path):
- #get dir
- dirname = os.path.dirname(model_dir_path)
- model_dir_path = dirname
- #get all model by regex
- model_list = glob.glob(os.path.join(model_dir_path, regex))
- model_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))), reverse = True)
- if len(model_list) > max_model_num:
- delete_list = model_list[max_model_num:]
- for f in delete_list:
- os.remove(f)
We can use this function to limit the count of saved models after we have saved a new model.
For example:
- limit_model_count("log_emotion/", regex = "[GD]_*.pth", max_model_num = 10)
Run this code, we will see: