18 #ifndef viskores_cont_tbb_internal_FunctorsTBB_h
19 #define viskores_cont_tbb_internal_FunctorsTBB_h
31 #include <type_traits>
35 #if defined(VISKORES_MSVC)
40 #pragma push_macro("__TBB_NO_IMPLICITLINKAGE")
41 #define __TBB_NO_IMPLICIT_LINKAGE 1
43 #endif // defined(VISKORES_MSVC)
51 #include <tbb/blocked_range.h>
52 #include <tbb/blocked_range3d.h>
53 #include <tbb/parallel_for.h>
54 #include <tbb/parallel_reduce.h>
55 #include <tbb/parallel_scan.h>
56 #include <tbb/parallel_sort.h>
57 #include <tbb/partitioner.h>
58 #include <tbb/tick_count.h>
60 #if defined(VISKORES_MSVC)
61 #pragma pop_macro("__TBB_NO_IMPLICITLINKAGE")
75 template <
typename ResultType,
typename Function>
76 using WrappedBinaryOperator = viskores::cont::internal::WrappedBinaryOperator<ResultType, Function>;
83 template <
typename InputPortalType,
typename OutputPortalType>
92 const OutputPortalType& outPortal,
107 template <
typename InIter,
typename OutIter>
108 void DoCopy(InIter src, InIter srcEnd, OutIter dst, std::false_type)
const
110 using InputType =
typename InputPortalType::ValueType;
111 using OutputType =
typename OutputPortalType::ValueType;
112 while (src != srcEnd)
118 *dst =
static_cast<OutputType
>(
static_cast<InputType
>(*src));
125 template <
typename InIter,
typename OutIter>
126 void DoCopy(InIter src, InIter srcEnd, OutIter dst, std::true_type)
const
128 std::copy(src, srcEnd, dst);
133 void operator()(const ::tbb::blocked_range<viskores::Id>& range)
const
143 using InputType =
typename InputPortalType::ValueType;
144 using OutputType =
typename OutputPortalType::ValueType;
146 this->
DoCopy(inIter + this->InputOffset + range.begin(),
147 inIter + this->InputOffset + range.end(),
148 outIter + this->OutputOffset + range.begin(),
149 std::is_same<InputType, OutputType>());
153 template <
typename InputPortalType,
typename OutputPortalType>
155 const OutputPortalType& outPortal,
161 Kernel kernel(inPortal, outPortal, inOffset, outOffset);
162 ::tbb::blocked_range<viskores::Id> range(0, numValues, TBB_GRAIN_SIZE);
163 ::tbb::parallel_for(range, kernel);
166 template <
typename InputPortalType,
167 typename StencilPortalType,
168 typename OutputPortalType,
169 typename UnaryPredicateType>
212 (this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
227 const StencilPortalType& stencilPortal,
228 const OutputPortalType& outputPortal,
229 UnaryPredicateType unaryPredicate)
248 void operator()(const ::tbb::blocked_range<viskores::Id>& range)
265 this->Ranges.
InputEnd = range.end();
272 InputIteratorsType inputIters(this->InputPortal);
273 StencilIteratorsType stencilIters(this->StencilPortal);
274 OutputIteratorsType outputIters(this->OutputPortal);
276 using InputIteratorType =
typename InputIteratorsType::IteratorType;
277 using StencilIteratorType =
typename StencilIteratorsType::IteratorType;
278 using OutputIteratorType =
typename OutputIteratorsType::IteratorType;
280 InputIteratorType inIter = inputIters.GetBegin();
281 StencilIteratorType stencilIter = stencilIters.GetBegin();
282 OutputIteratorType outIter = outputIters.GetBegin();
294 writePos = range.begin();
307 UnaryPredicateType predicate(this->UnaryPredicate);
308 for (; readPos < readEnd; ++readPos)
310 if (predicate(stencilIter[readPos]))
312 outIter[writePos] = inIter[readPos];
325 using OutputIteratorType =
typename OutputIteratorsType::IteratorType;
327 OutputIteratorsType outputIters(this->OutputPortal);
328 OutputIteratorType outIter = outputIters.GetBegin();
340 if (srcBegin != dstBegin && srcBegin != srcEnd)
344 std::copy(outIter + srcBegin, outIter + srcEnd, outIter + dstBegin);
348 this->Ranges.
OutputEnd += srcEnd - srcBegin;
354 template <
typename InputPortalType,
355 typename StencilPortalType,
356 typename OutputPortalType,
357 typename UnaryPredicateType>
359 StencilPortalType stencilPortal,
360 OutputPortalType outputPortal,
361 UnaryPredicateType unaryPredicate)
363 const viskores::Id inputLength = inputPortal.GetNumberOfValues();
366 if (inputLength == 0)
372 inputPortal, stencilPortal, outputPortal, unaryPredicate);
373 ::tbb::blocked_range<viskores::Id> range(0, inputLength, TBB_GRAIN_SIZE);
375 ::tbb::parallel_reduce(range, body);
379 body.
Ranges.OutputBegin == 0 && body.
Ranges.OutputEnd <= inputLength);
381 return body.
Ranges.OutputEnd;
384 template <
class InputPortalType,
class T,
class BinaryOperationType>
396 BinaryOperationType binaryOperation)
417 void operator()(const ::tbb::blocked_range<viskores::Id>& range)
420 InputIteratorsType inputIterators(this->InputPortal);
423 typename InputIteratorsType::IteratorType inIter =
424 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
429 for (
viskores::Id index = range.begin() + 2; index != range.end(); ++index, ++inIter)
435 if (range.begin() == 0)
452 this->FirstCall =
false;
465 template <
class InputPortalType,
typename T,
class BinaryOperationType>
466 VISKORES_CONT static auto ReducePortals(InputPortalType inputPortal,
468 BinaryOperationType binaryOperation)
469 -> decltype(binaryOperation(initialValue, inputPortal.Get(0)))
471 using ResultType = decltype(binaryOperation(initialValue, inputPortal.Get(0)));
472 using WrappedBinaryOp = internal::WrappedBinaryOperator<ResultType, BinaryOperationType>;
474 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
475 ReduceBody<InputPortalType, ResultType, WrappedBinaryOp> body(
476 inputPortal, initialValue, wrappedBinaryOp);
477 viskores::Id arrayLength = inputPortal.GetNumberOfValues();
481 ::tbb::blocked_range<viskores::Id> range(0, arrayLength, TBB_GRAIN_SIZE);
482 ::tbb::parallel_reduce(range, body);
485 else if (arrayLength == 1)
488 return binaryOperation(initialValue, inputPortal.Get(0));
493 return static_cast<ResultType
>(initialValue);
501 template <
typename KeysInPortalType,
502 typename ValuesInPortalType,
503 typename KeysOutPortalType,
504 typename ValuesOutPortalType,
505 class BinaryOperationType>
508 using KeyType =
typename KeysInPortalType::ValueType;
509 using ValueType =
typename ValuesInPortalType::ValueType;
548 (this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
561 #ifdef VISKORES_DEBUG_TBB_RBK
568 const ValuesInPortalType& valuesInPortal,
569 const KeysOutPortalType& keysOutPortal,
570 const ValuesOutPortalType& valuesOutPortal,
571 BinaryOperationType binaryOperation)
577 #ifdef VISKORES_DEBUG_TBB_RBK
591 #ifdef VISKORES_DEBUG_TBB_RBK
600 void operator()(const ::tbb::blocked_range<viskores::Id>& range)
602 #ifdef VISKORES_DEBUG_TBB_RBK
603 ::tbb::tick_count startTime = ::tbb::tick_count::now();
604 #endif // VISKORES_DEBUG_TBB_RBK
620 this->Ranges.
InputEnd = range.end();
628 KeysInIteratorsType keysInIters(this->KeysInPortal);
629 ValuesInIteratorsType valuesInIters(this->ValuesInPortal);
630 KeysOutIteratorsType keysOutIters(this->KeysOutPortal);
631 ValuesOutIteratorsType valuesOutIters(this->ValuesOutPortal);
633 using KeysInIteratorType =
typename KeysInIteratorsType::IteratorType;
634 using ValuesInIteratorType =
typename ValuesInIteratorsType::IteratorType;
635 using KeysOutIteratorType =
typename KeysOutIteratorsType::IteratorType;
636 using ValuesOutIteratorType =
typename ValuesOutIteratorsType::IteratorType;
638 KeysInIteratorType keysIn = keysInIters.GetBegin();
639 ValuesInIteratorType valuesIn = valuesInIters.GetBegin();
640 KeysOutIteratorType keysOut = keysOutIters.GetBegin();
641 ValuesOutIteratorType valuesOut = valuesOutIters.GetBegin();
653 writePos = range.begin();
667 BinaryOperationType functor(this->BinaryOperation);
668 KeyType currentKey = keysIn[readPos];
669 ValueType currentValue = valuesIn[readPos];
675 if (!firstRun && keysOut[writePos - 1] == currentKey)
681 currentValue = functor(valuesOut[writePos], currentValue);
685 if (readPos >= readEnd)
687 keysOut[writePos] = currentKey;
688 valuesOut[writePos] = currentValue;
697 while (readPos < readEnd && currentKey == keysIn[readPos])
699 currentValue = functor(currentValue, valuesIn[readPos]);
704 keysOut[writePos] = currentKey;
705 valuesOut[writePos] = currentValue;
708 if (readPos < readEnd)
710 currentKey = keysIn[readPos];
711 currentValue = valuesIn[readPos];
721 #ifdef VISKORES_DEBUG_TBB_RBK
722 ::tbb::tick_count endTime = ::tbb::tick_count::now();
723 double time = (endTime - startTime).seconds();
724 this->ReduceTime += time;
725 std::ostringstream out;
726 out <<
"Reduced " << range.size() <<
" key/value pairs in " << time <<
"s. "
729 std::cerr << out.str();
739 using KeysIteratorType =
typename KeysIteratorsType::IteratorType;
740 using ValuesIteratorType =
typename ValuesIteratorsType::IteratorType;
742 #ifdef VISKORES_DEBUG_TBB_RBK
743 ::tbb::tick_count startTime = ::tbb::tick_count::now();
751 KeysIteratorsType keysIters(this->KeysOutPortal);
752 ValuesIteratorsType valuesIters(this->ValuesOutPortal);
753 KeysIteratorType keys = keysIters.GetBegin();
754 ValuesIteratorType values = valuesIters.GetBegin();
763 if (keys[srcBegin] == keys[lastDstIdx])
765 values[lastDstIdx] = this->
BinaryOperation(values[lastDstIdx], values[srcBegin]);
770 if (srcBegin != dstBegin && srcBegin != srcEnd)
774 std::copy(keys + srcBegin, keys + srcEnd, keys + dstBegin);
775 std::copy(values + srcBegin, values + srcEnd, values + dstBegin);
779 this->Ranges.
OutputEnd += srcEnd - srcBegin;
782 #ifdef VISKORES_DEBUG_TBB_RBK
783 ::tbb::tick_count endTime = ::tbb::tick_count::now();
784 double time = (endTime - startTime).seconds();
785 this->JoinTime += rhs.JoinTime + time;
786 std::ostringstream out;
787 out <<
"Joined " << (srcEnd - srcBegin) <<
" rhs values into body in " << time <<
"s. "
788 <<
"InRange: " << this->Ranges.
InputBegin <<
" " << this->Ranges.InputEnd <<
" "
789 <<
"OutRange: " << this->Ranges.OutputBegin <<
" " << this->Ranges.OutputEnd <<
"\n";
790 std::cerr << out.str();
796 template <
typename KeysInPortalType,
797 typename ValuesInPortalType,
798 typename KeysOutPortalType,
799 typename ValuesOutPortalType,
800 typename BinaryOperationType>
802 ValuesInPortalType valuesInPortal,
803 KeysOutPortalType keysOutPortal,
804 ValuesOutPortalType valuesOutPortal,
805 BinaryOperationType binaryOperation)
807 const viskores::Id inputLength = keysInPortal.GetNumberOfValues();
810 if (inputLength == 0)
815 using ValueType =
typename ValuesInPortalType::ValueType;
816 using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
817 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
824 body(keysInPortal, valuesInPortal, keysOutPortal, valuesOutPortal, wrappedBinaryOp);
825 ::tbb::blocked_range<viskores::Id> range(0, inputLength, TBB_GRAIN_SIZE);
827 #ifdef VISKORES_DEBUG_TBB_RBK
828 std::cerr <<
"\n\nTBB ReduceByKey:\n";
831 ::tbb::parallel_reduce(range, body);
833 #ifdef VISKORES_DEBUG_TBB_RBK
834 std::cerr <<
"Total reduce time: " << body.ReduceTime <<
"s\n";
835 std::cerr <<
"Total join time: " << body.JoinTime <<
"s\n";
836 std::cerr <<
"\nend\n";
839 body.Ranges.AssertSane();
840 VISKORES_ASSERT(body.Ranges.InputBegin == 0 && body.Ranges.InputEnd == inputLength &&
841 body.Ranges.OutputBegin == 0 && body.Ranges.OutputEnd <= inputLength);
843 return body.Ranges.OutputEnd;
846 #ifdef VISKORES_DEBUG_TBB_RBK
847 #undef VISKORES_DEBUG_TBB_RBK
850 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
853 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
862 const OutputPortalType& outputPortal,
863 BinaryOperationType binaryOperation)
884 void operator()(const ::tbb::blocked_range<viskores::Id>& range, ::tbb::pre_scan_tag)
887 InputIteratorsType inputIterators(this->InputPortal);
890 typename InputIteratorsType::IteratorType inIter =
891 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
893 this->FirstCall =
false;
894 for (
viskores::Id index = range.begin() + 1; index != range.end(); ++index, ++inIter)
903 void operator()(const ::tbb::blocked_range<viskores::Id>& range, ::tbb::final_scan_tag)
908 InputIteratorsType inputIterators(this->InputPortal);
909 OutputIteratorsType outputIterators(this->OutputPortal);
912 typename InputIteratorsType::IteratorType inIter =
913 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
914 typename OutputIteratorsType::IteratorType outIter =
915 outputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
917 this->FirstCall =
false;
919 for (
viskores::Id index = range.begin() + 1; index != range.end(); ++index, ++inIter, ++outIter)
938 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
941 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
951 const OutputPortalType& outputPortal,
952 BinaryOperationType binaryOperation,
974 void operator()(const ::tbb::blocked_range<viskores::Id>& range, ::tbb::pre_scan_tag)
977 InputIteratorsType inputIterators(this->InputPortal);
980 typename InputIteratorsType::IteratorType iter =
981 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
985 if (!(this->FirstCall && range.begin() > 0))
989 for (
viskores::Id index = range.begin() + 1; index != range.end(); ++index, ++iter)
994 this->FirstCall =
false;
999 void operator()(const ::tbb::blocked_range<viskores::Id>& range, ::tbb::final_scan_tag)
1004 InputIteratorsType inputIterators(this->InputPortal);
1005 OutputIteratorsType outputIterators(this->OutputPortal);
1008 typename InputIteratorsType::IteratorType inIter =
1009 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
1010 typename OutputIteratorsType::IteratorType outIter =
1011 outputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
1014 for (
viskores::Id index = range.begin(); index != range.end(); ++index, ++inIter, ++outIter)
1023 this->FirstCall =
false;
1034 if (!left.
FirstCall && !this->FirstCall)
1046 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
1047 VISKORES_CONT static typename std::remove_reference<typename OutputPortalType::ValueType>::type
1048 ScanInclusivePortals(InputPortalType inputPortal,
1049 OutputPortalType outputPortal,
1050 BinaryOperationType binaryOperation)
1052 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
1054 using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
1056 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
1057 ScanInclusiveBody<InputPortalType, OutputPortalType, WrappedBinaryOp> body(
1058 inputPortal, outputPortal, wrappedBinaryOp);
1059 viskores::Id arrayLength = inputPortal.GetNumberOfValues();
1061 ::tbb::blocked_range<viskores::Id> range(0, arrayLength, TBB_GRAIN_SIZE);
1062 ::tbb::parallel_scan(range, body);
1067 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
1068 VISKORES_CONT static typename std::remove_reference<typename OutputPortalType::ValueType>::type
1069 ScanExclusivePortals(
1070 InputPortalType inputPortal,
1071 OutputPortalType outputPortal,
1072 BinaryOperationType binaryOperation,
1073 typename std::remove_reference<typename OutputPortalType::ValueType>::type initialValue)
1075 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
1077 using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
1079 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
1080 ScanExclusiveBody<InputPortalType, OutputPortalType, WrappedBinaryOp> body(
1081 inputPortal, outputPortal, wrappedBinaryOp, initialValue);
1082 viskores::Id arrayLength = inputPortal.GetNumberOfValues();
1084 ::tbb::blocked_range<viskores::Id> range(0, arrayLength, TBB_GRAIN_SIZE);
1085 ::tbb::parallel_scan(range, body);
1092 template <
typename InputPortalType,
typename IndexPortalType,
typename OutputPortalType>
1097 IndexPortalType indexPortal,
1098 OutputPortalType outputPortal)
1106 void operator()(const ::tbb::blocked_range<viskores::Id>& range)
const
1117 for (
viskores::Id i = range.begin(); i < range.end(); i++)
1129 this->
ErrorMessage.RaiseError(
"Unexpected error in execution environment.");
1141 template <
typename InputPortalType,
typename IndexPortalType,
typename OutputPortalType>
1142 VISKORES_CONT static void ScatterPortal(InputPortalType inputPortal,
1143 IndexPortalType indexPortal,
1144 OutputPortalType outputPortal)
1146 const viskores::Id size = inputPortal.GetNumberOfValues();
1150 inputPortal, indexPortal, outputPortal);
1152 ::tbb::blocked_range<viskores::Id> range(0, size, TBB_GRAIN_SIZE);
1153 ::tbb::parallel_for(range, scatter);
1156 template <
typename PortalType,
typename BinaryOperationType>
1199 this->OutputEnd <= this->
InputEnd);
1201 (this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
1214 UniqueBody(
const PortalType& portal, BinaryOperationType binaryOperation)
1246 this->Ranges.
InputEnd = range.end();
1250 using IteratorType =
typename IteratorsType::IteratorType;
1252 IteratorsType iters(this->Portal);
1253 IteratorType data = iters.GetBegin();
1265 writePos = range.begin();
1279 BinaryOperationType functor(this->BinaryOperation);
1287 if (!firstRun && functor(data[writePos - 1], current))
1295 current = data[writePos];
1299 if (readPos >= readEnd)
1301 data[writePos] = current;
1311 while (readPos < readEnd && functor(current, data[readPos]))
1318 data[writePos] = current;
1322 if (readPos < readEnd)
1324 current = data[readPos];
1341 using IteratorType =
typename IteratorsType::IteratorType;
1349 IteratorsType iters(this->Portal);
1350 IteratorType data = iters.GetBegin();
1351 BinaryOperationType functor(this->BinaryOperation);
1360 if (functor(data[srcBegin], data[lastDstIdx]))
1366 if (srcBegin != dstBegin && srcBegin != srcEnd)
1370 std::copy(data + srcBegin, data + srcEnd, data + dstBegin);
1374 this->Ranges.
OutputEnd += srcEnd - srcBegin;
1380 template <
typename PortalType,
typename BinaryOperationType>
1383 const viskores::Id inputLength = portal.GetNumberOfValues();
1384 if (inputLength == 0)
1389 using WrappedBinaryOp = internal::WrappedBinaryOperator<bool, BinaryOperationType>;
1390 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
1393 ::tbb::blocked_range<viskores::Id> range(0, inputLength, TBB_GRAIN_SIZE);
1395 ::tbb::parallel_reduce(range, body);
1397 body.
Ranges.AssertSane();
1399 body.
Ranges.OutputBegin == 0 && body.
Ranges.OutputEnd <= inputLength);
1401 return body.
Ranges.OutputEnd;
1406 #endif //viskores_cont_tbb_internal_FunctorsTBB_h