Ginkgo Generated from develop branch based on develop. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
ic.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
6#define GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
7
8
9#include <memory>
10#include <type_traits>
11
12
13#include <ginkgo/core/base/abstract_factory.hpp>
14#include <ginkgo/core/base/composition.hpp>
15#include <ginkgo/core/base/exception.hpp>
16#include <ginkgo/core/base/exception_helpers.hpp>
17#include <ginkgo/core/base/lin_op.hpp>
18#include <ginkgo/core/base/precision_dispatch.hpp>
19#include <ginkgo/core/base/std_extensions.hpp>
20#include <ginkgo/core/factorization/par_ic.hpp>
21#include <ginkgo/core/matrix/dense.hpp>
22#include <ginkgo/core/solver/solver_traits.hpp>
23#include <ginkgo/core/solver/triangular.hpp>
24#include <ginkgo/core/stop/combined.hpp>
25#include <ginkgo/core/stop/iteration.hpp>
26#include <ginkgo/core/stop/residual_norm.hpp>
27
28
29namespace gko {
30namespace preconditioner {
31
32
78template <typename LSolverType = solver::LowerTrs<>, typename IndexType = int32>
79class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {
80 friend class EnableLinOp<Ic>;
81 friend class EnablePolymorphicObject<Ic, LinOp>;
82
83public:
84 static_assert(
85 std::is_same<typename LSolverType::transposed_type::transposed_type,
86 LSolverType>::value,
87 "LSolverType::transposed_type must be symmetric");
88 using value_type = typename LSolverType::value_type;
89 using l_solver_type = LSolverType;
90 using lh_solver_type = typename LSolverType::transposed_type;
91 using index_type = IndexType;
93
94 class Factory;
95
97 : public enable_parameters_type<parameters_type, Factory> {
101 std::shared_ptr<const typename l_solver_type::Factory>
103
107 std::shared_ptr<const LinOpFactory> factorization_factory{};
108
109 GKO_DEPRECATED("use with_l_solver instead")
110 parameters_type& with_l_solver_factory(
112 solver)
113 {
114 return with_l_solver(std::move(solver));
115 }
116
117 parameters_type& with_l_solver(
119 solver)
120 {
121 this->l_solver_generator = std::move(solver);
122 this->deferred_factories["l_solver"] = [](const auto& exec,
123 auto& params) {
124 if (!params.l_solver_generator.is_empty()) {
125 params.l_solver_factory =
126 params.l_solver_generator.on(exec);
127 }
128 };
129 return *this;
130 }
131
132 GKO_DEPRECATED("use with_factorization instead")
133 parameters_type& with_factorization_factory(
134 deferred_factory_parameter<const LinOpFactory> factorization)
135 {
136 return with_factorization(std::move(factorization));
137 }
138
139 parameters_type& with_factorization(
141 {
142 this->factorization_generator = std::move(factorization);
143 this->deferred_factories["factorization"] = [](const auto& exec,
144 auto& params) {
145 if (!params.factorization_generator.is_empty()) {
146 params.factorization_factory =
147 params.factorization_generator.on(exec);
148 }
149 };
150 return *this;
151 }
152
153 private:
154 deferred_factory_parameter<const typename l_solver_type::Factory>
155 l_solver_generator;
156
158 };
159
162
168 std::shared_ptr<const l_solver_type> get_l_solver() const
169 {
170 return l_solver_;
171 }
172
178 std::shared_ptr<const lh_solver_type> get_lh_solver() const
179 {
180 return lh_solver_;
181 }
182
183 std::unique_ptr<LinOp> transpose() const override
184 {
185 std::unique_ptr<transposed_type> transposed{
186 new transposed_type{this->get_executor()}};
187 transposed->set_size(gko::transpose(this->get_size()));
188 transposed->l_solver_ =
190 this->get_lh_solver()->transpose()));
191 transposed->lh_solver_ =
193 this->get_l_solver()->transpose()));
194
195 return std::move(transposed);
196 }
197
198 std::unique_ptr<LinOp> conj_transpose() const override
199 {
200 std::unique_ptr<transposed_type> transposed{
201 new transposed_type{this->get_executor()}};
202 transposed->set_size(gko::transpose(this->get_size()));
203 transposed->l_solver_ =
205 this->get_lh_solver()->conj_transpose()));
206 transposed->lh_solver_ =
208 this->get_l_solver()->conj_transpose()));
209
210 return std::move(transposed);
211 }
212
219 {
220 if (&other != this) {
222 auto exec = this->get_executor();
223 l_solver_ = other.l_solver_;
224 lh_solver_ = other.lh_solver_;
225 parameters_ = other.parameters_;
226 if (other.get_executor() != exec) {
227 l_solver_ = gko::clone(exec, l_solver_);
228 lh_solver_ = gko::clone(exec, lh_solver_);
229 }
230 }
231 return *this;
232 }
233
241 {
242 if (&other != this) {
244 auto exec = this->get_executor();
245 l_solver_ = std::move(other.l_solver_);
246 lh_solver_ = std::move(other.lh_solver_);
247 parameters_ = std::exchange(other.parameters_, parameters_type{});
248 if (other.get_executor() != exec) {
249 l_solver_ = gko::clone(exec, l_solver_);
250 lh_solver_ = gko::clone(exec, lh_solver_);
251 }
252 }
253 return *this;
254 }
255
260 Ic(const Ic& other) : Ic{other.get_executor()} { *this = other; }
261
267 Ic(Ic&& other) : Ic{other.get_executor()} { *this = std::move(other); }
268
269protected:
270 void apply_impl(const LinOp* b, LinOp* x) const override
271 {
272 // take care of real-to-complex apply
274 [&](auto dense_b, auto dense_x) {
275 this->set_cache_to(dense_b);
276 l_solver_->apply(dense_b, cache_.intermediate);
277 if (lh_solver_->apply_uses_initial_guess()) {
278 dense_x->copy_from(cache_.intermediate);
279 }
280 lh_solver_->apply(cache_.intermediate, dense_x);
281 },
282 b, x);
283 }
284
285 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
286 LinOp* x) const override
287 {
289 [&](auto dense_alpha, auto dense_b, auto dense_beta, auto dense_x) {
290 this->set_cache_to(dense_b);
291 l_solver_->apply(dense_b, cache_.intermediate);
292 lh_solver_->apply(dense_alpha, cache_.intermediate, dense_beta,
293 dense_x);
294 },
295 alpha, b, beta, x);
296 }
297
298 explicit Ic(std::shared_ptr<const Executor> exec)
299 : EnableLinOp<Ic>(std::move(exec))
300 {}
301
302 explicit Ic(const Factory* factory, std::shared_ptr<const LinOp> lin_op)
303 : EnableLinOp<Ic>(factory->get_executor(), lin_op->get_size()),
304 parameters_{factory->get_parameters()}
305 {
306 auto comp =
307 std::dynamic_pointer_cast<const Composition<value_type>>(lin_op);
308 std::shared_ptr<const LinOp> l_factor;
309
310 // build factorization if we weren't passed a composition
311 if (!comp) {
312 auto exec = lin_op->get_executor();
313 if (!parameters_.factorization_factory) {
314 parameters_.factorization_factory =
315 factorization::ParIc<value_type, index_type>::build()
316 .with_both_factors(false)
317 .on(exec);
318 }
319 auto fact = std::shared_ptr<const LinOp>(
320 parameters_.factorization_factory->generate(lin_op));
321 // ensure that the result is a composition
322 comp =
323 std::dynamic_pointer_cast<const Composition<value_type>>(fact);
324 if (!comp) {
325 GKO_NOT_SUPPORTED(comp);
326 }
327 }
328 // comp must contain one or two factors
329 if (comp->get_operators().size() > 2 || comp->get_operators().empty()) {
330 GKO_NOT_SUPPORTED(comp);
331 }
332 l_factor = comp->get_operators()[0];
333 GKO_ASSERT_IS_SQUARE_MATRIX(l_factor);
334
335 auto exec = this->get_executor();
336
337 // If no factories are provided, generate default ones
338 if (!parameters_.l_solver_factory) {
340 // If comp contains both factors: use the transposed factor to avoid
341 // transposing twice
342 if (comp->get_operators().size() == 2) {
343 auto lh_factor = comp->get_operators()[1];
344 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, lh_factor);
345 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
346 } else {
347 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
348 }
349 } else {
350 l_solver_ = parameters_.l_solver_factory->generate(l_factor);
351 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
352 }
353 }
354
362 void set_cache_to(const LinOp* b) const
363 {
364 if (cache_.intermediate == nullptr) {
365 cache_.intermediate =
367 }
368 // Use b as the initial guess for the first triangular solve
369 cache_.intermediate->copy_from(b);
370 }
371
372
380 template <typename SolverType>
381 static std::enable_if_t<solver::has_with_criteria<SolverType>::value,
382 std::unique_ptr<SolverType>>
383 generate_default_solver(const std::shared_ptr<const Executor>& exec,
384 const std::shared_ptr<const LinOp>& mtx)
385 {
387 const unsigned int default_max_iters{
388 static_cast<unsigned int>(mtx->get_size()[0])};
389
390 return SolverType::build()
391 .with_criteria(
392 gko::stop::Iteration::build().with_max_iters(default_max_iters),
394 .with_reduction_factor(default_reduce_residual))
395 .on(exec)
396 ->generate(mtx);
397 }
398
402 template <typename SolverType>
403 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value,
404 std::unique_ptr<SolverType>>
405 generate_default_solver(const std::shared_ptr<const Executor>& exec,
406 const std::shared_ptr<const LinOp>& mtx)
407 {
408 return SolverType::build().on(exec)->generate(mtx);
409 }
410
411private:
412 std::shared_ptr<const l_solver_type> l_solver_{};
413 std::shared_ptr<const lh_solver_type> lh_solver_{};
424 mutable struct cache_struct {
425 cache_struct() = default;
426 ~cache_struct() = default;
427 cache_struct(const cache_struct&) {}
428 cache_struct(cache_struct&&) {}
429 cache_struct& operator=(const cache_struct&) { return *this; }
430 cache_struct& operator=(cache_struct&&) { return *this; }
431 std::unique_ptr<LinOp> intermediate{};
432 } cache_;
433};
434
435
436} // namespace preconditioner
437} // namespace gko
438
439
440#endif // GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
The EnableLinOp mixin can be used to provide sensible default implementations of the majority of the ...
Definition lin_op.hpp:880
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:663
Definition lin_op.hpp:118
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition polymorphic_object.hpp:235
Linear operators which support transposition should implement the Transposable interface.
Definition lin_op.hpp:434
Represents a factory parameter of factory type that can either initialized by a pre-existing factory ...
Definition abstract_factory.hpp:309
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:211
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
Creates an uninitialized Dense matrix of the specified size.
The Incomplete Cholesky (IC) preconditioner solves the equation for a given lower triangular matrix ...
Definition ic.hpp:79
std::shared_ptr< const lh_solver_type > get_lh_solver() const
Returns the solver which is used for the L^H matrix.
Definition ic.hpp:178
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
Definition ic.hpp:183
Ic(const Ic &other)
Copy-constructs an IC preconditioner.
Definition ic.hpp:260
Ic & operator=(Ic &&other)
Move-assigns an IC preconditioner.
Definition ic.hpp:240
Ic(Ic &&other)
Move-constructs an IC preconditioner.
Definition ic.hpp:267
std::shared_ptr< const l_solver_type > get_l_solver() const
Returns the solver which is used for the provided L matrix.
Definition ic.hpp:168
Ic & operator=(const Ic &other)
Copy-assigns an IC preconditioner.
Definition ic.hpp:218
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
Definition ic.hpp:198
The ResidualNorm class is a stopping criterion which stops the iteration process when the actual resi...
Definition residual_norm.hpp:110
#define GKO_ENABLE_BUILD_METHOD(_factory_name)
Defines a build method for the factory, simplifying its construction by removing the repetitive typin...
Definition abstract_factory.hpp:394
#define GKO_ENABLE_LIN_OP_FACTORY(_lin_op, _parameters_name, _factory_name)
This macro will generate a default implementation of a LinOpFactory for the LinOp subclass it is defi...
Definition lin_op.hpp:1018
@ factory
LinOpFactory events.
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:775
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:326
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:175
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:120
detail::shared_type< OwningPointer > share(OwningPointer &&p)
Marks the object pointed to by p as shared.
Definition utils_helper.hpp:226
std::shared_ptr< const typename l_solver_type::Factory > l_solver_factory
Factory for the L solver.
Definition ic.hpp:102
std::shared_ptr< const LinOpFactory > factorization_factory
Factory for the factorization.
Definition ic.hpp:107