Original link: https://www.bodunhu.com/blog/posts/tensorir-transformation/
In the previous post , we’ve explored how to write primitive functions in TensorIR. Here, we will see how to transform TensorIR into other (potentially more performant) variants. The content is driven from the mlc course taught by Tianqi Chen .
Batched BMM ReLu
A batched matrix multiplication followed by a ReLu operation can be expressed using numpy as:
def lnumpy_mm_relu_v2 ( A : np . ndarray , B : np . ndarray , C : np . ndarray ): Y = np . empty (( 16 , 128 , 128 ), dtype = "float32" ) for n in range ( 16 ): for i in range ( 128 ): for j in range ( 128 ): for k in range ( 128 ): if k == 0 : Y [ n , i , j ] = 0 Y [ n , i , j ] = Y [ n , i , j ] + A [ n , i , k ] * B [ n , k , j ] for n in range ( 16 ): for i in range ( 128 ): for j in range ( 128 ): C [ n , i , j ] = max ( Y [ n , i , j ], 0 )
Translating the numpy code into TensorIR we get:
@tvm . script . ir_module class MyBmmRule : @T . prim_func def bmm_relu ( A : T . Buffer [( 16 , 128 , 128 ), "float32" ], W : T . Buffer [( 16 , 128 , 128 ), "float32" ], Y : T . Buffer [( 16 , 128 , 128 ), "float32" ]): T . func_attr ({ "global_symbol" : "bmm_relu" , "tir.noalias" : True }) # we must to allocate the buffer here! Y_ = T . alloc_buffer ([ 16 , 128 , 128 ], dtype = "float32" ) for n , i , j , k in T . grid ( 16 , 128 , 128 , 128 ): with T . block ( "M" ): vn = T . axis . spatial ( 16 , n ) vi = T . axis . spatial ( 128 , i ) vj = T . axis . spatial ( 128 , j ) vk = T . axis . reduce ( 128 , k ) with T . init (): Y_ [ vn , vi , vj ] = T . float32 ( 0 ) Y_ [ vn , vi , vj ] += A [ vn , vi , vk ] * W [ vn , vk , vj ] for n , i , j in T . grid ( 16 , 128 , 128 ): with T . block ( "R" ): vn = T . axis . spatial ( 16 , n ) vi = T . axis . spatial ( 128 , i ) vj = T . axis . spatial ( 128 , j ) Y [ vn , vi , vj ] = T . max ( Y_ [ vn , vi , vj ], T . float32 ( 0 ))
Our ultimate goal is to transform the TensorIR above to the following form:
@tvm . script . ir_module class TargetModule : @T . prim_func def bmm_relu ( A : T . Buffer [( 16 , 128 , 128 ), "float32" ], B : T . Buffer [( 16 , 128 , 128 ), "float32" ], C : T . Buffer [( 16 , 128 , 128 ), "float32" ]) -> None : T . func_attr ({ "global_symbol" : "bmm_relu" , "tir.noalias" : True }) Y = T . alloc_buffer ([ 16 , 128 , 128 ], dtype = "float32" ) for i0 in T . parallel ( 16 ): for i1 , i2_0 in T . grid ( 128 , 16 ): for ax0_init in T . vectorized ( 8 ): with T . block ( "Y_init" ): n , i = T . axis . remap ( "SS" , [ i0 , i1 ]) j = T . axis . spatial ( 128 , i2_0 * 8 + ax0_init ) Y [ n , i , j ] = T . float32 ( 0 ) for ax1_0 in T . serial ( 32 ): for ax1_1 in T . unroll ( 4 ): for ax0 in T . serial ( 8 ): with T . block ( "Y_update" ): n , i = T . axis . remap ( "SS" , [ i0 , i1 ]) j = T . axis . spatial ( 128 , i2_0 * 8 + ax0 ) k = T . axis . reduce ( 128 , ax1_0 * 4 + ax1_1 ) Y [ n , i , j ] = Y [ n , i , j ] + A [ n , i , k ] * B [ n , k , j ] for i2_1 in T . vectorized ( 8 ): with T . block ( "C" ): n , i = T . axis . remap ( "SS" , [ i0 , i1 ]) j = T . axis . spatial ( 128 , i2_0 * 8 + i2_1 ) C [ n , i , j ] = T . max ( Y [ n , i , j ], T . float32 ( 0 ))
Before we perform the transformation, let’s understand what the transformed TensorIR is doing by looking at several loops here.
First, taking a look at
for i1 , i2_0 in T . grid ( 128 , 16 ): for ax0_init in T . vectorized ( 8 ): with T . block ( "Y_init" ): n , i = T . axis . remap ( "SS" , [ i0 , i1 ]) j = T . axis . spatial ( 128 , i2_0 * 8 + ax0_init ) Y [ n , i , j ] = T . float32 ( 0 )
The code block is initializing the Y
matrix to be 0. But it does so by initializing every 8 consecutive elements in each row of Y
using a vectorized operation (which might be faster).
The next loop is bit tricky:
for ax1_0 in T . serial ( 32 ): for ax1_1 in T . unroll ( 4 ): for ax0 in T . serial ( 8 ): with T . block ( "Y_update" ): n , i = T . axis . remap ( "SS" , [ i0 , i1 ]) j = T . axis . spatial ( 128 , i2_0 * 8 + ax0 ) k = T . axis . reduce ( 128 , ax1_0 * 4 + ax1_1 ) Y [ n , i , j ] = Y [ n , i , j ] + A [ n , i , k ] * B [ n , k , j ]
This loop is actually performing the matrix multiplication of A
and B
. We mutiply a row in A
with a column in B
and sum up the result into a number.
Here, i
is mapped to i1
, which means we access A
one row at a time.i k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
means we access one row in matrix A
and one column in matrix B
sequentially duing mutiplying, while applying unrolling in hope for better access efficency (\(128 = 32\times 4))). j = T.axis.spatial(128, i2_0 * 8 + ax0)
really just means accessing each column sequentially, nothing special.
Perform Transformation
To perform tranformation on any TensorIP, it’s very important to follow the steps listed below :
- Get block
- Get loops
- Organize loops by split, reorder, compute_at/reverse_compute_at
- Decompose reduction
- vectorize/unroll/parallel
Applying step 1, 2, and 3, we first get the block from the original TensorIR:
sch = tvm . tir . Schedule ( MyBmmRule ) # Step 1. Get blocks block_M = sch . get_block ( "M" , func_name = "bmm_relu" ) # Step 2. Get loops n , i , j , k = sch . get_loops ( block_M ) # Step 3. Organize loops k0 , k1 = sch . split ( k , factors = [ 32 , 4 ]) j0 , j1 = sch . split ( j , factors = [ 16 , 8 ])
The reason we split k
and j
in such a way is: we already mentioned k
dimension is accessed sequentially but with unrolling (4) applied; when matrix Y
is initialized, a vectorized operation (applied on 8 elements) is applied to dimension j
, or every 8 elements in one row(TVM is row-major, therefore might be faster).
But the next question is: how do we reorder the spitted loop? I spent a lot of time trying to figure that out. Turns out the simplest way is to write out the implementation in numpy and proceed from there. Remember, we’ve already splitted k
and j
, which are used during matrix multiplication, so our new matrix multipliation in numy would be:
for j0 in range ( 16 ): for k0 in range ( 32 ): for k1 in range ( 4 ): for j1 in range ( 8 ): Y [ i , 8 * j0 + j1 ] += A [ i , 4 * k0 + k1 ] * B [ 4 * k0 + k1 , 8 * j0 + j1 ]
Because we move the the next column in B
after traversing the previous column, we will put j1
at the innermost loop. Therefore, the transformation for TensorIR would be:
sch . reorder ( j0 , k0 , k1 , j1 )
We can print out the transformed TensorIR:
This article is reproduced from: https://www.bodunhu.com/blog/posts/tensorir-transformation/
This site is for inclusion only, and the copyright belongs to the original author.