How to Save Best Model? Average Multiple Models for Inference – Pytorch Tutorial

By | October 26, 2023

When we are training a gan network or other network without validation set, when should we save our model? In this tutorial, we will introduce you a simple way.

In paper: Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time

We can find: averaging multiple models can improve the inference performance. We will implement this strategy in pytorch.

Here is an example:

import torch


def average_model(average_model_file, model_path_list):
    avg = None
    num = len(model_path_list)
    for path in model_path_list:
        print('Processing {}'.format(path))
        #load model
        states = torch.load(path, map_location=torch.device('cpu'))
        states = states['model'] if 'model' in states else states

        if avg is None:
            avg = states
        else:
            for k in avg.keys():
                avg[k] += states[k]
    # average
    for k in avg.keys():
        if avg[k] is not None:
            # pytorch 1.6 use true_divide instead of /=
            avg[k] = torch.true_divide(avg[k], num)
    print('Saving to {}'.format(average_model_file))
    torch.save(avg, average_model_file)
    return average_model_file

Then we can use it as follows:

average_model_file= "final_model.pt"
model_path_list= ["100.pt", "200.pt","300.pt"]
average_model(average_model_file, model_path_list)

Finally, we can use final_model.pt to make inference.

How to Save Best Model Average Multiple Models for Inference - Pytorch Tutorial