is_symmetric.cpp
1 #include <symengine/basic.h>
2 #include <symengine/assumptions.h>
3 #include <symengine/visitor.h>
4 #include <symengine/test_visitors.h>
5 
6 namespace SymEngine
7 {
8 
9 class MatrixSymmetricVisitor : public BaseVisitor<MatrixSymmetricVisitor>
10 {
11 private:
12  tribool is_symmetric_;
13  const Assumptions *assumptions_;
14 
15  void check_vector(const vec_basic &vec)
16  {
17  bool found_nonsym = false;
18  for (auto &elt : vec) {
19  elt->accept(*this);
20  if (is_indeterminate(is_symmetric_)) {
21  return;
22  } else if (is_false(is_symmetric_)) {
23  if (found_nonsym) {
24  return;
25  } else {
26  found_nonsym = true;
27  }
28  }
29  }
30  if (found_nonsym) {
31  is_symmetric_ = tribool::trifalse;
32  } else {
33  is_symmetric_ = tribool::tritrue;
34  }
35  }
36 
37 public:
38  MatrixSymmetricVisitor(const Assumptions *assumptions)
39  : assumptions_(assumptions)
40  {
41  }
42 
43  void bvisit(const Basic &x){};
44  void bvisit(const MatrixExpr &x)
45  {
46  is_symmetric_ = tribool::indeterminate;
47  }
48 
49  void bvisit(const IdentityMatrix &x)
50  {
51  is_symmetric_ = tribool::tritrue;
52  }
53 
54  void bvisit(const ZeroMatrix &x)
55  {
56  is_symmetric_ = is_square(x, assumptions_);
57  }
58 
59  void bvisit(const DiagonalMatrix &x)
60  {
61  is_symmetric_ = tribool::tritrue;
62  }
63 
64  void bvisit(const ImmutableDenseMatrix &x)
65  {
66  size_t nrows = x.nrows();
67  size_t ncols = x.ncols();
68  if (nrows != ncols) {
69  is_symmetric_ = tribool::trifalse;
70  return;
71  }
72  ZeroVisitor visitor(assumptions_);
73  is_symmetric_ = tribool::tritrue;
74  for (size_t i = 0; i < ncols; i++) {
75  for (size_t j = 0; j <= i; j++) {
76  if (j != i) {
77  auto e1 = x.get(i, j);
78  auto e2 = x.get(j, i);
79  is_symmetric_ = and_tribool(is_symmetric_,
80  visitor.apply(*sub(e1, e2)));
81  }
82  if (is_false(is_symmetric_)) {
83  return;
84  }
85  }
86  }
87  }
88 
89  void bvisit(const MatrixAdd &x)
90  {
91  check_vector(x.get_terms());
92  }
93 
94  void bvisit(const HadamardProduct &x)
95  {
96  check_vector(x.get_factors());
97  }
98 
99  tribool apply(const MatrixExpr &s)
100  {
101  s.accept(*this);
102  return is_symmetric_;
103  }
104 };
105 
106 tribool is_symmetric(const MatrixExpr &m, const Assumptions *assumptions)
107 {
108  MatrixSymmetricVisitor visitor(assumptions);
109  return visitor.apply(m);
110 }
111 
112 } // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495