Convert TensorFlow Pretrained Bert Model to PyTorch Model – PyTorch Tutorial

By | June 6, 2022

Pretrained bert models are usually trained in tensorflow, which can not be used in pytorch. In this tutorial, we will introduce you how to convert a tensorflow pretrained bert model to pytorch model. Then, you can load and use bert in pytorch.

Tensorflow Pretrained Bert Model

We will use tensorflow chinese_L-12_H-768_A-12 pretrained bert model in this tutorial.

The structure of it is:

Tensorflow Pretrained Bert Model

How to convert tensorflow bert model to pytorch model?

It is easy to convert, we can use huggingface transformers to implement it.

Step 1: You should install pytorch

Step 2: You should install tensorflow

In this tutorial, we will use tensorflow 1.14

pip download -i https://mirrors.aliyun.com/pypi/simple/ tensorflow==1.14.0 --trusted-host mirrors.aliyun.com -d D:\python-packages\tensorflow\tensorflow

Step 3. You should install transformers

If you get PEP 517 error, you can read:

Fix ERROR: Could not build wheels for tokenizers which use PEP 517 and cannot be installed directly – Bert Tutorial

Step 4. Use convert_bert_original_tf_checkpoint_to_pytorch.py to convert

We can edit this file to convert. For example:

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default="./bert/chinese_L-12_H-768_A-12/bert_model.ckpt", type=str, required=False, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default="./bert/chinese_L-12_H-768_A-12/bert_config.json",
        type=str,
        required=False,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default="./pytorch_bert/pytorch_bert_model.bin", type=str, required=False, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

In this code, we should notice:

(1)./bert/chinese_L-12_H-768_A-12/bert_model.ckpt is the tensorflow bert model

(2) pytorch_bert_model.bin is pytorch bert model, you can use this mode in pytorch.

Step 5. Copy bert_config.json and vocab.txt to pytorch_bert

We will see:

Convert TensorFlow Pretrained Bert Model to PyTorch Model - PyTorch Tutorial

Then, we can load it in pytorch as follows:

from transformers import BertTokenizer
from transformers import BertModel, BertConfig

token = BertTokenizer.from_pretrained('./pytorch_bert/')
config = BertConfig.from_json_file('./pytorch_bert/bert_config.json')
bert_model = BertModel.from_pretrained('./pytorch_bert/pytorch_bert_chinese_L-12_H-768_A-12.bin', from_tf=False, config=config)
print(bert_model)

Leave a Reply