19 #ifndef viskores_cont_internal_DeviceAdapterAlgorithmGeneral_h
20 #define viskores_cont_internal_DeviceAdapterAlgorithmGeneral_h
41 #include <type_traits>
108 template <
class DerivedAlgorithm,
class DeviceAdapterTag>
109 struct DeviceAdapterAlgorithmGeneral
118 template <
typename T,
class CIn>
130 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
131 inputPortal, outputPortal, index);
133 DerivedAlgorithm::Schedule(kernel, 1);
142 template <
typename IndicesStorage>
154 auto indicesPortal = indices.
PrepareForOutput(numBits, DeviceAdapterTag{}, token);
156 std::atomic<viskores::UInt64> popCount;
157 popCount.store(0, std::memory_order_seq_cst);
159 using Functor = BitFieldToUnorderedSetFunctor<decltype(bitsPortal), decltype(indicesPortal)>;
160 Functor functor{ bitsPortal, indicesPortal, popCount };
162 DerivedAlgorithm::Schedule(functor, functor.GetNumberOfInstances());
163 DerivedAlgorithm::Synchronize();
167 numBits =
static_cast<viskores::Id>(popCount.load(std::memory_order_seq_cst));
175 template <
typename T,
typename U,
class CIn,
class COut>
185 auto outputPortal = output.
PrepareForOutput(inSize, DeviceAdapterTag(), token);
187 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(inputPortal, outputPortal);
188 DerivedAlgorithm::Schedule(kernel, inSize);
193 template <
typename T,
typename U,
class CIn,
class CStencil,
class COut,
class UnaryPredicate>
197 UnaryPredicate unary_predicate)
204 using IndexArrayType =
206 IndexArrayType indices;
211 auto stencilPortal = stencil.
PrepareForInput(DeviceAdapterTag(), token);
212 auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag(), token);
214 StencilToIndexFlagKernel<decltype(stencilPortal), decltype(indexPortal), UnaryPredicate>
215 indexKernel(stencilPortal, indexPortal, unary_predicate);
217 DerivedAlgorithm::Schedule(indexKernel, arrayLength);
220 viskores::Id outArrayLength = DerivedAlgorithm::ScanExclusive(indices, indices);
226 auto stencilPortal = stencil.
PrepareForInput(DeviceAdapterTag(), token);
227 auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag(), token);
228 auto outputPortal = output.
PrepareForOutput(outArrayLength, DeviceAdapterTag(), token);
230 CopyIfKernel<decltype(inputPortal),
231 decltype(stencilPortal),
232 decltype(indexPortal),
233 decltype(outputPortal),
235 copyKernel(inputPortal, stencilPortal, indexPortal, outputPortal, unary_predicate);
236 DerivedAlgorithm::Schedule(copyKernel, arrayLength);
240 template <
typename T,
typename U,
class CIn,
class CStencil,
class COut>
248 DerivedAlgorithm::CopyIf(input, stencil, output, unary_predicate);
253 template <
typename T,
typename U,
class CIn,
class COut>
265 if (input == output &&
266 ((outputIndex >= inputStartIndex &&
267 outputIndex < inputStartIndex + numberOfElementsToCopy) ||
268 (inputStartIndex >= outputIndex &&
269 inputStartIndex < outputIndex + numberOfElementsToCopy)))
274 if (inputStartIndex < 0 || numberOfElementsToCopy < 0 || outputIndex < 0 ||
275 inputStartIndex >= inSize)
281 if (inSize < (inputStartIndex + numberOfElementsToCopy))
283 numberOfElementsToCopy = (inSize - inputStartIndex);
287 const viskores::Id copyOutEnd = outputIndex + numberOfElementsToCopy;
288 if (outSize < copyOutEnd)
299 DerivedAlgorithm::CopySubRange(output, 0, outSize, temp);
309 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
310 inputPortal, outputPortal, inputStartIndex, outputIndex);
311 DerivedAlgorithm::Schedule(kernel, numberOfElementsToCopy);
325 std::atomic<viskores::UInt64> popCount;
326 popCount.store(0, std::memory_order_relaxed);
328 using Functor = CountSetBitsFunctor<decltype(bitsPortal)>;
329 Functor functor{ bitsPortal, popCount };
331 DerivedAlgorithm::Schedule(functor, functor.GetNumberOfInstances());
332 DerivedAlgorithm::Synchronize();
334 return static_cast<viskores::Id>(popCount.load(std::memory_order_seq_cst));
353 using WordType =
typename viskores::cont::BitField::template ExecutionTypes<
354 DeviceAdapterTag>::WordTypePreferred;
356 using Functor = FillBitFieldFunctor<decltype(portal), WordType>;
357 Functor functor{ portal, value ? ~WordType{ 0 } : WordType{ 0 } };
359 const viskores::Id numWords = portal.template GetNumberOfWords<WordType>();
360 DerivedAlgorithm::Schedule(functor, numWords);
379 using WordType =
typename viskores::cont::BitField::template ExecutionTypes<
380 DeviceAdapterTag>::WordTypePreferred;
382 using Functor = FillBitFieldFunctor<decltype(portal), WordType>;
383 Functor functor{ portal, value ? ~WordType{ 0 } : WordType{ 0 } };
385 const viskores::Id numWords = portal.template GetNumberOfWords<WordType>();
386 DerivedAlgorithm::Schedule(functor, numWords);
391 template <
typename WordType>
397 "Invalid word type.");
414 auto repWord = RepeatTo32BitsIfNeeded(word);
415 using RepWordType = decltype(repWord);
417 using Functor = FillBitFieldFunctor<decltype(portal), RepWordType>;
418 Functor functor{ portal, repWord };
420 const viskores::Id numWords = portal.template GetNumberOfWords<RepWordType>();
421 DerivedAlgorithm::Schedule(functor, numWords);
426 template <
typename WordType>
430 "Invalid word type.");
446 auto repWord = RepeatTo32BitsIfNeeded(word);
447 using RepWordType = decltype(repWord);
449 using Functor = FillBitFieldFunctor<decltype(portal), RepWordType>;
450 Functor functor{ portal, repWord };
452 const viskores::Id numWords = portal.template GetNumberOfWords<RepWordType>();
453 DerivedAlgorithm::Schedule(functor, numWords);
458 template <
typename T,
typename S>
471 auto portal = handle.
PrepareForOutput(numValues, DeviceAdapterTag{}, token);
472 FillArrayHandleFunctor<decltype(portal)> functor{ portal, value };
473 DerivedAlgorithm::Schedule(functor, numValues);
478 template <
typename T,
typename S>
492 auto portal = handle.
PrepareForOutput(numValues, DeviceAdapterTag{}, token);
493 FillArrayHandleFunctor<decltype(portal)> functor{ portal, value };
494 DerivedAlgorithm::Schedule(functor, numValues);
499 template <
typename T,
class CIn,
class CVal,
class COut>
512 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
514 LowerBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
515 inputPortal, valuesPortal, outputPortal);
517 DerivedAlgorithm::Schedule(kernel, arraySize);
520 template <
typename T,
class CIn,
class CVal,
class COut,
class BinaryCompare>
524 BinaryCompare binary_compare)
534 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
536 LowerBoundsComparisonKernel<decltype(inputPortal),
537 decltype(valuesPortal),
538 decltype(outputPortal),
540 kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
542 DerivedAlgorithm::Schedule(kernel, arraySize);
545 template <
class CIn,
class COut>
552 DeviceAdapterAlgorithmGeneral<DerivedAlgorithm, DeviceAdapterTag>::LowerBounds(
553 input, values_output, values_output);
558 #ifndef VISKORES_CUDA
562 template <
typename T,
typename BinaryFunctor>
563 class ReduceDecoratorImpl
569 ReduceDecoratorImpl(
const T& initialValue,
const BinaryFunctor& binaryFunctor)
570 : InitialValue(initialValue)
571 , ReduceOperator(binaryFunctor)
575 template <
typename Portal>
576 VISKORES_CONT ReduceKernel<Portal, T, BinaryFunctor> CreateFunctor(
const Portal& portal)
const
578 return ReduceKernel<Portal, T, BinaryFunctor>(
579 portal, this->InitialValue, this->ReduceOperator);
584 BinaryFunctor ReduceOperator;
588 template <
typename T,
typename U,
class CIn>
593 return DerivedAlgorithm::Reduce(input, initialValue,
viskores::Add());
596 template <
typename T,
typename U,
class CIn,
class BinaryFunctor>
599 BinaryFunctor binary_functor)
612 length, ReduceDecoratorImpl<U, BinaryFunctor>(initialValue, binary_functor), input);
616 DerivedAlgorithm::ScanInclusive(reduced, inclusiveScanStorage, binary_functor);
622 template <
typename T,
633 BinaryFunctor binary_functor)
642 if (numberOfKeys <= 1)
644 DerivedAlgorithm::Copy(keys, keys_output);
645 DerivedAlgorithm::Copy(values, values_output);
657 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
658 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
659 inputPortal, keyStatePortal);
660 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
676 DerivedAlgorithm::ScanInclusive(
677 scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
685 DerivedAlgorithm::CopyIf(reducedValues, stencil, values_output, ReduceByKeyUnaryStencilOp());
694 DerivedAlgorithm::Copy(keys, keys_output);
695 DerivedAlgorithm::Unique(keys_output);
701 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
704 BinaryFunctor binaryFunctor,
705 const T& initialValue)
717 T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
721 auto inputPortal = inclusiveScan.
PrepareForInput(DeviceAdapterTag(), token);
722 auto outputPortal = output.
PrepareForOutput(numValues, DeviceAdapterTag(), token);
724 InclusiveToExclusiveKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
725 inclusiveToExclusive(inputPortal, outputPortal, binaryFunctor, initialValue);
727 DerivedAlgorithm::Schedule(inclusiveToExclusive, numValues);
729 return binaryFunctor(initialValue, result);
732 template <
typename T,
class CIn,
class COut>
738 return DerivedAlgorithm::ScanExclusive(
744 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
747 BinaryFunctor binaryFunctor,
748 const T& initialValue)
761 T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
765 auto inputPortal = inclusiveScan.
PrepareForInput(DeviceAdapterTag(), token);
766 auto outputPortal = output.
PrepareForOutput(numValues + 1, DeviceAdapterTag(), token);
768 InclusiveToExtendedKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
769 inclusiveToExtended(inputPortal,
773 binaryFunctor(initialValue, result));
775 DerivedAlgorithm::Schedule(inclusiveToExtended, numValues + 1);
778 template <
typename T,
class CIn,
class COut>
784 DerivedAlgorithm::ScanExtended(
790 template <
typename KeyT,
800 const ValueT& initialValue,
801 BinaryFunctor binaryFunctor)
810 if (numberOfKeys == 0)
814 else if (numberOfKeys == 1)
830 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
831 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
832 inputPortal, keyStatePortal);
833 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
841 auto keyStatePortal = keystate.
PrepareForInput(DeviceAdapterTag(), token);
842 auto tempPortal = temp.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
844 ShiftCopyAndInit<ValueT,
845 decltype(inputPortal),
846 decltype(keyStatePortal),
847 decltype(tempPortal)>
848 kernel(inputPortal, keyStatePortal, tempPortal, initialValue);
849 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
852 DerivedAlgorithm::ScanInclusiveByKey(keys, temp, output, binaryFunctor);
855 template <
typename KeyT,
typename ValueT,
class KIn,
typename VIn,
typename VOut>
863 DerivedAlgorithm::ScanExclusiveByKey(
869 template <
typename T,
class CIn,
class COut>
875 return DerivedAlgorithm::ScanInclusive(input, output,
viskores::Add());
879 template <
typename T1,
typename S1,
typename T2,
typename S2>
886 template <
typename T,
typename S>
894 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
897 BinaryFunctor binary_functor)
901 if (!ArrayHandlesAreSame(input, output))
903 DerivedAlgorithm::Copy(input, output);
916 using ScanKernelType = ScanKernel<decltype(portal), BinaryFunctor>;
920 for (stride = 2; stride - 1 < numValues; stride *= 2)
922 ScanKernelType kernel(portal, binary_functor, stride, stride / 2 - 1);
923 DerivedAlgorithm::Schedule(kernel, numValues / stride);
927 for (stride /= 2; stride > 1; stride /= 2)
929 ScanKernelType kernel(portal, binary_functor, stride, stride - 1);
930 DerivedAlgorithm::Schedule(kernel, numValues / stride);
934 return GetExecutionValue(output, numValues - 1);
937 template <
typename KeyT,
typename ValueT,
class KIn,
class VIn,
class VOut>
945 return DerivedAlgorithm::ScanInclusiveByKey(keys, values, values_output,
viskores::Add());
948 template <
typename KeyT,
typename ValueT,
class KIn,
class VIn,
class VOut,
class BinaryFunctor>
953 BinaryFunctor binary_functor)
960 if (numberOfKeys <= 1)
962 DerivedAlgorithm::Copy(values, values_output);
974 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
975 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
976 inputPortal, keyStatePortal);
977 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
992 DerivedAlgorithm::ScanInclusive(
993 scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
996 DerivedAlgorithm::Copy(reducedValues, values_output);
1002 template <
typename T,
class Storage,
class BinaryCompare>
1004 BinaryCompare binary_compare)
1014 while (numThreads < numValues)
1023 using MergeKernel = BitonicSortMergeKernel<decltype(portal), BinaryCompare>;
1024 using CrossoverKernel = BitonicSortCrossoverKernel<decltype(portal), BinaryCompare>;
1026 for (
viskores::Id crossoverSize = 1; crossoverSize < numValues; crossoverSize *= 2)
1028 DerivedAlgorithm::Schedule(CrossoverKernel(portal, binary_compare, crossoverSize),
1030 for (
viskores::Id mergeSize = crossoverSize / 2; mergeSize > 0; mergeSize /= 2)
1032 DerivedAlgorithm::Schedule(MergeKernel(portal, binary_compare, mergeSize), numThreads);
1037 template <
typename T,
class Storage>
1042 DerivedAlgorithm::Sort(values, DefaultCompareFunctor());
1047 template <
typename T,
typename U,
class StorageT,
class StorageU>
1057 DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U>());
1060 template <
typename T,
typename U,
class StorageT,
class StorageU,
class BinaryCompare>
1063 BinaryCompare binary_compare)
1072 DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare));
1075 template <
typename T,
1081 typename BinaryFunctor>
1085 BinaryFunctor binaryFunctor)
1097 auto input1Portal = input1.
PrepareForInput(DeviceAdapterTag(), token);
1098 auto input2Portal = input2.
PrepareForInput(DeviceAdapterTag(), token);
1099 auto outputPortal = output.
PrepareForOutput(numValues, DeviceAdapterTag(), token);
1101 BinaryTransformKernel<decltype(input1Portal),
1102 decltype(input2Portal),
1103 decltype(outputPortal),
1105 binaryKernel(input1Portal, input2Portal, outputPortal, binaryFunctor);
1106 DerivedAlgorithm::Schedule(binaryKernel, numValues);
1112 template <
typename T,
class Storage>
1120 template <
typename T,
class Storage,
class BinaryCompare>
1122 BinaryCompare binary_compare)
1129 using WrappedBOpType = internal::WrappedBinaryOperator<bool, BinaryCompare>;
1130 WrappedBOpType wrappedCompare(binary_compare);
1134 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1135 auto stencilPortal = stencilArray.
PrepareForOutput(inputSize, DeviceAdapterTag(), token);
1136 ClassifyUniqueComparisonKernel<decltype(valuesPortal),
1137 decltype(stencilPortal),
1139 classifyKernel(valuesPortal, stencilPortal, wrappedCompare);
1141 DerivedAlgorithm::Schedule(classifyKernel, inputSize);
1146 DerivedAlgorithm::CopyIf(values, stencilArray, outputArray);
1149 DerivedAlgorithm::Copy(outputArray, values);
1154 template <
typename T,
class CIn,
class CVal,
class COut>
1166 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1167 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
1169 UpperBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
1170 inputPortal, valuesPortal, outputPortal);
1171 DerivedAlgorithm::Schedule(kernel, arraySize);
1174 template <
typename T,
class CIn,
class CVal,
class COut,
class BinaryCompare>
1178 BinaryCompare binary_compare)
1187 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1188 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
1190 UpperBoundsKernelComparisonKernel<decltype(inputPortal),
1191 decltype(valuesPortal),
1192 decltype(outputPortal),
1194 kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
1196 DerivedAlgorithm::Schedule(kernel, arraySize);
1199 template <
class CIn,
class COut>
1206 DeviceAdapterAlgorithmGeneral<DerivedAlgorithm, DeviceAdapterTag>::UpperBounds(
1207 input, values_output, values_output);
1223 template <
typename DeviceTag>
1224 class DeviceTaskTypes
1227 template <
typename WorkletType,
typename InvocationType>
1228 static viskores::exec::internal::TaskSingular<WorkletType, InvocationType>
MakeTask(
1229 WorkletType& worklet,
1230 InvocationType& invocation,
1234 using Task = viskores::exec::internal::TaskSingular<WorkletType, InvocationType>;
1235 return Task(worklet, invocation, globalIndexOffset);
1238 template <
typename WorkletType,
typename InvocationType>
1239 static viskores::exec::internal::TaskSingular<WorkletType, InvocationType>
MakeTask(
1240 WorkletType& worklet,
1241 InvocationType& invocation,
1245 using Task = viskores::exec::internal::TaskSingular<WorkletType, InvocationType>;
1246 return Task(worklet, invocation, globalIndexOffset);
1252 #endif //viskores_cont_internal_DeviceAdapterAlgorithmGeneral_h