Ginkgo Generated from develop branch based on develop. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
triangular.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_SOLVER_TRIANGULAR_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_TRIANGULAR_HPP_
7
8
9#include <memory>
10#include <utility>
11
12
13#include <ginkgo/core/base/abstract_factory.hpp>
14#include <ginkgo/core/base/array.hpp>
15#include <ginkgo/core/base/dim.hpp>
16#include <ginkgo/core/base/exception_helpers.hpp>
17#include <ginkgo/core/base/lin_op.hpp>
18#include <ginkgo/core/base/polymorphic_object.hpp>
19#include <ginkgo/core/base/types.hpp>
20#include <ginkgo/core/base/utils.hpp>
21#include <ginkgo/core/log/logger.hpp>
22#include <ginkgo/core/matrix/csr.hpp>
23#include <ginkgo/core/matrix/identity.hpp>
24#include <ginkgo/core/solver/solver_base.hpp>
25
26
27namespace gko {
28namespace solver {
29
30
31struct SolveStruct;
32
33
39enum class trisolve_algorithm { sparselib, syncfree };
40
41
42template <typename ValueType, typename IndexType>
43class UpperTrs;
44
45
63template <typename ValueType = default_precision, typename IndexType = int32>
64class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
65 public EnableSolverBase<LowerTrs<ValueType, IndexType>,
66 matrix::Csr<ValueType, IndexType>>,
67 public Transposable {
68 friend class EnableLinOp<LowerTrs>;
70 friend class UpperTrs<ValueType, IndexType>;
71
72public:
73 using value_type = ValueType;
74 using index_type = IndexType;
76
77 std::unique_ptr<LinOp> transpose() const override;
78
79 std::unique_ptr<LinOp> conj_transpose() const override;
80
82 {
90
95 bool GKO_FACTORY_PARAMETER_SCALAR(unit_diagonal, false);
96
104 algorithm, trisolve_algorithm::sparselib);
105 };
108
115
123
130
137
138protected:
140
141 void apply_impl(const LinOp* b, LinOp* x) const override;
142
143 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
144 LinOp* x) const override;
145
150 void generate();
151
152 explicit LowerTrs(std::shared_ptr<const Executor> exec)
153 : EnableLinOp<LowerTrs>(std::move(exec))
154 {}
155
156 explicit LowerTrs(const Factory* factory,
157 std::shared_ptr<const LinOp> system_matrix)
158 : EnableLinOp<LowerTrs>(factory->get_executor(),
159 gko::transpose(system_matrix->get_size())),
160 EnableSolverBase<LowerTrs<ValueType, IndexType>, CsrMatrix>{
161 copy_and_convert_to<CsrMatrix>(factory->get_executor(),
162 system_matrix)},
163 parameters_{factory->get_parameters()}
164 {
165 this->generate();
166 }
167
168private:
169 std::shared_ptr<solver::SolveStruct> solve_struct_;
170};
171
172
173template <typename ValueType, typename IndexType>
174struct workspace_traits<LowerTrs<ValueType, IndexType>> {
176 // number of vectors used by this workspace
177 static int num_vectors(const Solver&);
178 // number of arrays used by this workspace
179 static int num_arrays(const Solver&);
180 // array containing the num_vectors names for the workspace vectors
181 static std::vector<std::string> op_names(const Solver&);
182 // array containing the num_arrays names for the workspace vectors
183 static std::vector<std::string> array_names(const Solver&);
184 // array containing all varying scalar vectors (independent of problem size)
185 static std::vector<int> scalars(const Solver&);
186 // array containing all varying vectors (dependent on problem size)
187 static std::vector<int> vectors(const Solver&);
188
189 // transposed input vector
190 constexpr static int transposed_b = 0;
191 // transposed output vector
192 constexpr static int transposed_x = 1;
193};
194
195
213template <typename ValueType = default_precision, typename IndexType = int32>
214class UpperTrs : public EnableLinOp<UpperTrs<ValueType, IndexType>>,
215 public EnableSolverBase<UpperTrs<ValueType, IndexType>,
216 matrix::Csr<ValueType, IndexType>>,
217 public Transposable {
218 friend class EnableLinOp<UpperTrs>;
220 friend class LowerTrs<ValueType, IndexType>;
221
222public:
223 using value_type = ValueType;
224 using index_type = IndexType;
226
227 std::unique_ptr<LinOp> transpose() const override;
228
229 std::unique_ptr<LinOp> conj_transpose() const override;
230
232 {
240
245 bool GKO_FACTORY_PARAMETER_SCALAR(unit_diagonal, false);
246
254 algorithm, trisolve_algorithm::sparselib);
255 };
258
265
273
280
287
288protected:
290
291 void apply_impl(const LinOp* b, LinOp* x) const override;
292
293 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
294 LinOp* x) const override;
295
300 void generate();
301
302 explicit UpperTrs(std::shared_ptr<const Executor> exec)
303 : EnableLinOp<UpperTrs>(std::move(exec))
304 {}
305
306 explicit UpperTrs(const Factory* factory,
307 std::shared_ptr<const LinOp> system_matrix)
308 : EnableLinOp<UpperTrs>(factory->get_executor(),
309 gko::transpose(system_matrix->get_size())),
310 EnableSolverBase<UpperTrs<ValueType, IndexType>, CsrMatrix>{
311 copy_and_convert_to<CsrMatrix>(factory->get_executor(),
312 system_matrix)},
313 parameters_{factory->get_parameters()}
314 {
315 this->generate();
316 }
317
318private:
319 std::shared_ptr<solver::SolveStruct> solve_struct_;
320};
321
322
323template <typename ValueType, typename IndexType>
324struct workspace_traits<UpperTrs<ValueType, IndexType>> {
326 // number of vectors used by this workspace
327 static int num_vectors(const Solver&);
328 // number of arrays used by this workspace
329 static int num_arrays(const Solver&);
330 // array containing the num_vectors names for the workspace vectors
331 static std::vector<std::string> op_names(const Solver&);
332 // array containing the num_arrays names for the workspace vectors
333 static std::vector<std::string> array_names(const Solver&);
334 // array containing all varying scalar vectors (independent of problem size)
335 static std::vector<int> scalars(const Solver&);
336 // array containing all varying vectors (dependent on problem size)
337 static std::vector<int> vectors(const Solver&);
338
339 // transposed input vector
340 constexpr static int transposed_b = 0;
341 // transposed output vector
342 constexpr static int transposed_x = 1;
343};
344
345
346} // namespace solver
347} // namespace gko
348
349
350#endif // GKO_PUBLIC_CORE_SOLVER_TRIANGULAR_HPP_
The EnableLinOp mixin can be used to provide sensible default implementations of the majority of the ...
Definition lin_op.hpp:880
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:663
Definition lin_op.hpp:118
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition polymorphic_object.hpp:235
Linear operators which support transposition should implement the Transposable interface.
Definition lin_op.hpp:434
CSR is a matrix format which stores only the nonzero coefficients by compressing each row of the matr...
Definition csr.hpp:117
A LinOp deriving from this CRTP class stores a system matrix.
Definition solver_base.hpp:542
Definition triangular.hpp:106
LowerTrs is the triangular solver which solves the system L x = b, when L is a lower triangular matri...
Definition triangular.hpp:67
LowerTrs & operator=(LowerTrs &&)
Move-constructs a triangular solver.
LowerTrs(const LowerTrs &)
Copy-assigns a triangular solver.
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
LowerTrs & operator=(const LowerTrs &)
Copy-constructs a triangular solver.
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
LowerTrs(LowerTrs &&)
Move-assigns a triangular solver.
Definition triangular.hpp:256
UpperTrs is the triangular solver which solves the system U x = b, when U is an upper triangular matr...
Definition triangular.hpp:217
UpperTrs(UpperTrs &&)
Move-assigns a triangular solver.
UpperTrs(const UpperTrs &)
Copy-assigns a triangular solver.
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
UpperTrs & operator=(UpperTrs &&)
Move-constructs a triangular solver.
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
UpperTrs & operator=(const UpperTrs &)
Copy-constructs a triangular solver.
#define GKO_CREATE_FACTORY_PARAMETERS(_parameters_name, _factory_name)
This Macro will generate a new type containing the parameters for the factory _factory_name.
Definition abstract_factory.hpp:280
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:445
#define GKO_ENABLE_BUILD_METHOD(_factory_name)
Defines a build method for the factory, simplifying its construction by removing the repetitive typin...
Definition abstract_factory.hpp:394
#define GKO_ENABLE_LIN_OP_FACTORY(_lin_op, _parameters_name, _factory_name)
This macro will generate a default implementation of a LinOpFactory for the LinOp subclass it is defi...
Definition lin_op.hpp:1018
trisolve_algorithm
A helper for algorithm selection in the triangular solvers.
Definition triangular.hpp:39
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:775
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:92
Traits class providing information on the type and location of workspace vectors inside a solver.
Definition solver_base.hpp:239