Contributors
Ben Ashbaugh, Intel
Eugene Chereshnev, Intel
Junjie Gu, Intel
Bartosz Koscielak, Intel
Mike MacPherson, Intel
Ritesh Patel, Intel
Lukasz Towarek, Intel
Dependencies
This extension is written against the OpenCL 3.0 C Language specification, V3.0.10.
This extension requires support for subgroups.
This extension depends on cl_intel_required_subgroup_size to query the subgroup sizes supported by a device or to require a subgroup size for a kernel.
Overview
The goal of this extension is to allow programmers to access specialized hardware to compute the product of an M x K matrix with a K x N matrix and then add an M x N matrix accumulation value. This is a commonly used building block to compute the product of two large matrices. When used in an OpenCL kernel, all work items in the subgroup cooperate to perform this operation.
This is a low-level extension for expert programmers seeking to access this functionality directly in custom kernels. Most users will access this functionality via high-level libraries or frameworks.
New OpenCL C Functions
// These functions are available to devices where the minimum subgroup
// size is 8. For these devices, the subgroup size must be 8 (the
// minimum supported subgroup size). Calling these functions on other
// devices or from kernels with a different subgroup size is undefined
// behavior:
// 8-bit matrices:
int intel_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc); // M = 1
int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc); // M = 2
int4 intel_sub_group_i8_i8_matrix_mad_k32(int4 a, int8 b, int4 acc); // M = 4
int8 intel_sub_group_i8_i8_matrix_mad_k32(int8 a, int8 b, int8 acc); // M = 8
int intel_sub_group_i8_u8_matrix_mad_k32(int a, uint8 b, int acc); // ...
int2 intel_sub_group_i8_u8_matrix_mad_k32(int2 a, uint8 b, int2 acc);
int4 intel_sub_group_i8_u8_matrix_mad_k32(int4 a, uint8 b, int4 acc);
int8 intel_sub_group_i8_u8_matrix_mad_k32(int8 a, uint8 b, int8 acc);
int intel_sub_group_u8_i8_matrix_mad_k32(uint a, int8 b, int acc);
int2 intel_sub_group_u8_i8_matrix_mad_k32(uint2 a, int8 b, int2 acc);
int4 intel_sub_group_u8_i8_matrix_mad_k32(uint4 a, int8 b, int4 acc);
int8 intel_sub_group_u8_i8_matrix_mad_k32(uint8 a, int8 b, int8 acc);
int intel_sub_group_u8_u8_matrix_mad_k32(uint a, uint8 b, int acc);
int2 intel_sub_group_u8_u8_matrix_mad_k32(uint2 a, uint8 b, int2 acc);
int4 intel_sub_group_u8_u8_matrix_mad_k32(uint4 a, uint8 b, int4 acc);
int8 intel_sub_group_u8_u8_matrix_mad_k32(uint8 a, uint8 b, int8 acc);
// 4-bit matrices:
int intel_sub_group_i4_i4_matrix_mad_k64(int a, int8 b, int acc);
int2 intel_sub_group_i4_i4_matrix_mad_k64(int2 a, int8 b, int2 acc);
int4 intel_sub_group_i4_i4_matrix_mad_k64(int4 a, int8 b, int4 acc);
int8 intel_sub_group_i4_i4_matrix_mad_k64(int8 a, int8 b, int8 acc);
int intel_sub_group_i4_u4_matrix_mad_k64(int a, uint8 b, int acc);
int2 intel_sub_group_i4_u4_matrix_mad_k64(int2 a, uint8 b, int2 acc);
int4 intel_sub_group_i4_u4_matrix_mad_k64(int4 a, uint8 b, int4 acc);
int8 intel_sub_group_i4_u4_matrix_mad_k64(int8 a, uint8 b, int8 acc);
int intel_sub_group_u4_i4_matrix_mad_k64(uint a, int8 b, int acc);
int2 intel_sub_group_u4_i4_matrix_mad_k64(uint2 a, int8 b, int2 acc);
int4 intel_sub_group_u4_i4_matrix_mad_k64(uint4 a, int8 b, int4 acc);
int8 intel_sub_group_u4_i4_matrix_mad_k64(uint8 a, int8 b, int8 acc);
int intel_sub_group_u4_u4_matrix_mad_k64(uint a, uint8 b, int acc);
int2 intel_sub_group_u4_u4_matrix_mad_k64(uint2 a, uint8 b, int2 acc);
int4 intel_sub_group_u4_u4_matrix_mad_k64(uint4 a, uint8 b, int4 acc);
int8 intel_sub_group_u4_u4_matrix_mad_k64(uint8 a, uint8 b, int8 acc);
// bfloat16 matrices:
float intel_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc);
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc);
float4 intel_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc);
float8 intel_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc);
// fp16 matrices:
float intel_sub_group_f16_f16_matrix_mad_k16(int a, int8 b, float acc);
float2 intel_sub_group_f16_f16_matrix_mad_k16(int2 a, int8 b, float2 acc);
float4 intel_sub_group_f16_f16_matrix_mad_k16(int4 a, int8 b, float4 acc);
float8 intel_sub_group_f16_f16_matrix_mad_k16(int8 a, int8 b, float8 acc);
// These functions are available to devices where the minimum subgroup
// size is 16. For these devices, the subgroup size must be 16 (the
// minimum supported subgroup size). Calling these functions on other
// devices or from kernels with a different subgroup size is undefined
// behavior:
// 8-bit matrices:
int intel_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc); // M = 1
int2 intel_sub_group_i8_i8_matrix_mad_k32(short2 a, int8 b, int2 acc); // M = 2
int4 intel_sub_group_i8_i8_matrix_mad_k32(short4 a, int8 b, int4 acc); // M = 4
int8 intel_sub_group_i8_i8_matrix_mad_k32(short8 a, int8 b, int8 acc); // M = 8
int intel_sub_group_i8_u8_matrix_mad_k32(short a, uint8 b, int acc); // ...
int2 intel_sub_group_i8_u8_matrix_mad_k32(short2 a, uint8 b, int2 acc);
int4 intel_sub_group_i8_u8_matrix_mad_k32(short4 a, uint8 b, int4 acc);
int8 intel_sub_group_i8_u8_matrix_mad_k32(short8 a, uint8 b, int8 acc);
int intel_sub_group_u8_i8_matrix_mad_k32(ushort a, int8 b, int acc);
int2 intel_sub_group_u8_i8_matrix_mad_k32(ushort2 a, int8 b, int2 acc);
int4 intel_sub_group_u8_i8_matrix_mad_k32(ushort4 a, int8 b, int4 acc);
int8 intel_sub_group_u8_i8_matrix_mad_k32(ushort8 a, int8 b, int8 acc);
int intel_sub_group_u8_u8_matrix_mad_k32(ushort a, uint8 b, int acc);
int2 intel_sub_group_u8_u8_matrix_mad_k32(ushort2 a, uint8 b, int2 acc);
int4 intel_sub_group_u8_u8_matrix_mad_k32(ushort4 a, uint8 b, int4 acc);
int8 intel_sub_group_u8_u8_matrix_mad_k32(ushort8 a, uint8 b, int8 acc);
// 4-bit matrices:
int intel_sub_group_i4_i4_matrix_mad_k64(short a, int8 b, int acc);
int2 intel_sub_group_i4_i4_matrix_mad_k64(short2 a, int8 b, int2 acc);
int4 intel_sub_group_i4_i4_matrix_mad_k64(short4 a, int8 b, int4 acc);
int8 intel_sub_group_i4_i4_matrix_mad_k64(short8 a, int8 b, int8 acc);
int intel_sub_group_i4_u4_matrix_mad_k64(short a, uint8 b, int acc);
int2 intel_sub_group_i4_u4_matrix_mad_k64(short2 a, uint8 b, int2 acc);
int4 intel_sub_group_i4_u4_matrix_mad_k64(short4 a, uint8 b, int4 acc);
int8 intel_sub_group_i4_u4_matrix_mad_k64(short8 a, uint8 b, int8 acc);
int intel_sub_group_u4_i4_matrix_mad_k64(ushort a, int8 b, int acc);
int2 intel_sub_group_u4_i4_matrix_mad_k64(ushort2 a, int8 b, int2 acc);
int4 intel_sub_group_u4_i4_matrix_mad_k64(ushort4 a, int8 b, int4 acc);
int8 intel_sub_group_u4_i4_matrix_mad_k64(ushort8 a, int8 b, int8 acc);
int intel_sub_group_u4_u4_matrix_mad_k64(ushort a, uint8 b, int acc);
int2 intel_sub_group_u4_u4_matrix_mad_k64(ushort2 a, uint8 b, int2 acc);
int4 intel_sub_group_u4_u4_matrix_mad_k64(ushort4 a, uint8 b, int4 acc);
int8 intel_sub_group_u4_u4_matrix_mad_k64(ushort8 a, uint8 b, int8 acc);
// bfloat16 matrices with float accumulator:
float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc);
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc);
float4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc);
float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc);
// fp16 matrices with float accumulator:
float intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, float acc);
float2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, float2 acc);
float4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, float4 acc);
float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, float8 acc);
// bfloat16 with bfloat16 accumulator:
short intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, short acc);
short2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, short2 acc);
short4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, short4 acc);
short8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, short8 acc);
// fp16 matrices with fp16 accumulator:
half intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, half acc);
half2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, half2 acc);
half4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, half4 acc);
half8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, half8 acc);
Modifications to the OpenCL C Specification
Add a new Section 6.13.X - Subgroup Matrix Multiply Accumulate Instructions
This section describes a family of built-in functions that multiply two matrix sources a and b and then add a matrix accumulation value to produce a matrix result value.
a is the first matrix operand and has M rows and K columns.
b is the second matrix operand and has K rows and N columns.
acc is the matrix accumulation value and has M rows and N columns.
The result value also has M rows and N columns.
All work items in the subgroup cooperate to perform this operation.
These functions must be encountered by all work items in the subgroup executing the kernel.
The dimensions of the two source matrices and the elements of each source matrix are described by the built-in function name and its arguments.
As an example, given the function:
int2 intel_sub_group_u8_i8_matrix_mad_k32(uint2 a, int8 b, int2 acc);
-
ais the first source matrix operand and hasMrows andKcolumns.-
The value for
Mis determined by the number of vector components in the source operanda. In the example above,ais auint2argument, therefore the matrixaoperand hasMequal to 2 rows. -
The value of
Kis described by the function name. In this case, the value ofKis 32, therefore the matrixaoperand hasKequal to 32 columns. -
The matrix component data type is also described by the function name. In this case, the matrix
acomponent data type isu8, indicating that the elements of the matrixaoperand are unsigned 8-bit integers. -
Each work item contributes part of this matrix. In this case, since the elements of the matrix
aare 8-bit integers, and since each work item is contributing 32 bits (the size of auint) of data per row of this matrix, each work item is contributing four 8-bit integer values per row. -
Since
Kis 32, and each work item is contributing four 8-bit values per row, the number of work items in the subgroup must be equal to 8.
-
-
bis the second source matrix operand and hasKrows andNcolumns.-
Each work item contributes one column of this matrix. Therefore, the number of columns
Nis equivalent to the subgroup size. -
As above, the value of
Kis described by the function name. In this case, the value ofKis 32, therefore the matrixboperand hasKequal to 32 rows. -
As above, the matrix component data type is described by the function name. In this case, the matrix
bcomponent data type isi8, indicating that the elements of the matrixboperand are signed 8-bit integers. -
Since
Kis 32 and the elements of the matrixbare 8-bit integers, each work item must contribute 256 bits of source data to contributeKvalues. The 256 bits of source data are packed and passed as theint8argumentb.
-
-
accspecifies the accumulation value and hasMrows andNcolumns.-
As above, the value of
Mis determined by the number of components in the source operandacc. In the example above,accis anint2argument, therefore the accumulation value operand hasMequal to 2 rows. -
Since both
aandaccspecify operands withMrows, and since the value ofMis determined by the number of components in the source operand, both theaandaccoperands will be vector operands with the same number of components. -
As above, each work item contributes one column of accumulation values. Therefore, the number of columns
Nis equivalent to the subgroup size. -
The
accoperand is a "full precision" accumulation value. In the example above, the matrices contain integer data, therefore theaccoperand is a vector ofintdata.
-
-
The result value returned by the function also has
Mrows andNcolumns.-
As above, the value of
Mis determined by the number of components in the return type. In the example above, the return type isint2, therefore the result value hasMequal to 2 rows. -
Since the result value,
a, andaccall specify values withMrows, and since the value ofMis determined by the number of components in the source operand or return type, the return tye,a, andaccwill all be vectors with the same number of components. -
As above, each work item will receive one column of result values. Therefore, the number of columns
Nis equivalent to the subgroup size. -
Similar to the
accoperand, the return value is a "full precision" result value. In the example above, the matrices contain integer data, therefore the return type is a vector ofintdata.
-
The full list of supported functions is described in the overview, above. For this list of functions:
-
Mmay be equal to 1, 2, 4, or 8. -
Nmust be equal to 8 for some devices or 16 for other devices. In other words, the only supported subgroup sizes are 8 and 16. -
Supported integer matrix types for
aandbare any combination of signed or unsigned 8-bit integers, or any combination of signed or unsigned 4-bit integers. For 8-bit matrices,Kmust be equal to 32. For 4-bit matrices,Kmust be equal to 64. For these integer matrix types, the accumulation valueaccand result value are signed 32-bit integers. -
The supported floating-point matrix types for
aandbare fp16 (half) or bfloat16. For these floating-point matrices,Kmust be equal to 16. The accumulation valueaccand result value are 32-bit floating-point values. For devices withNequal to 16, the accumulation valueaccand result value may also be fp16 for fp16 matrices, or bfloat16 for bfloat16 matrices.
Coding Sample
// The code below shows a functional implementation of one of the
// built-in functions added by this extension. For this built-in
// function:
// * M = 2, since the result value, a operand, and acc operand
// are all vectors with two components.
// * N = 8, and is equal to the subgroup size.
// * K = 32, as described by the function name.
// * The elements of both matrix a and matrix b are signed 8-bit
// integers.
// This is a helper function that performs the dot product of
// two vectors of four components of 8-bit integer data, and then
// adds a 32-bit integer accumulation value.
static int __intel_dot_product_accumulate( char4 a, char4 b, int acc )
{
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w + acc;
}
// This is a helper function that computes the product of a
// 1 x 32 row vector value shared across the subgroup and a 32 x 1
// column vector, that is added to a full precision accumulation
// value.
static int __intel_vector_matrix_multiply_accumulate_k32( int v, int8 b, int acc )
{
// Note: 8 is the size of the subgroup.
// As K is 32, and the size of the subgroup is 8, each
// work item contributes 4 elements of the 1 x K vector.
// as_char4() is used to reinterpret 32-bits of data
// as four components of 8-bit data.
int result = acc;
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 0 ) ), as_char4( b.s0 ), result );
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 1 ) ), as_char4( b.s1 ), result );
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 2 ) ), as_char4( b.s2 ), result );
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 3 ) ), as_char4( b.s3 ), result );
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 4 ) ), as_char4( b.s4 ), result );
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 5 ) ), as_char4( b.s5 ), result );
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 6 ) ), as_char4( b.s6 ), result );
result = __intel_dot_product_accumulate(
as_char4( sub_group_broadcast( v, 7 ) ), as_char4( b.s7 ), result );
return result;
}
int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc)
{
int2 result;
result.x = __intel_vector_matrix_multiply_accumulate_k32( a.x, b, acc.x );
result.y = __intel_vector_matrix_multiply_accumulate_k32( a.y, b, acc.y );
return result;
}
Modifications to the OpenCL SPIR-V Environment Specification
|
SPIR-V support was added in extension version 1.1.0. |
Add a new section 5.2.X - cl_intel_subgroup_matrix_multiply_accumulate
If the OpenCL environment supports the extension cl_intel_subgroup_matrix_multiply_accumulate then the environment must accept modules that declare use of the extension SPV_INTEL_subgroup_matrix_multiply_accumulate and that declare the SPIR-V capability SubgroupMatrixMultiplyAccumulateINTEL.
For devices where the minimum subgroup size is 8, the following matrix dimensions and types are supported. For these devices, the subgroup size must be 8 (the minimum subgroup size). Behavior is undefined if these functions are called on other devices or from kernels with a different subgroup size:
| M Dimension | N Dimension | K Dimension | Result Type | Matrix A Type | Matrix B Type | Matrix C Type |
|---|---|---|---|---|---|---|
8-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator: |
||||||
1, 2, 4, 8 |
8 |
32 |
|
|
|
|
1, 2, 4, 8 |
8 |
32 |
|
|
|
|
1, 2, 4, 8 |
8 |
32 |
|
|
|
|
1, 2, 4, 8 |
8 |
32 |
|
|
|
|
4-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator: |
||||||
1, 2, 4, 8 |
8 |
64 |
|
|
|
|
1, 2, 4, 8 |
8 |
64 |
|
|
|
|
1, 2, 4, 8 |
8 |
64 |
|
|
|
|
1, 2, 4, 8 |
8 |
64 |
|
|
|
|
fp16 matrix sources, fp32 accumulator: |
||||||
1, 2, 4, 8 |
8 |
16 |
|
|
|
|
bf16 matrix sources, fp32 accumulator: |
||||||
1, 2, 4, 8 |
8 |
16 |
|
|
|
|
For devices where the minimum subgroup size is 16, the following matrix dimensions and types are supported. For these devices, the subgroup size must be 16 (the minimum subgroup size). Behavior is undefined if these functions are called on other devices or from kernels with a different subgroup size:
| M Dimension | N Dimension | K Dimension | Result Type | Matrix A Type | Matrix B Type | Matrix C Type |
|---|---|---|---|---|---|---|
8-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator: |
||||||
1, 2, 4, 8 |
16 |
32 |
|
|
|
|
1, 2, 4, 8 |
16 |
32 |
|
|
|
|
1, 2, 4, 8 |
16 |
32 |
|
|
|
|
1, 2, 4, 8 |
16 |
32 |
|
|
|
|
4-bit integer matrix sources (signed and unsigned), 32-bit integer accumulator: |
||||||
1, 2, 4, 8 |
16 |
64 |
|
|
|
|
1, 2, 4, 8 |
16 |
64 |
|
|
|
|
1, 2, 4, 8 |
16 |
64 |
|
|
|
|
1, 2, 4, 8 |
16 |
64 |
|
|
|
|
fp16 matrix sources, fp32 accumulator: |
||||||
1, 2, 4, 8 |
16 |
16 |
|
|
|
|
bf16 matrix sources, fp32 accumulator: |
||||||
1, 2, 4, 8 |
16 |
16 |
|
|
|
|
fp16 matrix sources, fp16 accumulator: |
||||||
1, 2, 4, 8 |
16 |
16 |
|
|
|
|
bf16 matrix sources, bf16 accumulator: |
||||||
1, 2, 4, 8 |
16 |
16 |
|
|
|
|
Issues
-
Should this extension use signed or unsigned types to represent fp16 and bf16 data?
RESOLVED: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions, such as thecl_intel_bfloat16_conversionsextension. This inconsistency may be addressed in a future extension or in a future version of this extension. Applications are encouraged to useas_typeto reinterpret unsigned data as signed data as needed to use the functions added by this extension.