Cache mechanism in Transformer

Author: Liu Shaokong (an NLP algorithm engineer)

The Encoder part is relatively simple, and only a batch-related mask needs to be considered when performing self-attention. Here we focus on the working mechanism of the decoder attention at each layer in the training and inference modes.

In the training mode, the decoder part uses the teacher_forcing mechanism to generate the input of the decoder. The specific implementation method is to move the original input_target_sequence to the right by one position, or it can be understood as adding a decode_start_token to the leftmost of the original input_target_sequence.

Let’s first examine the self_attention of the decoder, and the mask is two parts tgt_mask and self_attention_mask. Among them, tgt_mask is related to the length of tgt, and self_attention_mask is in the form of a triangular matrix (diagonal and lower triangle are 0, and upper triangle is a large negative number, such as -1e9), which can ensure that when calculating the token of a certain position, this The token after a position has no effect on the output result of the position. The principle is as follows

softmax(K*q + (-1e9)) * V

Here K, V is the set of k and v corresponding to any position after the current q position.

In the training phase, the self_attention of each layer calculates the output corresponding to each position in parallel through teacher_forcing and mask (tgt_mask + self_attention_mask). (Parallel computing means that the calculation method is the same as the self_attention calculation method of the encoder part, all input at one time instead of only one token at a time).

Next is the cross_attention part, the Q of the cross_attention part is obtained by the output of self_attention through a q_proj transformation matrix, K and V are obtained by the output of the encoder through two transformation matrices k_proj and v_proj respectively, and then use (Q, K, V) to compute the output for each location.

Each layer of the Decoder is superimposed. When the output of the last layer is reached, the output vector of each position is obtained through a softmax_embedding matrix transformation, and its size is equal to the vocabulary size of the tgt language. At this time, the loss in a batch can be calculated. At this time The loss also needs to consider the length of each sentence in a batch, that is, it needs to be multiplied by a tgt_mask.

We can see that in the training phase, since the teacher_forcing and mask mechanisms are used, the input_tokens of the decoder side in a batch can be input at one time, and finally the loss of the batch can be obtained through the loss function. The calculation results of self_attention and cross_attention of each layer in the middle of the Decoder do not need to be used later, so they do not need to be saved.

Next, let’s look at the attention calculation of each part in the inference state of the transformer.

Since all the information on the Encoder side is known, the input and calculation modes are the same as in the training phase, and only involve self-attention calculations in the general form.

On the decoder side, only one token is input at a time (one token is input for each sentence in the batch, and batch_size tokens are actually input). In a decoder_layer, the calculations of self_attention and cross_attention are performed in turn. Assuming that n tokens have been decoded when this token is input, then self_attention only needs to know the q corresponding to the token at the current decoding position and the (K, V) corresponding to the previous n tokens. It can be seen here that the (K, V) of the previous n tokens can be saved. At this time, it is only necessary to calculate the (q, k, v) of the current token in this layer, where q is used for the previous n tokens ( K, V) perform attenton calculation, and after the calculation is completed, add the (k, v) of the current position to (K, V) respectively, as the (K, V) in the next decoding.

Self attention does not affect the information that has been generated at the previous position, and self_attention only outputs the hidden_state vector of the current decoding position to the next corss_attention. The (K, V) of cross attention is generated during the first decoding and reused in subsequent decoding processes, so it can be saved . Here (K, V) of cross_attention is obtained by the output of the encoder (encoder_hidden_states) through the matrix transformation of k_proj and v_proj of cross_attention of this layer, so in subsequent decoding, it will not change with the gradual increase of the decoding length .

In summary, the decoder of the transformer model does not need to save the calculation results of each layer during training, but only needs to finally output the corresponding token classification label (vocabulary size) at each position to calculate the loss (cross_entropy) with true_label. In the inference mode, since only one token is input at a time, the (K, V) corresponding to the decoded tokens can be saved and used directly in the calculation of self_attention and cross_attention. The (K, V) of each layer of self_attention varies with As the decoding length increases, the (K, V) of each layer of cross_attention is calculated at the first decoding (converted from the encoder-outputs), and it does not change with the increase of the decoding length.

Based on the above thought,

1. We have added an inference cross_attention cache mechanism to the open source framework THUMT in practice.

2. In the open source project fastt5, the transformer(t5) model is split into 3 onnx models, (encoder.onnx, decoder_init.onnx, decoder.onnx), where decoder_init.onnx only involves the first step of decoding, that is, generating cross_attention K and V of , and K and V of self_attention.

Therefore, if the above 3 onnx models are degenerate into 2, you can create or calculate the self_attn_kv and cross_attn_kv of the decoder part when the encoder outputs,

Among them, decoder_seq_length is set to 0, and corresponding changes are made in subsequent decoder decoding.

The shape of self_attn_values ​​and cross_attn_values ​​and the shape of their corresponding keys are only different in the last dimension (value_channels).

The keys and values ​​related to self_attn can be created by torch.ones(), and the keys and values ​​of cross_attn need to be calculated by extracting the k_proj and v_proj parameters of each layer of the decoder (the model structure needs to be fine-tuned).

This article is reproduced from: https://www.52nlp.cn/transformer%E4%B8%AD%E7%9A%84%E7%BC%93%E5%AD%98%E6%9C%BA%E5%88%B6
This site is for inclusion only, and the copyright belongs to the original author.

Leave a Comment