is_diagonal.cpp
1 #include <symengine/basic.h>
2 #include <symengine/assumptions.h>
3 #include <symengine/visitor.h>
4 #include <symengine/test_visitors.h>
5 
6 namespace SymEngine
7 {
8 
9 class MatrixDiagonalVisitor : public BaseVisitor<MatrixDiagonalVisitor>
10 {
11 private:
12  tribool is_diagonal_;
13  const Assumptions *assumptions_;
14 
15 public:
16  MatrixDiagonalVisitor(const Assumptions *assumptions)
17  : assumptions_(assumptions)
18  {
19  }
20 
21  void bvisit(const Basic &x){};
22  void bvisit(const MatrixExpr &x)
23  {
24  is_diagonal_ = tribool::indeterminate;
25  }
26 
27  void bvisit(const IdentityMatrix &x)
28  {
29  is_diagonal_ = tribool::tritrue;
30  }
31 
32  void bvisit(const ZeroMatrix &x)
33  {
34  is_diagonal_ = is_square(x, assumptions_);
35  }
36 
37  void bvisit(const DiagonalMatrix &x)
38  {
39  is_diagonal_ = tribool::tritrue;
40  }
41 
42  void bvisit(const ImmutableDenseMatrix &x)
43  {
44  if (x.nrows() != x.ncols()) {
45  is_diagonal_ = tribool::trifalse;
46  return;
47  }
48  size_t ncols = x.ncols();
49  size_t offset;
50  ZeroVisitor visitor(assumptions_);
51  is_diagonal_ = tribool::tritrue;
52  for (size_t i = 0; i < ncols; i++) {
53  offset = i * ncols;
54  for (size_t j = 0; j < ncols; j++) {
55  if (j != i) {
56  auto &e = x.get_values()[offset];
57  is_diagonal_ = and_tribool(is_diagonal_, visitor.apply(*e));
58  if (is_false(is_diagonal_)) {
59  return;
60  }
61  }
62  offset++;
63  }
64  }
65  }
66 
67  void bvisit(const MatrixAdd &x)
68  {
69  bool found_nondiag = false;
70  for (auto &elt : x.get_terms()) {
71  elt->accept(*this);
72  if (is_indeterminate(is_diagonal_)) {
73  return;
74  } else if (is_false(is_diagonal_)) {
75  if (found_nondiag) {
76  return;
77  } else {
78  found_nondiag = true;
79  }
80  }
81  }
82  if (found_nondiag) {
83  is_diagonal_ = tribool::trifalse;
84  } else {
85  is_diagonal_ = tribool::tritrue;
86  }
87  }
88 
89  void bvisit(const HadamardProduct &x)
90  {
91  // diag x (diag | nodiag | indeterminate) x ... = diag
92  // (indet | nodiag) x (indet | nodiag) x ... = indeterminate
93  for (auto &elt : x.get_factors()) {
94  elt->accept(*this);
95  if (is_true(is_diagonal_)) {
96  return;
97  }
98  }
99  is_diagonal_ = tribool::indeterminate;
100  }
101 
102  tribool apply(const MatrixExpr &s)
103  {
104  s.accept(*this);
105  return is_diagonal_;
106  }
107 };
108 
109 tribool is_diagonal(const MatrixExpr &m, const Assumptions *assumptions)
110 {
111  MatrixDiagonalVisitor visitor(assumptions);
112  return visitor.apply(m);
113 }
114 
115 } // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19