46 #ifndef MUELU_FILTEREDAFACTORY_KOKKOS_DEF_HPP 47 #define MUELU_FILTEREDAFACTORY_KOKKOS_DEF_HPP 49 #ifdef HAVE_MUELU_KOKKOS_REFACTOR 53 #include <Xpetra_Matrix.hpp> 54 #include <Xpetra_MatrixFactory.hpp> 56 #include "MueLu_FactoryManager.hpp" 63 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
64 RCP<const ParameterList> FilteredAFactory_kokkos<Scalar, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::
65 GetValidParameterList()
const {
66 RCP<ParameterList> validParamList =
rcp(
new ParameterList());
68 #define SET_VALID_ENTRY(name) validParamList->setEntry(name, MasterList::getEntry(name)) 72 #undef SET_VALID_ENTRY 74 validParamList->set< RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A used for filtering");
75 validParamList->set< RCP<const FactoryBase> >(
"Graph", Teuchos::null,
"Generating factory for coalesced filtered graph");
76 validParamList->set< RCP<const FactoryBase> >(
"Filtering", Teuchos::null,
"Generating factory for filtering boolean");
78 return validParamList;
81 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
82 void FilteredAFactory_kokkos<Scalar, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::
83 DeclareInput(Level& currentLevel)
const {
84 Input(currentLevel,
"A");
85 Input(currentLevel,
"Filtering");
86 Input(currentLevel,
"Graph");
89 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
90 void FilteredAFactory_kokkos<Scalar, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::
91 Build(Level& currentLevel)
const {
92 FactoryMonitor m(*
this,
"Matrix filtering", currentLevel);
94 RCP<Matrix> A = Get< RCP<Matrix> >(currentLevel,
"A");
95 if (Get<bool>(currentLevel,
"Filtering") ==
false) {
96 GetOStream(
Runtime0) <<
"Filtered matrix is not being constructed as no filtering is being done" << std::endl;
97 Set(currentLevel,
"A", A);
101 const ParameterList& pL = GetParameterList();
102 bool lumping = pL.get<
bool>(
"filtered matrix: use lumping");
104 GetOStream(
Runtime0) <<
"Lumping dropped entries" << std::endl;
106 RCP<LWGraph_kokkos> graph = Get< RCP<LWGraph_kokkos> >(currentLevel,
"Graph");
108 RCP<ParameterList> fillCompleteParams(
new ParameterList);
109 fillCompleteParams->set(
"No Nonlocal Changes",
true);
111 RCP<Matrix> filteredA;
112 if (pL.get<
bool>(
"filtered matrix: reuse graph")) {
113 filteredA = MatrixFactory::Build(A->getCrsGraph());
114 filteredA->fillComplete(fillCompleteParams);
116 BuildReuse(*A, *graph, lumping, *filteredA);
119 filteredA = MatrixFactory::Build(A->getRowMap(), A->getColMap(), A->getNodeMaxNumRowEntries(), Xpetra::StaticProfile);
121 BuildNew(*A, *graph, lumping, *filteredA);
123 filteredA->fillComplete(A->getDomainMap(), A->getRangeMap(), fillCompleteParams);
126 filteredA->SetFixedBlockSize(A->GetFixedBlockSize());
128 if (pL.get<
bool>(
"filtered matrix: reuse eigenvalue")) {
133 filteredA->SetMaxEigenvalueEstimate(A->GetMaxEigenvalueEstimate());
136 Set(currentLevel,
"A", filteredA);
146 template<
class MatrixType,
class GraphType,
class FilterType>
147 class BuildReuseFunctor {
149 MatrixType localA, localFA;
155 typedef typename MatrixType::ordinal_type LO;
156 typedef typename MatrixType::value_type SC;
157 typedef Kokkos::ArithTraits<SC> ATS;
160 BuildReuseFunctor(MatrixType localA_, MatrixType localFA_,
size_t blkSize_, GraphType graph_,
bool lumping_, FilterType filter_) :
169 KOKKOS_INLINE_FUNCTION
170 void operator()(
const size_t i)
const {
172 typename GraphType::row_type indsG = graph.getNeighborVertices(i);
173 for (
size_t j = 0; j < indsG.size(); j++)
174 for (
size_t k = 0; k < blkSize; k++)
175 filter(indsG(j)*blkSize + k) = 1;
177 SC zero = ATS::zero();
178 for (
size_t k = 0; k < blkSize; k++) {
179 LO row = i*blkSize + k;
181 auto rowA = localA.row (row);
182 auto nnz = rowA.length;
187 auto rowFA = localFA.row (row);
189 for (decltype(nnz) j = 0; j < nnz; j++)
190 rowFA.value(j) = rowA.value(j);
192 if (lumping ==
false) {
193 for (decltype(nnz) j = 0; j < nnz; j++)
194 if (!filter(rowA.colidx(j)))
195 rowFA.value(j) = zero;
201 for (decltype(nnz) j = 0; j < nnz; j++) {
202 if (filter(rowA.colidx(j))) {
203 if (rowA.colidx(j) == row) {
210 diagExtra += rowFA.value(j);
212 rowFA.value(j) = zero;
220 rowFA.value(diagIndex) += diagExtra;
225 for (
size_t j = 0; j < indsG.size(); j++)
226 for (
size_t k = 0; k < blkSize; k++)
227 filter(indsG(j)*blkSize + k) = 0;
231 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
232 void FilteredAFactory_kokkos<Scalar, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::
233 BuildReuse(
const Matrix& A,
const LWGraph_kokkos& graph,
const bool lumping, Matrix& filteredA)
const {
236 size_t blkSize = A.GetFixedBlockSize();
238 auto localA = A .getLocalMatrix();
239 auto localFA = filteredA.getLocalMatrix();
241 Kokkos::View<char*> filter(
"filter", blkSize * graph.GetImportMap()->getNodeNumElements(), 0);
243 size_t numGRows = graph.GetNodeNumVertices();
245 BuildReuseFunctor<decltype(localA), LWGraph_kokkos, decltype(filter)> functor(localA, localFA, blkSize, graph, lumping, filter);
249 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
250 void FilteredAFactory_kokkos<Scalar, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::
251 BuildNew(
const Matrix& A,
const LWGraph_kokkos& graph,
const bool lumping, Matrix& filteredA)
const {
257 #endif // HAVE_MUELU_KOKKOS_REFACTOR 258 #endif // MUELU_FILTEREDAFACTORY_KOKKOS_DEF_HPP #define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
One-liner description of what is happening.
Namespace for MueLu classes and methods.
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
void parallel_for(const ExecPolicy &policy, const FunctorType &functor, const std::string &str="", typename Impl::enable_if< ! Impl::is_integral< ExecPolicy >::value >::type *=0)
#define SET_VALID_ENTRY(name)