Loading...
Searching...
No Matches
is_toeplitz.cpp
1#include <symengine/basic.h>
2#include <symengine/assumptions.h>
3#include <symengine/visitor.h>
4#include <symengine/test_visitors.h>
5
6namespace SymEngine
7{
8
9class MatrixToeplitzVisitor : public BaseVisitor<MatrixToeplitzVisitor>
10{
11private:
12 tribool is_toeplitz_;
13 const Assumptions *assumptions_;
14
15public:
16 MatrixToeplitzVisitor(const Assumptions *assumptions)
17 : assumptions_(assumptions)
18 {
19 }
20
21 void bvisit(const Basic &x){};
22 void bvisit(const MatrixExpr &x)
23 {
24 is_toeplitz_ = tribool::indeterminate;
25 }
26
27 void bvisit(const IdentityMatrix &x)
28 {
29 is_toeplitz_ = tribool::tritrue;
30 }
31
32 void bvisit(const ZeroMatrix &x)
33 {
34 is_toeplitz_ = tribool::tritrue;
35 }
36
37 void bvisit(const DiagonalMatrix &x)
38 {
39 tribool current = tribool::tritrue;
40 auto vec = x.get_container();
41 if (vec.size() == 1) {
42 is_toeplitz_ = tribool::tritrue;
43 return;
44 }
45 auto first = vec[0];
46 for (auto it = vec.begin() + 1; it != vec.end(); ++it) {
47 auto diff = sub(first, *it);
48 tribool next = is_zero(*diff, assumptions_);
49 if (is_false(next)) {
50 is_toeplitz_ = next;
51 return;
52 }
53 current = andwk_tribool(current, next);
54 }
55 is_toeplitz_ = current;
56 }
57
58 void bvisit(const ImmutableDenseMatrix &x)
59 {
60 size_t i_start, j_start, i, j;
61 ZeroVisitor visitor(assumptions_);
62 is_toeplitz_ = tribool::tritrue;
63 // Loop over all diagonals
64 for (size_t w = 0; w < std::max(x.nrows(), x.ncols()) - 1; w++) {
65 // Loop over diagonals starting from the first row and the first
66 // column
67 for (size_t k = 0; k < 2; k++) {
68 if (k == 0 && w <= x.ncols()) {
69 i_start = 0;
70 j_start = w;
71 } else if (k == 1 && w <= x.nrows() && w != 0) {
72 i_start = w;
73 j_start = 0;
74 } else {
75 continue;
76 }
77 auto first = x.get(i_start, j_start);
78 // Loop along the diagonal
79 for (i = i_start + 1, j = j_start + 1;
80 i < x.nrows() && j < x.ncols(); i++, j++) {
81 is_toeplitz_ = and_tribool(
82 is_toeplitz_, visitor.apply(*sub(first, x.get(i, j))));
83 if (is_false(is_toeplitz_)) {
84 return;
85 }
86 }
87 }
88 }
89 }
90
91 tribool apply(const MatrixExpr &s)
92 {
93 s.accept(*this);
94 return is_toeplitz_;
95 }
96};
97
98tribool is_toeplitz(const MatrixExpr &m, const Assumptions *assumptions)
99{
100 MatrixToeplitzVisitor visitor(assumptions);
101 return visitor.apply(m);
102}
103
104} // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
T max(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495