An Introduction to Onnx Inference in Python – Onnx Tutorial

By | October 8, 2023

We have learned how to convert a pytorch model to onnx, here is the tutorial:

Completed Guide to Export PyTorch Models to ONNX – PyTorch Tutorial

In this tutorial, we will introduce how to make an inference based on onnx model in python.

Preliminary

We should import some packages.

import onnx, onnxruntime

Check onnx model is valid

Before using onnx model, we should be sure it is valid. We can use code below to check.

onnx.load("test.onnx")
onnx.checker.check_model(onnx_model)

This code will check test.onnx is valid or not.

Load onnx model to infer

Here is an example:

num_thread = 1
sessionOptions = onnxruntime.SessionOptions()
sessionOptions.intra_op_num_threads = num_thread

encoder_ort_session = onnxruntime.InferenceSession("test.onnx", sess_options = sessionOptions)

We use SessionOptions() to allow only one thread to run onnx model.

Understand onnxruntime.SessionOptions() intra_op_num_threads – Onnx Tutorial

In this code, we used onnxruntime.InferenceSession() to load an onnx model. Then, we can start to make an inference.

In order to make an inference, we should know what are the inputs and outpus of this onnx model.

Build inputs

We can build onnx model inputs based on pytorch model.

Here is an example:

encoder_ort_inputs = {
                      encoder_ort_session.get_inputs()[0].name: x_tst,
                      encoder_ort_session.get_inputs()[1].name: x_tst_lengths,
                      encoder_ort_session.get_inputs()[2].name: speed,
                      encoder_ort_session.get_inputs()[3].name: sentiment_prob,
                     }

We should notice: x_tst, x_tst_lengths, speed and sentiment_prob are numpy ndarray.

Build outputs

It determins what outputs we should get, here is an example:

encoder_ort_outputs = ['z_p', 'y_mask']

Finally, we can make an inference as follows:

encoder_ort_outs = encoder_ort_session.run(encoder_ort_outpus, encoder_ort_inputs)

We also can set encoder_ort_outpus = None, which means we will get all outputs of test.onnx.

At last, we can get:

z_p = encoder_ort_outs[0]

y_mask = encoder_ort_outs[1]

Both of them are numpy ndarray.

An Introduction to Onnx Inference in Python - Onnx Tutorial