89 #ifndef viskores_cont_internal_ParallelRadixSort_h
90 #define viskores_cont_internal_ParallelRadixSort_h
122 inline size_t GetMaxThreads(
size_t num_bytes,
size_t available_cores)
124 const double CORES_PER_BYTE =
125 double(available_cores - 1) / double(BYTES_FOR_MAX_PARALLELISM - MIN_BYTES_FOR_PARALLEL);
126 const double Y_INTERCEPT = 1.0 - CORES_PER_BYTE * MIN_BYTES_FOR_PARALLEL;
128 const size_t num_cores = (size_t)(CORES_PER_BYTE *
double(num_bytes) + Y_INTERCEPT);
133 if (num_cores > available_cores)
135 return available_cores;
144 const size_t kOutBufferSize = 32;
147 template <
typename PlainType,
148 typename UnsignedType,
149 typename CompareType,
150 typename ValueManager,
152 struct ParallelRadixCompareInternal
154 inline static void reverse(UnsignedType& t) { (void)t; }
158 template <
typename PlainType,
typename Un
signedType,
typename ValueManager,
unsigned int Base>
159 struct ParallelRadixCompareInternal<PlainType,
161 std::greater<PlainType>,
165 inline static void reverse(UnsignedType& t) { t = ((1 << Base) - 1) - t; }
169 template <
typename PlainType,
170 typename CompareType,
171 typename UnsignedType,
173 typename ValueManager,
174 typename ThreaderType,
176 class ParallelRadixSortInternal
179 using CompareInternal =
180 ParallelRadixCompareInternal<PlainType, UnsignedType, CompareType, ValueManager, Base>;
182 ParallelRadixSortInternal();
183 ~ParallelRadixSortInternal();
185 void Init(PlainType* data,
size_t num_elems,
const ThreaderType& threader);
187 PlainType* Sort(PlainType* data, ValueManager* value_manager);
189 static void InitAndSort(PlainType* data,
191 const ThreaderType& threader,
192 ValueManager* value_manager);
195 CompareInternal compare_internal_;
201 UnsignedType*** out_buf_;
204 size_t *pos_bgn_, *pos_end_;
205 ValueManager* value_manager_;
206 ThreaderType threader_;
210 UnsignedType* SortInternal(UnsignedType* data, ValueManager* value_manager);
213 void ComputeRanges();
217 void ComputeHistogram(
unsigned int b, UnsignedType* src);
221 void Scatter(
unsigned int b, UnsignedType* src, UnsignedType* dst);
224 template <
typename PlainType,
225 typename CompareType,
226 typename UnsignedType,
228 typename ValueManager,
229 typename ThreaderType,
231 ParallelRadixSortInternal<PlainType,
237 Base>::ParallelRadixSortInternal()
247 assert(
sizeof(PlainType) ==
sizeof(UnsignedType));
250 template <
typename PlainType,
251 typename CompareType,
252 typename UnsignedType,
254 typename ValueManager,
255 typename ThreaderType,
257 ParallelRadixSortInternal<PlainType,
263 Base>::~ParallelRadixSortInternal()
268 template <
typename PlainType,
269 typename CompareType,
270 typename UnsignedType,
272 typename ValueManager,
273 typename ThreaderType,
275 void ParallelRadixSortInternal<PlainType,
286 for (
size_t i = 0; i < num_threads_; ++i)
291 for (
size_t i = 0; i < num_threads_; ++i)
293 for (
size_t j = 0; j < 1 << Base; ++j)
295 delete[] out_buf_[i][j];
297 delete[] out_buf_n_[i];
298 delete[] out_buf_[i];
307 pos_bgn_ = pos_end_ = NULL;
313 template <
typename PlainType,
314 typename CompareType,
315 typename UnsignedType,
317 typename ValueManager,
318 typename ThreaderType,
320 void ParallelRadixSortInternal<PlainType,
326 Base>::Init(PlainType* data,
328 const ThreaderType& threader)
333 threader_ = threader;
335 num_elems_ = num_elems;
338 utility::GetMaxThreads(num_elems_ *
sizeof(PlainType), threader_.GetAvailableCores());
340 tmp_ =
new UnsignedType[num_elems_];
341 histo_ =
new size_t*[num_threads_];
342 for (
size_t i = 0; i < num_threads_; ++i)
344 histo_[i] =
new size_t[1 << Base];
347 out_buf_ =
new UnsignedType**[num_threads_];
348 out_buf_n_ =
new size_t*[num_threads_];
349 for (
size_t i = 0; i < num_threads_; ++i)
351 out_buf_[i] =
new UnsignedType*[1 << Base];
352 out_buf_n_[i] =
new size_t[1 << Base];
353 for (
size_t j = 0; j < 1 << Base; ++j)
355 out_buf_[i][j] =
new UnsignedType[kOutBufferSize];
359 pos_bgn_ =
new size_t[num_threads_];
360 pos_end_ =
new size_t[num_threads_];
363 template <
typename PlainType,
364 typename CompareType,
365 typename UnsignedType,
367 typename ValueManager,
368 typename ThreaderType,
370 PlainType* ParallelRadixSortInternal<PlainType,
376 Base>::Sort(PlainType* data, ValueManager* value_manager)
378 UnsignedType* src =
reinterpret_cast<UnsignedType*
>(data);
379 UnsignedType* res = SortInternal(src, value_manager);
380 return reinterpret_cast<PlainType*
>(res);
383 template <
typename PlainType,
384 typename CompareType,
385 typename UnsignedType,
387 typename ValueManager,
388 typename ThreaderType,
390 void ParallelRadixSortInternal<PlainType,
396 Base>::InitAndSort(PlainType* data,
398 const ThreaderType& threader,
399 ValueManager* value_manager)
401 ParallelRadixSortInternal prs;
402 prs.Init(data, num_elems, threader);
403 const PlainType* res = prs.Sort(data, value_manager);
406 for (
size_t i = 0; i < num_elems; ++i)
411 template <
typename PlainType,
412 typename CompareType,
413 typename UnsignedType,
415 typename ValueManager,
416 typename ThreaderType,
418 UnsignedType* ParallelRadixSortInternal<PlainType,
424 Base>::SortInternal(UnsignedType* data,
425 ValueManager* value_manager)
428 value_manager_ = value_manager;
434 const size_t bits = CHAR_BIT *
sizeof(UnsignedType);
435 UnsignedType *src = data, *dst = tmp_;
436 for (
unsigned int b = 0; b < bits; b += Base)
438 ComputeHistogram(b, src);
439 Scatter(b, src, dst);
442 value_manager->Next();
448 template <
typename PlainType,
449 typename CompareType,
450 typename UnsignedType,
452 typename ValueManager,
453 typename ThreaderType,
455 void ParallelRadixSortInternal<PlainType,
461 Base>::ComputeRanges()
464 for (
size_t i = 0; i < num_threads_ - 1; ++i)
466 const size_t t = (num_elems_ - pos_bgn_[i]) / (num_threads_ - i);
467 pos_bgn_[i + 1] = pos_end_[i] = pos_bgn_[i] + t;
469 pos_end_[num_threads_ - 1] = num_elems_;
472 template <
typename PlainType,
473 typename UnsignedType,
477 typename ThreaderType>
480 RunTask(
size_t binary_tree_height,
481 size_t binary_tree_position,
485 const ThreaderType& threader)
486 : binary_tree_height_(binary_tree_height)
487 , binary_tree_position_(binary_tree_position)
489 , num_elems_(num_elems)
490 , num_threads_(num_threads)
491 , threader_(threader)
495 template <
typename ThreaderData =
void*>
496 void operator()(ThreaderData tData =
nullptr)
const
498 size_t num_nodes_at_current_height = (size_t)pow(2, (
double)binary_tree_height_);
499 if (num_threads_ <= num_nodes_at_current_height)
501 const size_t my_id = binary_tree_position_ - num_nodes_at_current_height;
502 if (my_id < num_threads_)
509 RunTask left(binary_tree_height_ + 1,
510 2 * binary_tree_position_,
515 RunTask right(binary_tree_height_ + 1,
516 2 * binary_tree_position_ + 1,
521 threader_.RunChildTasks(tData, left, right);
525 size_t binary_tree_height_;
526 size_t binary_tree_position_;
530 ThreaderType threader_;
533 template <
typename PlainType,
534 typename CompareType,
535 typename UnsignedType,
537 typename ValueManager,
538 typename ThreaderType,
540 void ParallelRadixSortInternal<PlainType,
546 Base>::ComputeHistogram(
unsigned int b, UnsignedType* src)
550 auto lambda = [=](
const size_t my_id)
552 const size_t my_bgn = pos_bgn_[my_id];
553 const size_t my_end = pos_end_[my_id];
554 size_t* my_histo = histo_[my_id];
556 memset(my_histo, 0,
sizeof(
size_t) * (1 << Base));
557 for (
size_t i = my_bgn; i < my_end; ++i)
559 const UnsignedType s = Encoder::encode(src[i]);
560 UnsignedType t = (s >> b) & ((1 << Base) - 1);
561 compare_internal_.reverse(t);
567 RunTask<PlainType, UnsignedType, Encoder, Base, std::function<void(
size_t)>, ThreaderType>;
569 RunTaskType root(0, 1, lambda, num_elems_, num_threads_, threader_);
570 this->threader_.RunParentTask(root);
574 for (
size_t i = 0; i < 1 << Base; ++i)
576 for (
size_t j = 0; j < num_threads_; ++j)
578 const size_t t = s + histo_[j][i];
585 template <
typename PlainType,
586 typename CompareType,
587 typename UnsignedType,
589 typename ValueManager,
590 typename ThreaderType,
592 void ParallelRadixSortInternal<PlainType,
598 Base>::Scatter(
unsigned int b, UnsignedType* src, UnsignedType* dst)
601 auto lambda = [=](
const size_t my_id)
603 const size_t my_bgn = pos_bgn_[my_id];
604 const size_t my_end = pos_end_[my_id];
605 size_t* my_histo = histo_[my_id];
606 UnsignedType** my_buf = out_buf_[my_id];
607 size_t* my_buf_n = out_buf_n_[my_id];
609 memset(my_buf_n, 0,
sizeof(
size_t) * (1 << Base));
610 for (
size_t i = my_bgn; i < my_end; ++i)
612 const UnsignedType s = Encoder::encode(src[i]);
613 UnsignedType t = (s >> b) & ((1 << Base) - 1);
614 compare_internal_.reverse(t);
615 my_buf[t][my_buf_n[t]] = src[i];
616 value_manager_->Push(my_id, t, my_buf_n[t], i);
619 if (my_buf_n[t] == kOutBufferSize)
621 size_t p = my_histo[t];
622 for (
size_t j = 0; j < kOutBufferSize; ++j)
625 dst[tp] = my_buf[t][j];
627 value_manager_->Flush(my_id, t, kOutBufferSize, my_histo[t]);
629 my_histo[t] += kOutBufferSize;
635 for (
size_t i = 0; i < 1 << Base; ++i)
637 size_t p = my_histo[i];
638 for (
size_t j = 0; j < my_buf_n[i]; ++j)
641 dst[tp] = my_buf[i][j];
643 value_manager_->Flush(my_id, i, my_buf_n[i], my_histo[i]);
648 RunTask<PlainType, UnsignedType, Encoder, Base, std::function<void(
size_t)>, ThreaderType>;
649 RunTaskType root(0, 1, lambda, num_elems_, num_threads_, threader_);
650 this->threader_.RunParentTask(root);
662 class EncoderUnsigned
665 template <
typename Un
signedType>
666 inline static UnsignedType encode(UnsignedType x)
675 template <
typename Un
signedType>
676 inline static UnsignedType encode(UnsignedType x)
678 return x ^ (UnsignedType(1) << (CHAR_BIT *
sizeof(UnsignedType) - 1));
685 template <
typename Un
signedType>
686 inline static UnsignedType encode(UnsignedType x)
688 static const size_t bits = CHAR_BIT *
sizeof(UnsignedType);
689 const UnsignedType a = x >> (bits - 1);
690 const UnsignedType b = (-
static_cast<int>(a)) | (UnsignedType(1) << (bits - 1));
698 namespace value_manager
700 class DummyValueManager
703 inline void Push(
int thread,
size_t bucket,
size_t num,
size_t from_pos)
711 inline void Flush(
int thread,
size_t bucket,
size_t num,
size_t to_pos)
722 template <
typename PlainType,
typename ValueType,
int Base>
723 class PairValueManager
738 ~PairValueManager() { DeleteAll(); }
740 void Init(
size_t max_elems,
size_t available_threads);
742 void Start(ValueType* original,
size_t num_elems)
744 assert(num_elems <= max_elems_);
745 src_ = original_ = original;
749 inline void Push(
int thread,
size_t bucket,
size_t num,
size_t from_pos)
751 out_buf_[thread][bucket][num] = src_[from_pos];
754 inline void Flush(
int thread,
size_t bucket,
size_t num,
size_t to_pos)
756 for (
size_t i = 0; i < num; ++i)
758 dst_[to_pos++] = out_buf_[thread][bucket][i];
762 void Next() { std::swap(src_, dst_); }
764 ValueType* GetResult() {
return src_; }
770 static constexpr
size_t kOutBufferSize = internal::kOutBufferSize;
771 ValueType *original_, *tmp_;
772 ValueType *src_, *dst_;
773 ValueType*** out_buf_;
779 template <
typename PlainType,
typename ValueType,
int Base>
780 void PairValueManager<PlainType, ValueType, Base>::Init(
size_t max_elems,
size_t available_cores)
784 max_elems_ = max_elems;
785 max_threads_ = utility::GetMaxThreads(max_elems_ *
sizeof(PlainType), available_cores);
788 tmp_size = max_elems *
sizeof(ValueType);
790 "Allocating working memory for radix sort-by-key: %s.",
792 tmp_ =
new ValueType[max_elems];
795 out_buf_ =
new ValueType**[max_threads_];
796 for (
int i = 0; i < max_threads_; ++i)
798 out_buf_[i] =
new ValueType*[1 << Base];
799 for (
size_t j = 0; j < 1 << Base; ++j)
801 out_buf_[i][j] =
new ValueType[kOutBufferSize];
806 template <
typename PlainType,
typename ValueType,
int Base>
807 void PairValueManager<PlainType, ValueType, Base>::DeleteAll()
811 "Freeing working memory for radix sort-by-key: %s.",
818 for (
int i = 0; i < max_threads_; ++i)
820 for (
size_t j = 0; j < 1 << Base; ++j)
822 delete[] out_buf_[i][j];
824 delete[] out_buf_[i];
835 template <
typename ThreaderType,
837 typename CompareType,
838 typename UnsignedType = PlainType,
839 typename Encoder = encoder::EncoderDummy,
840 unsigned int Base = 8>
843 using DummyValueManager = value_manager::DummyValueManager;
844 using Internal = internal::ParallelRadixSortInternal<PlainType,
853 void InitAndSort(PlainType* data,
855 const ThreaderType& threader,
856 const CompareType& comp)
859 DummyValueManager dvm;
860 Internal::InitAndSort(data, num_elems, threader, &dvm);
865 template <
typename ThreaderType,
868 typename CompareType,
869 typename UnsignedType = PlainType,
870 typename Encoder = encoder::EncoderDummy,
874 using ValueManager = value_manager::PairValueManager<PlainType, ValueType, Base>;
875 using Internal = internal::ParallelRadixSortInternal<PlainType,
884 void InitAndSort(PlainType* keys,
887 const ThreaderType& threader,
888 const CompareType& comp)
892 vm.Init(num_elems, threader.GetAvailableCores());
893 vm.Start(vals, num_elems);
894 Internal::InitAndSort(keys, num_elems, threader, &vm);
895 ValueType* res_vals = vm.GetResult();
896 if (res_vals != vals)
898 for (
size_t i = 0; i < num_elems; ++i)
900 vals[i] = res_vals[i];
908 #define KEY_SORT_CASE(plain_type, compare_type, unsigned_type, encoder_type) \
909 template <typename ThreaderType> \
910 class KeySort<ThreaderType, plain_type, compare_type> \
911 : public KeySort<ThreaderType, \
915 encoder::Encoder##encoder_type> \
918 template <typename V, typename ThreaderType> \
919 class PairSort<ThreaderType, plain_type, V, compare_type> \
920 : public PairSort<ThreaderType, \
925 encoder::Encoder##encoder_type> \
930 KEY_SORT_CASE(
unsigned int, std::less<unsigned int>,
unsigned int, Unsigned);
931 KEY_SORT_CASE(
unsigned int, std::greater<unsigned int>,
unsigned int, Unsigned);
932 KEY_SORT_CASE(
unsigned short int, std::less<unsigned short int>,
unsigned short int, Unsigned);
933 KEY_SORT_CASE(
unsigned short int, std::greater<unsigned short int>,
unsigned short int, Unsigned);
934 KEY_SORT_CASE(
unsigned long int, std::less<unsigned long int>,
unsigned long int, Unsigned);
935 KEY_SORT_CASE(
unsigned long int, std::greater<unsigned long int>,
unsigned long int, Unsigned);
937 std::less<unsigned long long int>,
938 unsigned long long int,
941 std::greater<unsigned long long int>,
942 unsigned long long int,
946 KEY_SORT_CASE(
unsigned char, std::less<unsigned char>,
unsigned char, Unsigned);
947 KEY_SORT_CASE(
unsigned char, std::greater<unsigned char>,
unsigned char, Unsigned);
948 KEY_SORT_CASE(char16_t, std::less<char16_t>, uint16_t, Unsigned);
949 KEY_SORT_CASE(char16_t, std::greater<char16_t>, uint16_t, Unsigned);
950 KEY_SORT_CASE(char32_t, std::less<char32_t>, uint32_t, Unsigned);
951 KEY_SORT_CASE(char32_t, std::greater<char32_t>, uint32_t, Unsigned);
952 KEY_SORT_CASE(
wchar_t, std::less<wchar_t>, uint32_t, Unsigned);
953 KEY_SORT_CASE(
wchar_t, std::greater<wchar_t>, uint32_t, Unsigned);
957 KEY_SORT_CASE(
char, std::greater<char>,
unsigned char, Signed);
958 KEY_SORT_CASE(
short, std::less<short>,
unsigned short, Signed);
959 KEY_SORT_CASE(
short, std::greater<short>,
unsigned short, Signed);
963 KEY_SORT_CASE(
long, std::greater<long>,
unsigned long, Signed);
964 KEY_SORT_CASE(
long long, std::less<long long>,
unsigned long long, Signed);
965 KEY_SORT_CASE(
long long, std::greater<long long>,
unsigned long long, Signed);
968 KEY_SORT_CASE(
signed char, std::less<signed char>,
unsigned char, Signed);
969 KEY_SORT_CASE(
signed char, std::greater<signed char>,
unsigned char, Signed);
973 KEY_SORT_CASE(
float, std::greater<float>, uint32_t, Decimal);
975 KEY_SORT_CASE(
double, std::greater<double>, uint64_t, Decimal);
979 template <
typename T,
typename CompareType>
980 struct run_kx_radix_sort_keys
982 static void run(T* data,
size_t num_elems,
const CompareType& comp)
984 std::sort(data, data + num_elems, comp);
988 #define KX_SORT_KEYS(key_type) \
990 struct run_kx_radix_sort_keys<key_type, std::less<key_type>> \
992 static void run(key_type* data, size_t num_elems, const std::less<key_type>& comp) \
995 kx::radix_sort(data, data + num_elems); \
1010 template <
typename T,
typename CompareType>
1011 bool use_serial_sort_keys(T* data,
size_t num_elems,
const CompareType& comp)
1013 size_t total_bytes = (num_elems) *
sizeof(T);
1014 if (total_bytes < MIN_BYTES_FOR_PARALLEL)
1016 run_kx_radix_sort_keys<T, CompareType>::run(data, num_elems, comp);
1023 #define VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(threader_type, key_type) \
1024 VISKORES_CONT_EXPORT void parallel_radix_sort_key_values( \
1025 key_type* keys, viskores::Id* vals, size_t num_elems, const std::greater<key_type>& comp) \
1027 using namespace viskores::cont::internal::radix; \
1028 PairSort<threader_type, key_type, viskores::Id, std::greater<key_type>> ps; \
1029 ps.InitAndSort(keys, vals, num_elems, threader_type(), comp); \
1031 VISKORES_CONT_EXPORT void parallel_radix_sort_key_values( \
1032 key_type* keys, viskores::Id* vals, size_t num_elems, const std::less<key_type>& comp) \
1034 using namespace viskores::cont::internal::radix; \
1035 PairSort<threader_type, key_type, viskores::Id, std::less<key_type>> ps; \
1036 ps.InitAndSort(keys, vals, num_elems, threader_type(), comp); \
1038 VISKORES_CONT_EXPORT void parallel_radix_sort( \
1039 key_type* data, size_t num_elems, const std::greater<key_type>& comp) \
1041 using namespace viskores::cont::internal::radix; \
1042 if (!use_serial_sort_keys(data, num_elems, comp)) \
1044 KeySort<threader_type, key_type, std::greater<key_type>> ks; \
1045 ks.InitAndSort(data, num_elems, threader_type(), comp); \
1048 VISKORES_CONT_EXPORT void parallel_radix_sort( \
1049 key_type* data, size_t num_elems, const std::less<key_type>& comp) \
1051 using namespace viskores::cont::internal::radix; \
1052 if (!use_serial_sort_keys(data, num_elems, comp)) \
1054 KeySort<threader_type, key_type, std::less<key_type>> ks; \
1055 ks.InitAndSort(data, num_elems, threader_type(), comp); \
1059 #define VISKORES_INSTANTIATE_RADIX_SORT_FOR_THREADER(ThreaderType) \
1060 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, short int) \
1061 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned short int) \
1062 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, int) \
1063 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned int) \
1064 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, long int) \
1065 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned long int) \
1066 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, long long int) \
1067 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned long long int) \
1068 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned char) \
1069 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, signed char) \
1070 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, char) \
1071 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, char16_t) \
1072 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, char32_t) \
1073 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, wchar_t) \
1074 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, float) \
1075 VISKORES_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, double)
1083 #endif // viskores_cont_internal_ParallelRadixSort_h