is_toeplitz.cpp
1 #include <symengine/basic.h>
2 #include <symengine/assumptions.h>
3 #include <symengine/visitor.h>
4 #include <symengine/test_visitors.h>
5 
6 namespace SymEngine
7 {
8 
9 class MatrixToeplitzVisitor : public BaseVisitor<MatrixToeplitzVisitor>
10 {
11 private:
12  tribool is_toeplitz_;
13  const Assumptions *assumptions_;
14 
15 public:
16  MatrixToeplitzVisitor(const Assumptions *assumptions)
17  : assumptions_(assumptions)
18  {
19  }
20 
21  void bvisit(const Basic &x){};
22  void bvisit(const MatrixExpr &x)
23  {
24  is_toeplitz_ = tribool::indeterminate;
25  }
26 
27  void bvisit(const IdentityMatrix &x)
28  {
29  is_toeplitz_ = tribool::tritrue;
30  }
31 
32  void bvisit(const ZeroMatrix &x)
33  {
34  is_toeplitz_ = tribool::tritrue;
35  }
36 
37  void bvisit(const DiagonalMatrix &x)
38  {
39  tribool current = tribool::tritrue;
40  auto vec = x.get_container();
41  if (vec.size() == 1) {
42  is_toeplitz_ = tribool::tritrue;
43  return;
44  }
45  auto first = vec[0];
46  for (auto it = vec.begin() + 1; it != vec.end(); ++it) {
47  auto diff = sub(first, *it);
48  tribool next = is_zero(*diff, assumptions_);
49  if (is_false(next)) {
50  is_toeplitz_ = next;
51  return;
52  }
53  current = andwk_tribool(current, next);
54  }
55  is_toeplitz_ = current;
56  }
57 
58  void bvisit(const ImmutableDenseMatrix &x)
59  {
60  size_t i_start, j_start, i, j;
61  ZeroVisitor visitor(assumptions_);
62  is_toeplitz_ = tribool::tritrue;
63  // Loop over all diagonals
64  for (size_t w = 0; w < std::max(x.nrows(), x.ncols()) - 1; w++) {
65  // Loop over diagonals starting from the first row and the first
66  // column
67  for (size_t k = 0; k < 2; k++) {
68  if (k == 0 && w <= x.ncols()) {
69  i_start = 0;
70  j_start = w;
71  } else if (k == 1 && w <= x.nrows() && w != 0) {
72  i_start = w;
73  j_start = 0;
74  } else {
75  continue;
76  }
77  auto first = x.get(i_start, j_start);
78  // Loop along the diagonal
79  for (i = i_start + 1, j = j_start + 1;
80  i < x.nrows() && j < x.ncols(); i++, j++) {
81  is_toeplitz_ = and_tribool(
82  is_toeplitz_, visitor.apply(*sub(first, x.get(i, j))));
83  if (is_false(is_toeplitz_)) {
84  return;
85  }
86  }
87  }
88  }
89  }
90 
91  tribool apply(const MatrixExpr &s)
92  {
93  s.accept(*this);
94  return is_toeplitz_;
95  }
96 };
97 
98 tribool is_toeplitz(const MatrixExpr &m, const Assumptions *assumptions)
99 {
100  MatrixToeplitzVisitor visitor(assumptions);
101  return visitor.apply(m);
102 }
103 
104 } // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
T max(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495