@@ -2058,7 +2058,7 @@ class joint_matrix {
20582058 const size_t num_elements;
20592059};
20602060
2061- // / Collectively loads 1 8x8 b16 (128 bytes) matrix from private memory to local
2061+ // / Collectively loads 1 8x8 b16 (128 bytes) matrix from local memory to private
20622062// / memory per sub-group. Requires the sub-group size of kernel calling this
20632063// / function to be 32.
20642064// / 'mat' specifies the matrix index to be loaded. The first '(mat + 1) * 8'
@@ -2135,7 +2135,7 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
21352135 }
21362136}
21372137
2138- // / Collectively loads 2 8x8 b16 (256 bytes) matrix from private memory to local
2138+ // / Collectively loads 2 8x8 b16 (256 bytes) matrix from local memory to private
21392139// / memory per sub-group. Requires the sub-group size of kernel calling this
21402140// / function to be 32.
21412141// / The first 16 work items of sub-group contain the starting address of their
@@ -2172,7 +2172,7 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21722172 ldmatrix (addr, m2, trans, 1 );
21732173}
21742174
2175- // / Collectively loads 4 8x8 b16 (512 bytes) matrix from private memory to local
2175+ // / Collectively loads 4 8x8 b16 (512 bytes) matrix from local memory to private
21762176// / memory per sub-group. Requires the sub-group size of kernel calling this
21772177// / function to be 32.
21782178// / Each work item of sub-group contains the starting address of their
@@ -2218,6 +2218,166 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218 ldmatrix (addr, m4, trans, 3 );
22192219}
22202220
2221+ // / Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2222+ // / local memory per sub-group.
2223+ // / Requires the sub-group size of kernel calling this function to be 32.
2224+ // / 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2225+ // / work items of sub-group contain the starting address of their respective
2226+ // / matrix row in 'addr'.
2227+ // / After distributing addresses to other work items, each of the 32 work items
2228+ // / store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2229+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2230+ // / item like below
2231+ // / Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2232+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2233+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2234+ // / ...
2235+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2236+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2237+ // / Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2238+ // / row-0: wi0 wi4 wi8 ... wi28
2239+ // / row-1: wi0 wi4 wi8 ... wi28
2240+ // / ...
2241+ // / row-6: wi3 wi7 wi11 ... wi31
2242+ // / row-7: wi3 wi7 wi11 ... wi31
2243+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2244+ // / \param [in] addr The starting address of corresponding matrix row for a work
2245+ // / item in local memory
2246+ // / \param [in] m The local memory to store the matrix. It points to 2 b16
2247+ // / type elements.
2248+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2249+ // / \param [in] mat The matrix index to be stored
2250+ template <typename T>
2251+ void stmatrix (uintptr_t addr, T m, bool trans = false , unsigned mat = 0 ) {
2252+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2253+ int lane = sg.get_local_linear_id ();
2254+
2255+ int lane_group8_row = lane / 8 ;
2256+ int lane_group8_col = lane % 8 ;
2257+
2258+ if (!trans) {
2259+ // calculate the source lane
2260+ int src_lane = 2 * lane_group8_row;
2261+ if (lane_group8_col >= 4 )
2262+ src_lane += 1 ;
2263+
2264+ // Broadcast the address from the source lane
2265+ auto recv_addr_uintp =
2266+ dpct::select_from_sub_group (sg, addr, mat * 8 + src_lane);
2267+
2268+ // Cast the received address from uintptr_t to the type of 'm'
2269+ auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
2270+
2271+ // Non-transposed store
2272+ recv_addr[lane_group8_col % 4 ] = m;
2273+ } else {
2274+ // calculate the source lane
2275+ int src_lane = (lane % 4 ) * 2 ;
2276+
2277+ // Broadcast the address from the source lane
2278+ auto recv_addr_uintp_1 =
2279+ dpct::select_from_sub_group (sg, addr, mat * 8 + src_lane);
2280+ auto recv_addr_uintp_2 =
2281+ dpct::select_from_sub_group (sg, addr, mat * 8 + src_lane + 1 );
2282+
2283+ // Cast the received address from uintptr_t to 'half *'
2284+ auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
2285+ auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
2286+
2287+ // Split the 32-bit value of 'm' into two 16-bits
2288+ sycl::half *val = reinterpret_cast <sycl::half *>(&m);
2289+
2290+ // Transposed store
2291+ int index = lane / 4 ;
2292+ recv_addr_1[index] = val[0 ];
2293+ recv_addr_2[index] = val[1 ];
2294+ }
2295+ }
2296+
2297+ // / Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2298+ // / local memory per sub-group.
2299+ // / Requires the sub-group size of kernel calling this function to be 32.
2300+ // / The first 16 work items of sub-group contain the starting address of their
2301+ // / respective matrix row in 'addr'.
2302+ // / After distributing addresses to other work items, each of the 32 work items
2303+ // / store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2304+ // / bytes.
2305+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2306+ // / item like below
2307+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2308+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2309+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2310+ // / ...
2311+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2312+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2313+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2314+ // / row-0: wi0 wi4 wi8 ... wi28
2315+ // / row-1: wi0 wi4 wi8 ... wi28
2316+ // / ...
2317+ // / row-6: wi3 wi7 wi11 ... wi31
2318+ // / row-7: wi3 wi7 wi11 ... wi31
2319+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2320+ // / \param [in] addr The starting address of corresponding matrix row for a work
2321+ // / item in local memory
2322+ // / \param [in] m1 The local memory to store the data of 1st matrix. It points
2323+ // / to 2 b16 type elements.
2324+ // / \param [in] m2 The local memory to store the data of 2nd matrix. It points
2325+ // / to 2 b16 type elements.
2326+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2327+ template <typename T>
2328+ void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
2329+ // Store 1st matrix
2330+ stmatrix (addr, m1, trans, 0 );
2331+ // Store 2nd matrix
2332+ stmatrix (addr, m2, trans, 1 );
2333+ }
2334+
2335+ // / Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2336+ // / local memory per sub-group.
2337+ // / Requires the sub-group size of kernel calling this function to be 32.
2338+ // / Each work item of sub-group contains the starting address of their
2339+ // / respective matrix row in 'addr'.
2340+ // / After distributing addresses to other work items, each of the 32 work items
2341+ // / store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2342+ // / of 512 bytes.
2343+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2344+ // / item like below
2345+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2346+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2347+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2348+ // / ...
2349+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2350+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2351+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2352+ // / row-0: wi0 wi4 wi8 ... wi28
2353+ // / row-1: wi0 wi4 wi8 ... wi28
2354+ // / ...
2355+ // / row-6: wi3 wi7 wi11 ... wi31
2356+ // / row-7: wi3 wi7 wi11 ... wi31
2357+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2358+ // / \param [in] addr The starting address of corresponding matrix row for a work
2359+ // / item in local memory
2360+ // / \param [in] m1 The local memory to store the data of 1st matrix. It points
2361+ // / to 2 b16 type elements.
2362+ // / \param [in] m2 The local memory to store the data of 2nd matrix. It points
2363+ // / to 2 b16 type elements.
2364+ // / \param [in] m3 The local memory to store the data of 3rd matrix. It points
2365+ // / to 2 b16 type elements.
2366+ // / \param [in] m4 The local memory to store the data of 4th matrix. It points
2367+ // / to 2 b16 type elements.
2368+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2369+ template <typename T>
2370+ void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
2371+ // Store 1st matrix
2372+ stmatrix (addr, m1, trans, 0 );
2373+ // Store 2nd matrix
2374+ stmatrix (addr, m2, trans, 1 );
2375+ // Store 3rd matrix
2376+ stmatrix (addr, m3, trans, 2 );
2377+ // Store 4th matrix
2378+ stmatrix (addr, m4, trans, 3 );
2379+ }
2380+
22212381// / A helper struct that defines the pack type for the input matrix fragments
22222382// / of mma() function based on the type of input matrix fragments.
22232383// / The MMAType struct is specialized for different types of input matrices.
0 commit comments