Loading...
Searching...
No Matches
is_symmetric.cpp
1#include <symengine/basic.h>
2#include <symengine/assumptions.h>
3#include <symengine/visitor.h>
4#include <symengine/test_visitors.h>
5
6namespace SymEngine
7{
8
9class MatrixSymmetricVisitor : public BaseVisitor<MatrixSymmetricVisitor>
10{
11private:
12 tribool is_symmetric_;
13 const Assumptions *assumptions_;
14
15 void check_vector(const vec_basic &vec)
16 {
17 bool found_nonsym = false;
18 for (auto &elt : vec) {
19 elt->accept(*this);
20 if (is_indeterminate(is_symmetric_)) {
21 return;
22 } else if (is_false(is_symmetric_)) {
23 if (found_nonsym) {
24 return;
25 } else {
26 found_nonsym = true;
27 }
28 }
29 }
30 if (found_nonsym) {
31 is_symmetric_ = tribool::trifalse;
32 } else {
33 is_symmetric_ = tribool::tritrue;
34 }
35 }
36
37public:
38 MatrixSymmetricVisitor(const Assumptions *assumptions)
39 : assumptions_(assumptions)
40 {
41 }
42
43 void bvisit(const Basic &x){};
44 void bvisit(const MatrixExpr &x)
45 {
46 is_symmetric_ = tribool::indeterminate;
47 }
48
49 void bvisit(const IdentityMatrix &x)
50 {
51 is_symmetric_ = tribool::tritrue;
52 }
53
54 void bvisit(const ZeroMatrix &x)
55 {
56 is_symmetric_ = is_square(x, assumptions_);
57 }
58
59 void bvisit(const DiagonalMatrix &x)
60 {
61 is_symmetric_ = tribool::tritrue;
62 }
63
64 void bvisit(const ImmutableDenseMatrix &x)
65 {
66 size_t nrows = x.nrows();
67 size_t ncols = x.ncols();
68 if (nrows != ncols) {
69 is_symmetric_ = tribool::trifalse;
70 return;
71 }
72 ZeroVisitor visitor(assumptions_);
73 is_symmetric_ = tribool::tritrue;
74 for (size_t i = 0; i < ncols; i++) {
75 for (size_t j = 0; j <= i; j++) {
76 if (j != i) {
77 auto e1 = x.get(i, j);
78 auto e2 = x.get(j, i);
79 is_symmetric_ = and_tribool(is_symmetric_,
80 visitor.apply(*sub(e1, e2)));
81 }
82 if (is_false(is_symmetric_)) {
83 return;
84 }
85 }
86 }
87 }
88
89 void bvisit(const MatrixAdd &x)
90 {
91 check_vector(x.get_terms());
92 }
93
94 void bvisit(const HadamardProduct &x)
95 {
96 check_vector(x.get_factors());
97 }
98
99 tribool apply(const MatrixExpr &s)
100 {
101 s.accept(*this);
102 return is_symmetric_;
103 }
104};
105
106tribool is_symmetric(const MatrixExpr &m, const Assumptions *assumptions)
107{
108 MatrixSymmetricVisitor visitor(assumptions);
109 return visitor.apply(m);
110}
111
112} // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495