Pretrained bert models are usually trained in tensorflow, which can not be used in pytorch. In this tutorial, we will introduce you how to convert a tensorflow pretrained bert model to pytorch model. Then, you can load and use bert in pytorch.
Tensorflow Pretrained Bert Model
We will use tensorflow chinese_L-12_H-768_A-12 pretrained bert model in this tutorial.
The structure of it is:
How to convert tensorflow bert model to pytorch model?
It is easy to convert, we can use huggingface transformers to implement it.
Step 1: You should install pytorch
Step 2: You should install tensorflow
In this tutorial, we will use tensorflow 1.14
pip download -i https://mirrors.aliyun.com/pypi/simple/ tensorflow==1.14.0 --trusted-host mirrors.aliyun.com -d D:\python-packages\tensorflow\tensorflow
Step 3. You should install transformers
If you get PEP 517 error, you can read:
Step 4. Use convert_bert_original_tf_checkpoint_to_pytorch.py to convert
We can edit this file to convert. For example:
if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--tf_checkpoint_path", default="./bert/chinese_L-12_H-768_A-12/bert_model.ckpt", type=str, required=False, help="Path to the TensorFlow checkpoint path." ) parser.add_argument( "--bert_config_file", default="./bert/chinese_L-12_H-768_A-12/bert_config.json", type=str, required=False, help="The config json file corresponding to the pre-trained BERT model. \n" "This specifies the model architecture.", ) parser.add_argument( "--pytorch_dump_path", default="./pytorch_bert/pytorch_bert_model.bin", type=str, required=False, help="Path to the output PyTorch model." ) args = parser.parse_args() convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
In this code, we should notice:
(1)./bert/chinese_L-12_H-768_A-12/bert_model.ckpt is the tensorflow bert model
(2) pytorch_bert_model.bin is pytorch bert model, you can use this mode in pytorch.
Step 5. Copy bert_config.json and vocab.txt to pytorch_bert
We will see:
Then, we can load it in pytorch as follows:
from transformers import BertTokenizer from transformers import BertModel, BertConfig token = BertTokenizer.from_pretrained('./pytorch_bert/') config = BertConfig.from_json_file('./pytorch_bert/bert_config.json') bert_model = BertModel.from_pretrained('./pytorch_bert/pytorch_bert_chinese_L-12_H-768_A-12.bin', from_tf=False, config=config) print(bert_model)