When we are training a gan network or other network without validation set, when should we save our model? In this tutorial, we will introduce you a simple way.
We can find: averaging multiple models can improve the inference performance. We will implement this strategy in pytorch.
Here is an example:
import torch def average_model(average_model_file, model_path_list): avg = None num = len(model_path_list) for path in model_path_list: print('Processing {}'.format(path)) #load model states = torch.load(path, map_location=torch.device('cpu')) states = states['model'] if 'model' in states else states if avg is None: avg = states else: for k in avg.keys(): avg[k] += states[k] # average for k in avg.keys(): if avg[k] is not None: # pytorch 1.6 use true_divide instead of /= avg[k] = torch.true_divide(avg[k], num) print('Saving to {}'.format(average_model_file)) torch.save(avg, average_model_file) return average_model_file
Then we can use it as follows:
average_model_file= "final_model.pt" model_path_list= ["100.pt", "200.pt","300.pt"] average_model(average_model_file, model_path_list)
Finally, we can use final_model.pt to make inference.