Loading...
Searching...
No Matches
is_diagonal.cpp
1#include <symengine/basic.h>
2#include <symengine/assumptions.h>
3#include <symengine/visitor.h>
4#include <symengine/test_visitors.h>
5
6namespace SymEngine
7{
8
9class MatrixDiagonalVisitor : public BaseVisitor<MatrixDiagonalVisitor>
10{
11private:
12 tribool is_diagonal_;
13 const Assumptions *assumptions_;
14
15public:
16 MatrixDiagonalVisitor(const Assumptions *assumptions)
17 : assumptions_(assumptions)
18 {
19 }
20
21 void bvisit(const Basic &x){};
22 void bvisit(const MatrixExpr &x)
23 {
24 is_diagonal_ = tribool::indeterminate;
25 }
26
27 void bvisit(const IdentityMatrix &x)
28 {
29 is_diagonal_ = tribool::tritrue;
30 }
31
32 void bvisit(const ZeroMatrix &x)
33 {
34 is_diagonal_ = is_square(x, assumptions_);
35 }
36
37 void bvisit(const DiagonalMatrix &x)
38 {
39 is_diagonal_ = tribool::tritrue;
40 }
41
42 void bvisit(const ImmutableDenseMatrix &x)
43 {
44 if (x.nrows() != x.ncols()) {
45 is_diagonal_ = tribool::trifalse;
46 return;
47 }
48 size_t ncols = x.ncols();
49 size_t offset;
50 ZeroVisitor visitor(assumptions_);
51 is_diagonal_ = tribool::tritrue;
52 for (size_t i = 0; i < ncols; i++) {
53 offset = i * ncols;
54 for (size_t j = 0; j < ncols; j++) {
55 if (j != i) {
56 auto &e = x.get_values()[offset];
57 is_diagonal_ = and_tribool(is_diagonal_, visitor.apply(*e));
58 if (is_false(is_diagonal_)) {
59 return;
60 }
61 }
62 offset++;
63 }
64 }
65 }
66
67 void bvisit(const MatrixAdd &x)
68 {
69 bool found_nondiag = false;
70 for (auto &elt : x.get_terms()) {
71 elt->accept(*this);
72 if (is_indeterminate(is_diagonal_)) {
73 return;
74 } else if (is_false(is_diagonal_)) {
75 if (found_nondiag) {
76 return;
77 } else {
78 found_nondiag = true;
79 }
80 }
81 }
82 if (found_nondiag) {
83 is_diagonal_ = tribool::trifalse;
84 } else {
85 is_diagonal_ = tribool::tritrue;
86 }
87 }
88
89 void bvisit(const HadamardProduct &x)
90 {
91 // diag x (diag | nodiag | indeterminate) x ... = diag
92 // (indet | nodiag) x (indet | nodiag) x ... = indeterminate
93 for (auto &elt : x.get_factors()) {
94 elt->accept(*this);
95 if (is_true(is_diagonal_)) {
96 return;
97 }
98 }
99 is_diagonal_ = tribool::indeterminate;
100 }
101
102 tribool apply(const MatrixExpr &s)
103 {
104 s.accept(*this);
105 return is_diagonal_;
106 }
107};
108
109tribool is_diagonal(const MatrixExpr &m, const Assumptions *assumptions)
110{
111 MatrixDiagonalVisitor visitor(assumptions);
112 return visitor.apply(m);
113}
114
115} // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19