Understand BaseModelOutputWithPoolingAndCrossAttentions with Examples – PyTorch Tutorial

By | June 13, 2023

When we get the output of a Bert mdole, we may get BaseModelOutputWithPoolingAndCrossAttentions object. In this tutorial, we will discuss it.

For example:

with torch.no_grad():
    model_output = model(**encoded_input)
    print("model_output.last_hidden_state shape = ", model_output.last_hidden_state.shape)
    print("model_output.pooler_output shape = ", model_output.pooler_output.shape)

We may get:

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.1958, -0.2254,  0.5879,  ...,  0.5189, -0.9704,  0.7224],
         [-0.5836, -0.6851,  0.3376,  ...,  0.7085, -0.5533,  0.2590],
         [-0.1957, -0.2260,  0.5873,  ...,  0.5188, -0.9710,  0.7216],
         [-0.1958, -0.2260,  0.5873,  ...,  0.5188, -0.9710,  0.7216],
         [-1.0514, -0.4288,  0.8458,  ...,  1.1722, -0.6951,  0.8225],
         [-0.1958, -0.2254,  0.5879,  ...,  0.5189, -0.9704,  0.7224]],

        [[-0.5236,  0.2747,  0.7207,  ...,  0.7099, -0.6590,  0.6492],
         [-0.9260,  0.0429, -0.1059,  ...,  1.0130,  0.2954,  0.5721],
         [-0.6988,  0.3200,  0.4998,  ...,  1.3675, -0.5426,  0.1605],
         [-0.5236,  0.2747,  0.7207,  ...,  0.7099, -0.6590,  0.6492],
         [-0.1056, -0.1332, -0.0261,  ...,  1.3496, -0.6363,  0.5059],
         [-0.0954, -0.1176, -0.0697,  ...,  1.3522, -0.6045,  0.5295]]]), pooler_output=tensor([[-0.7585, -0.1595,  0.4985,  ..., -0.2657, -0.0202,  0.3537],
        [-0.7301, -0.5412,  0.3729,  ..., -0.1573, -0.1320,  0.3476]]), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)
model_output.last_hidden_state shape =  torch.Size([2, 13, 1024])
model_output.pooler_output shape =  torch.Size([2, 1024])

We can find that model(**encoded_input) returns a BaseModelOutputWithPoolingAndCrossAttentions object.


It is defined here: https://huggingface.co/docs/transformers/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithCrossAttentions

It contains five variables.

last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

last_hidden_state and pooler_output are the most important.

last_hidden_state: (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) — Sequence of hidden-states at the output of the last layer of the model.

pooler_output: (torch.FloatTensor of shape (batch_size, hidden_size)) — Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.

It means pooler_output = [CLS] output

As to this example, we can find:

model_output.last_hidden_state shape =  torch.Size([2, 13, 1024])
model_output.pooler_output shape =  torch.Size([2, 1024])

hidden_size is defined in model config.json file. For example:

GanymedeNil text2vec-large-chinese

The content may be:

  "_name_or_path": "hfl/chinese-lert-large",
  "architectures": [
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "pooler_fc_size": 1024,
  "pooler_num_attention_heads": 16,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.26.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128