Dependencies
This extension is written against the OpenCL 3.0 C Language specification, V3.0.10.
This extension requires support for subgroups.
This extension uses many of the terms and concepts from the cl_intel_subgroup_matrix_multiply_accumulate
extension.
Overview
The goal of this extension is to allow programmers to access specialized hardware to compute the product of an M x K matrix with a K x N matrix and then add an M x N matrix accumulation value. This is a commonly used building block to compute the product of two large matrices.
The functionality described in this extension is very similar to the functionality described in the cl_intel_subgroup_matrix_multiply_accumulate
extension, with one key difference:
in this extension, work items across two subgroups cooperate to perform the operation.
This is done by splitting the M x K matrix source across two participating subgroups:
The first M-divided-by-2 rows of the matrix source are provided by the first subgroup, and the remaining M-divided-by-2 rows of the matrix source are provided by the second subgroup.
Splitting the matrix source improves performance by halving the amount of data each subgroup must load for the first matrix source.
New OpenCL C Functions
// 8-bit matrices:
int2 intel_sub_group_i8_i8_split_matrix_mad_k32(int a, int8 b, int2 acc); // M = 2
int4 intel_sub_group_i8_i8_split_matrix_mad_k32(int2 a, int8 b, int4 acc); // M = 4
int8 intel_sub_group_i8_i8_split_matrix_mad_k32(int4 a, int8 b, int8 acc); // M = 8
int2 intel_sub_group_i8_u8_split_matrix_mad_k32(int a, uint8 b, int2 acc); // ...
int4 intel_sub_group_i8_u8_split_matrix_mad_k32(int2 a, uint8 b, int4 acc);
int8 intel_sub_group_i8_u8_split_matrix_mad_k32(int4 a, uint8 b, int8 acc);
int2 intel_sub_group_u8_i8_split_matrix_mad_k32(uint a, int8 b, int2 acc);
int4 intel_sub_group_u8_i8_split_matrix_mad_k32(uint2 a, int8 b, int4 acc);
int8 intel_sub_group_u8_i8_split_matrix_mad_k32(uint4 a, int8 b, int8 acc);
int2 intel_sub_group_u8_u8_split_matrix_mad_k32(uint a, uint8 b, int2 acc);
int4 intel_sub_group_u8_u8_split_matrix_mad_k32(uint2 a, uint8 b, int4 acc);
int8 intel_sub_group_u8_u8_split_matrix_mad_k32(uint4 a, uint8 b, int8 acc);
// bfloat16 matrices:
float2 intel_sub_group_bf16_bf16_split_matrix_mad_k16(int a, int8 b, float2 acc);
float4 intel_sub_group_bf16_bf16_split_matrix_mad_k16(int2 a, int8 b, float4 acc);
float8 intel_sub_group_bf16_bf16_split_matrix_mad_k16(int4 a, int8 b, float8 acc);
// fp16 matrices:
float2 intel_sub_group_f16_f16_split_matrix_mad_k16(int a, int8 b, float2 acc);
float4 intel_sub_group_f16_f16_split_matrix_mad_k16(int2 a, int8 b, float4 acc);
float8 intel_sub_group_f16_f16_split_matrix_mad_k16(int4 a, int8 b, float8 acc);
Modifications to the OpenCL C Specification
Add a new Section 6.13.X - Subgroup Split Matrix Multiply Accumulate Instructions
This section describes a family of built-in functions that multiply two matrix sources a
and b
and then add a matrix accumulation value to produce a matrix result value.
Work items from two subgroups cooperate to perform this operation.
a
is the first matrix operand and has M rows and K columns.
Each subgroup provides half of the rows of the a
matrix.
b
is the second matrix operand and has K rows and N columns.
acc
is the matrix accumulation value and has M rows and N columns.
The result value also has M rows and N columns.
All work items in both subgroups cooperate to perform this operation.
These functions must be encountered by all work items in both subgroups executing the kernel.
The dimensions of the two source matrices and the elements of each source matrix are described by the built-in function name and its arguments.
As an example, given the function:
int2 intel_sub_group_u8_i8_split_matrix_mad_k32(uint a, int8 b, int2 acc);
-
a
is the first source matrix operand and hasM
rows andK
columns. This matrix operand is split across two participating subgroups. Work items from each participating subgroup provide half of the row data for this matrix.-
The value for
M
is determined by the number of vector components in the source operanda
. Since each subgroup provides half of the row data for this matrix, multiply the number of components ina
by two to compute the number of rowsM
. In the example above,a
is a scalaruint
argument, therefore the matrixa
operand hasM
equal to 2 rows. -
The value of
K
is described by the function name. In this case, the value ofK
is 32, therefore the matrixa
operand hasK
equal to 32 columns. -
The matrix component data type is also described by the function name. In this case, the matrix
a
component data type isu8
, indicating that the elements of the matrixa
operand are unsigned 8-bit integers. -
Each work item contributes part of this matrix. In this case, since the elements of the matrix
a
are 8-bit integers, and since each work item is contributing 32 bits (the size of auint
) of data per row of this matrix, each work item is contributing four 8-bit integer values per row. -
Since
K
is 32, and each work item is contributing four 8-bit values per row, the number of work items in the subgroup must be equal to 8.
-
-
b
is the second source matrix operand and hasK
rows andN
columns.-
Each work item contributes one column of this matrix. Therefore, the number of columns
N
is equivalent to the subgroup size. -
As above, the value of
K
is described by the function name. In this case, the value ofK
is 32, therefore the matrixb
operand hasK
equal to 32 rows. -
As above, the matrix component data type is described by the function name. In this case, the matrix
b
component data type isi8
, indicating that the elements of the matrixb
operand are signed 8-bit integers. -
Since
K
is 32 and the elements of the matrixb
are 8-bit integers, each work item must contribute 256 bits of source data to contributeK
values. The 256 bits of source data are packed and passed as theint8
argumentb
.
-
-
acc
specifies the accumulation value and hasM
rows andN
columns.-
As above, the value of
M
is determined by the number of components in the source operandacc
. In the example above,acc
is anint2
argument, therefore the accumulation value operand hasM
equal to 2 rows. -
Both
a
andacc
specify operands withM
rows, and the value ofM
is determined by the number of components in each source operand. Since each subgroup provides half of thea
matrix data, thea
operand will have half the number of components as theacc
source operand. -
As above, each work item contributes one column of accumulation values. Therefore, the number of columns
N
is equivalent to the subgroup size. -
The
acc
operand is a "full precision" accumulation value. In the example above, the matrices contain integer data, therefore theacc
operand is a vector ofint
data.
-
-
The result value returned by the function also has
M
rows andN
columns.-
As above, the value of
M
is determined by the number of components in the return type. In the example above, the return type isint2
, therefore the result value hasM
equal to 2 rows. -
The result value,
a
, andacc
all specify values withM
rows, and the value ofM
is determined by the number of components in each source operand or return type. Since each subgroup provides half of thea
matrix data, thea
operand will have half the number of components as the return type andacc
operand. -
As above, each work item will receive one column of result values. Therefore, the number of columns
N
is equivalent to the subgroup size. -
Similar to the
acc
operand, the return value is a "full precision" result value. In the example above, the matrices contain integer data, therefore the return type is a vector ofint
data.
-
The full list of supported functions is described in the overview, above. For this list of functions:
-
M
may be equal to 2, 4, or 8. -
N
must be equal to 8. In other words, the only supported subgroup size is 8. -
Supported integer matrix types for
a
andb
are any combination of signed or unsigned 8-bit integers. For these integer matrix types, the accumulation valueacc
and result value are signed 32-bit integers, andK
must be equal to 32. -
The supported floating-point matrix types for
a
andb
are fp16 (half) or bfloat16. For these floating-point matrix type, the accumulation valueacc
and result value are 32-bit floating-point values, andK
must be equal to 16.
Issues
-
Do we need to talk about which two subgroups cooperate to perform the split matrix multiplication?
UNRESOLVED: For now, this is left as an implementation detail, outside of the scope of this extension.
-
Should the built-in functions in this extension overload the built-ins from
cl_intel_subgroup_matrix_multiply_accumulate
, or define new functions?RESOLVED
: Switched to a non-overloaded syntax:intel_sub_group_i8_i8_split_matrix_mad_k32
. -
Should this extension use signed or unsigned types to represent fp16 and bf16 data?
RESOLVED
: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions such as cl_intel_bfloat16 conversions. See discussion incl_intel_subgroup_matrix_multiply_accumulate
.