HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_cooperative_groups.h
1/*
2Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved.
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
34
35#if defined(__clang__)
36#pragma clang diagnostic push
37#pragma clang diagnostic ignored "-Wc++98-compat"
38#pragma clang diagnostic ignored "-Wsign-conversion"
39#pragma clang diagnostic ignored "-Wunused-parameter"
40#pragma clang diagnostic ignored "-Wreserved-macro-identifier"
41#pragma clang diagnostic ignored "-Wpadded"
42#endif
43
44#if __cplusplus
45#if !defined(__HIPCC_RTC__)
47#endif
48
49#define __hip_abort() \
50 { asm("trap;"); }
51#if defined(NDEBUG)
52#define __hip_assert(COND)
53#else
54#define __hip_assert(COND) \
55 { \
56 if (!COND) { \
57 __hip_abort(); \
58 } \
59 }
60#endif
61
62namespace cooperative_groups {
63
72class thread_group {
73 protected:
74 uint32_t _type; // thread_group type
75 uint32_t _size; // total number of threads in the tread_group
76 uint64_t _mask; // Lanemask for coalesced and tiled partitioned group types,
77 // LSB represents lane 0, and MSB represents lane 63
78
79 // Construct a thread group, and set thread group type and other essential
80 // thread group properties. This generic thread group is directly constructed
81 // only when the group is supposed to contain only the calling the thread
82 // (throurh the API - `this_thread()`), and in all other cases, this thread
83 // group object is a sub-object of some other derived thread group object
84 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size = static_cast<uint64_t>(0),
85 uint64_t mask = static_cast<uint64_t>(0)) {
86 _type = type;
87 _size = size;
88 _mask = mask;
89 }
90
91 struct _tiled_info {
92 bool is_tiled;
93 unsigned int size;
94 unsigned int meta_group_rank;
95 unsigned int meta_group_size;
96 };
97
98 struct _coalesced_info {
99 lane_mask member_mask;
100 unsigned int size;
101 struct _tiled_info tiled_info;
102 } coalesced_info;
103
104 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
105 unsigned int tile_size);
106 friend class thread_block;
107
108 public:
109 // Total number of threads in the thread group, and this serves the purpose
110 // for all derived cooperative group types since their `size` is directly
111 // saved during the construction
112 __CG_QUALIFIER__ uint32_t size() const { return _size; }
113 __CG_QUALIFIER__ unsigned int cg_type() const { return _type; }
114 // Rank of the calling thread within [0, size())
115 __CG_QUALIFIER__ uint32_t thread_rank() const;
116 // Is this cooperative group type valid?
117 __CG_QUALIFIER__ bool is_valid() const;
118 // synchronize the threads in the thread group
119 __CG_QUALIFIER__ void sync() const;
120};
144class multi_grid_group : public thread_group {
145 // Only these friend functions are allowed to construct an object of this class
146 // and access its resources
147 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
148
149 protected:
150 // Construct mutli-grid thread group (through the API this_multi_grid())
151 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
152 : thread_group(internal::cg_multi_grid, size) {}
153
154 public:
155 // Number of invocations participating in this multi-grid group. In other
156 // words, the number of GPUs
157 __CG_QUALIFIER__ uint32_t num_grids() { return internal::multi_grid::num_grids(); }
158 // Rank of this invocation. In other words, an ID number within the range
159 // [0, num_grids()) of the GPU, this kernel is running on
160 __CG_QUALIFIER__ uint32_t grid_rank() { return internal::multi_grid::grid_rank(); }
161 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::multi_grid::thread_rank(); }
162 __CG_QUALIFIER__ bool is_valid() const { return internal::multi_grid::is_valid(); }
163 __CG_QUALIFIER__ void sync() const { internal::multi_grid::sync(); }
164};
165
175__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
176 return multi_grid_group(internal::multi_grid::size());
177}
178
187class grid_group : public thread_group {
188 // Only these friend functions are allowed to construct an object of this class
189 // and access its resources
190 friend __CG_QUALIFIER__ grid_group this_grid();
191
192 protected:
193 // Construct grid thread group (through the API this_grid())
194 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
195
196 public:
197 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::grid::thread_rank(); }
198 __CG_QUALIFIER__ bool is_valid() const { return internal::grid::is_valid(); }
199 __CG_QUALIFIER__ void sync() const { internal::grid::sync(); }
200};
201
211__CG_QUALIFIER__ grid_group this_grid() { return grid_group(internal::grid::size()); }
212
222class thread_block : public thread_group {
223 // Only these friend functions are allowed to construct an object of thi
224 // class and access its resources
225 friend __CG_QUALIFIER__ thread_block this_thread_block();
226 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
227 unsigned int tile_size);
228 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent,
229 unsigned int tile_size);
230 protected:
231 // Construct a workgroup thread group (through the API this_thread_block())
232 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
233 : thread_group(internal::cg_workgroup, size) {}
234
235 __CG_QUALIFIER__ thread_group new_tiled_group(unsigned int tile_size) const {
236 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
237 // Invalid tile size, assert
238 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
239 __hip_assert(false && "invalid tile size");
240 }
241
242 thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size);
243 tiledGroup.coalesced_info.tiled_info.size = tile_size;
244 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
245 tiledGroup.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
246 tiledGroup.coalesced_info.tiled_info.meta_group_size = (size() + tile_size - 1) / tile_size;
247 return tiledGroup;
248 }
249
250 public:
251 // 3-dimensional block index within the grid
252 __CG_STATIC_QUALIFIER__ dim3 group_index() { return internal::workgroup::group_index(); }
253 // 3-dimensional thread index within the block
254 __CG_STATIC_QUALIFIER__ dim3 thread_index() { return internal::workgroup::thread_index(); }
255 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() { return internal::workgroup::thread_rank(); }
256 __CG_STATIC_QUALIFIER__ uint32_t size() { return internal::workgroup::size(); }
257 __CG_STATIC_QUALIFIER__ bool is_valid() { return internal::workgroup::is_valid(); }
258 __CG_STATIC_QUALIFIER__ void sync() { internal::workgroup::sync(); }
259 __CG_QUALIFIER__ dim3 group_dim() { return internal::workgroup::block_dim(); }
260};
261
271__CG_QUALIFIER__ thread_block this_thread_block() {
272 return thread_block(internal::workgroup::size());
273}
274
283class tiled_group : public thread_group {
284 private:
285 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
286 unsigned int tile_size);
287 friend __CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent,
288 unsigned int tile_size);
289
290 __CG_QUALIFIER__ tiled_group new_tiled_group(unsigned int tile_size) const {
291 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
292
293 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
294 __hip_assert(false && "invalid tile size");
295 }
296
297 if (size() <= tile_size) {
298 return *this;
299 }
300
301 tiled_group tiledGroup = tiled_group(tile_size);
302 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
303 return tiledGroup;
304 }
305
306 protected:
307 explicit __CG_QUALIFIER__ tiled_group(unsigned int tileSize)
308 : thread_group(internal::cg_tiled_group, tileSize) {
309 coalesced_info.tiled_info.size = tileSize;
310 coalesced_info.tiled_info.is_tiled = true;
311 }
312
313 public:
314 __CG_QUALIFIER__ unsigned int size() const { return (coalesced_info.tiled_info.size); }
315
316 __CG_QUALIFIER__ unsigned int thread_rank() const {
317 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
318 }
319
320 __CG_QUALIFIER__ void sync() const {
321 internal::tiled_group::sync();
322 }
323};
324
332class coalesced_group : public thread_group {
333 private:
334 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
335 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size);
336 friend __CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size);
337
338 __CG_QUALIFIER__ coalesced_group new_tiled_group(unsigned int tile_size) const {
339 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
340
341 if (!tile_size || (tile_size > size()) || !pow2) {
342 return coalesced_group(0);
343 }
344
345 // If a tiled group is passed to be partitioned further into a coalesced_group.
346 // prepare a mask for further partitioning it so that it stays coalesced.
347 if (coalesced_info.tiled_info.is_tiled) {
348 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
349 unsigned int masklength = min(static_cast<unsigned int>(size()) - base_offset, tile_size);
350 lane_mask member_mask = static_cast<lane_mask>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
351
352 member_mask <<= (__lane_id() & ~(tile_size - 1));
353 coalesced_group coalesced_tile = coalesced_group(member_mask);
354 coalesced_tile.coalesced_info.tiled_info.is_tiled = true;
355 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
356 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
357 return coalesced_tile;
358 }
359 // Here the parent coalesced_group is not partitioned.
360 else {
361 lane_mask member_mask = 0;
362 unsigned int tile_rank = 0;
363 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
364
365 for (unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
366 lane_mask active = coalesced_info.member_mask & (1 << i);
367 // Make sure the lane is active
368 if (active) {
369 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
370 // Prepare a member_mask that is appropriate for a tile
371 member_mask |= active;
372 tile_rank++;
373 }
374 lanes_to_skip--;
375 }
376 }
377 coalesced_group coalesced_tile = coalesced_group(member_mask);
378 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
379 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
380 (size() + tile_size - 1) / tile_size;
381 return coalesced_tile;
382 }
383 return coalesced_group(0);
384 }
385
386 protected:
387 // Constructor
388 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
389 : thread_group(internal::cg_coalesced_group) {
390 coalesced_info.member_mask = member_mask; // Which threads are active
391 coalesced_info.size = __popcll(coalesced_info.member_mask); // How many threads are active
392 coalesced_info.tiled_info.is_tiled = false; // Not a partitioned group
393 }
394
395 public:
396 __CG_QUALIFIER__ unsigned int size() const {
397 return coalesced_info.size;
398 }
399
400 __CG_QUALIFIER__ unsigned int thread_rank() const {
401 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
402 }
403
404 __CG_QUALIFIER__ void sync() const {
405 internal::coalesced_group::sync();
406 }
407
408 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
409 return coalesced_info.tiled_info.meta_group_rank;
410 }
411
412 __CG_QUALIFIER__ unsigned int meta_group_size() const {
413 return coalesced_info.tiled_info.meta_group_size;
414 }
415
416 template <class T>
417 __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
418 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
419
420 srcRank = srcRank % static_cast<int>(size());
421
422 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
423 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
424 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
425
426 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
427 }
428
429 template <class T>
430 __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
431 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
432
433 // Note: The cuda implementation appears to use the remainder of lane_delta
434 // and WARP_SIZE as the shift value rather than lane_delta itself.
435 // This is not described in the documentation and is not done here.
436
437 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
438 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
439 }
440
441 int lane;
442 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
443 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
444 }
445 else {
446 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
447 }
448
449 if (lane == -1) {
450 lane = __lane_id();
451 }
452
453 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
454 }
455
456 template <class T>
457 __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
458 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
459
460 // Note: The cuda implementation appears to use the remainder of lane_delta
461 // and WARP_SIZE as the shift value rather than lane_delta itself.
462 // This is not described in the documentation and is not done here.
463
464 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
465 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
466 }
467
468 int lane;
469 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
470 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
471 }
472 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
473 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
474 }
475
476 if (lane == -1) {
477 lane = __lane_id();
478 }
479
480 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
481 }
482};
483
491__CG_QUALIFIER__ coalesced_group coalesced_threads() {
492 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
493}
494
500__CG_QUALIFIER__ uint32_t thread_group::thread_rank() const {
501 switch (this->_type) {
502 case internal::cg_multi_grid: {
503 return (static_cast<const multi_grid_group*>(this)->thread_rank());
504 }
505 case internal::cg_grid: {
506 return (static_cast<const grid_group*>(this)->thread_rank());
507 }
508 case internal::cg_workgroup: {
509 return (static_cast<const thread_block*>(this)->thread_rank());
510 }
511 case internal::cg_tiled_group: {
512 return (static_cast<const tiled_group*>(this)->thread_rank());
513 }
514 case internal::cg_coalesced_group: {
515 return (static_cast<const coalesced_group*>(this)->thread_rank());
516 }
517 default: {
518 __hip_assert(false && "invalid cooperative group type");
519 return -1;
520 }
521 }
522}
528__CG_QUALIFIER__ bool thread_group::is_valid() const {
529 switch (this->_type) {
530 case internal::cg_multi_grid: {
531 return (static_cast<const multi_grid_group*>(this)->is_valid());
532 }
533 case internal::cg_grid: {
534 return (static_cast<const grid_group*>(this)->is_valid());
535 }
536 case internal::cg_workgroup: {
537 return (static_cast<const thread_block*>(this)->is_valid());
538 }
539 case internal::cg_tiled_group: {
540 return (static_cast<const tiled_group*>(this)->is_valid());
541 }
542 case internal::cg_coalesced_group: {
543 return (static_cast<const coalesced_group*>(this)->is_valid());
544 }
545 default: {
546 __hip_assert(false && "invalid cooperative group type");
547 return false;
548 }
549 }
550}
556__CG_QUALIFIER__ void thread_group::sync() const {
557 switch (this->_type) {
558 case internal::cg_multi_grid: {
559 static_cast<const multi_grid_group*>(this)->sync();
560 break;
561 }
562 case internal::cg_grid: {
563 static_cast<const grid_group*>(this)->sync();
564 break;
565 }
566 case internal::cg_workgroup: {
567 static_cast<const thread_block*>(this)->sync();
568 break;
569 }
570 case internal::cg_tiled_group: {
571 static_cast<const tiled_group*>(this)->sync();
572 break;
573 }
574 case internal::cg_coalesced_group: {
575 static_cast<const coalesced_group*>(this)->sync();
576 break;
577 }
578 default: {
579 __hip_assert(false && "invalid cooperative group type");
580 }
581 }
582}
583
590template <class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy const& g) { return g.size(); }
597template <class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy const& g) {
598 return g.thread_rank();
599}
606template <class CGTy> __CG_QUALIFIER__ bool is_valid(CGTy const& g) { return g.is_valid(); }
613template <class CGTy> __CG_QUALIFIER__ void sync(CGTy const& g) { g.sync(); }
619template <unsigned int tileSize> class tile_base {
620 protected:
621 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
622
623 public:
624 // Rank of the thread within this tile
625 _CG_STATIC_CONST_DECL_ unsigned int thread_rank() {
626 return (internal::workgroup::thread_rank() & (numThreads - 1));
627 }
628
629 // Number of threads within this tile
630 __CG_STATIC_QUALIFIER__ unsigned int size() { return numThreads; }
631};
637template <unsigned int size> class thread_block_tile_base : public tile_base<size> {
638 static_assert(is_valid_tile_size<size>::value,
639 "Tile size is either not a power of 2 or greater than the wavefront size");
640 using tile_base<size>::numThreads;
641
642 public:
643 __CG_STATIC_QUALIFIER__ void sync() {
644 internal::tiled_group::sync();
645 }
646
647 template <class T> __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
648 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
649 return (__shfl(var, srcRank, numThreads));
650 }
651
652 template <class T> __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
653 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
654 return (__shfl_down(var, lane_delta, numThreads));
655 }
656
657 template <class T> __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
658 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
659 return (__shfl_up(var, lane_delta, numThreads));
660 }
661
662 template <class T> __CG_QUALIFIER__ T shfl_xor(T var, unsigned int laneMask) const {
663 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
664 return (__shfl_xor(var, laneMask, numThreads));
665 }
666};
669template <unsigned int tileSize, typename ParentCGTy>
670class parent_group_info {
671public:
672 // Returns the linear rank of the group within the set of tiles partitioned
673 // from a parent group (bounded by meta_group_size)
674 __CG_STATIC_QUALIFIER__ unsigned int meta_group_rank() {
675 return ParentCGTy::thread_rank() / tileSize;
676 }
677
678 // Returns the number of groups created when the parent group was partitioned.
679 __CG_STATIC_QUALIFIER__ unsigned int meta_group_size() {
680 return (ParentCGTy::size() + tileSize - 1) / tileSize;
681 }
682};
683
690template <unsigned int tileSize, class ParentCGTy>
691class thread_block_tile_type : public thread_block_tile_base<tileSize>,
692 public tiled_group,
693 public parent_group_info<tileSize, ParentCGTy> {
694 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
695 protected:
696 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
697 coalesced_info.tiled_info.size = numThreads;
698 coalesced_info.tiled_info.is_tiled = true;
699 }
700};
701
702// Partial template specialization
703template <unsigned int tileSize>
704class thread_block_tile_type<tileSize, void> : public thread_block_tile_base<tileSize>,
705 public tiled_group
706 {
707 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
708
709 typedef thread_block_tile_base<numThreads> tbtBase;
710
711 protected:
712
713 __CG_QUALIFIER__ thread_block_tile_type(unsigned int meta_group_rank, unsigned int meta_group_size)
714 : tiled_group(numThreads) {
715 coalesced_info.tiled_info.size = numThreads;
716 coalesced_info.tiled_info.is_tiled = true;
717 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
718 coalesced_info.tiled_info.meta_group_size = meta_group_size;
719 }
720
721 public:
722 using tbtBase::size;
723 using tbtBase::sync;
724 using tbtBase::thread_rank;
725
726 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
727 return coalesced_info.tiled_info.meta_group_rank;
728 }
729
730 __CG_QUALIFIER__ unsigned int meta_group_size() const {
731 return coalesced_info.tiled_info.meta_group_size;
732 }
733// end of operative group
737};
738
739
746__CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size) {
747 if (parent.cg_type() == internal::cg_tiled_group) {
748 const tiled_group* cg = static_cast<const tiled_group*>(&parent);
749 return cg->new_tiled_group(tile_size);
750 }
751 else if(parent.cg_type() == internal::cg_coalesced_group) {
752 const coalesced_group* cg = static_cast<const coalesced_group*>(&parent);
753 return cg->new_tiled_group(tile_size);
754 }
755 else {
756 const thread_block* tb = static_cast<const thread_block*>(&parent);
757 return tb->new_tiled_group(tile_size);
758 }
759}
760
761// Thread block type overload
762__CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent, unsigned int tile_size) {
763 return (parent.new_tiled_group(tile_size));
764}
765
766__CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent, unsigned int tile_size) {
767 return (parent.new_tiled_group(tile_size));
768}
769
770// If a coalesced group is passed to be partitioned, it should remain coalesced
771__CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size) {
772 return (parent.new_tiled_group(tile_size));
773}
774
775template <unsigned int size, class ParentCGTy> class thread_block_tile;
776
777namespace impl {
778template <unsigned int size, class ParentCGTy> class thread_block_tile_internal;
779
780template <unsigned int size, class ParentCGTy>
781class thread_block_tile_internal : public thread_block_tile_type<size, ParentCGTy> {
782 protected:
783 template <unsigned int tbtSize, class tbtParentT>
784 __CG_QUALIFIER__ thread_block_tile_internal(
785 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
786 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
787
788 __CG_QUALIFIER__ thread_block_tile_internal(const thread_block& g)
789 : thread_block_tile_type<size, ParentCGTy>() {}
790};
791} // namespace impl
792
793template <unsigned int size, class ParentCGTy>
794class thread_block_tile : public impl::thread_block_tile_internal<size, ParentCGTy> {
795 protected:
796 __CG_QUALIFIER__ thread_block_tile(const ParentCGTy& g)
797 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
798
799 public:
800 __CG_QUALIFIER__ operator thread_block_tile<size, void>() const {
801 return thread_block_tile<size, void>(*this);
802 }
803};
804
805
806template <unsigned int size>
807class thread_block_tile<size, void> : public impl::thread_block_tile_internal<size, void> {
808 template <unsigned int, class ParentCGTy> friend class thread_block_tile;
809
810 protected:
811 public:
812 template <class ParentCGTy>
813 __CG_QUALIFIER__ thread_block_tile(const thread_block_tile<size, ParentCGTy>& g)
814 : impl::thread_block_tile_internal<size, void>(g) {}
815};
816
817template <unsigned int size, class ParentCGTy = void> class thread_block_tile;
818
819namespace impl {
820template <unsigned int size, class ParentCGTy> struct tiled_partition_internal;
821
822template <unsigned int size>
823struct tiled_partition_internal<size, thread_block> : public thread_block_tile<size, thread_block> {
824 __CG_QUALIFIER__ tiled_partition_internal(const thread_block& g)
825 : thread_block_tile<size, thread_block>(g) {}
826};
827
828} // namespace impl
829
835template <unsigned int size, class ParentCGTy>
836__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(const ParentCGTy& g) {
837 static_assert(is_valid_tile_size<size>::value,
838 "Tiled partition with size > wavefront size. Currently not supported ");
839 return impl::tiled_partition_internal<size, ParentCGTy>(g);
840}
841} // namespace cooperative_groups
842
843#if defined(__clang__)
844#pragma clang diagnostic pop
845#endif
846
847#endif // __cplusplus
848#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
Device side implementation of cooperative group feature.