Loading...
Searching...
No Matches
is_square.cpp
1#include <symengine/basic.h>
2#include <symengine/assumptions.h>
3#include <symengine/visitor.h>
4#include <symengine/test_visitors.h>
5
6namespace SymEngine
7{
8
9class MatrixSquareVisitor : public BaseVisitor<MatrixSquareVisitor>
10{
11private:
12 tribool is_square_;
13 const Assumptions *assumptions_;
14
15 void check_vector(const vec_basic &vec)
16 {
17 for (auto &elt : vec) {
18 elt->accept(*this);
19 if (not is_indeterminate(is_square_)) {
20 return;
21 }
22 }
23 }
24
25public:
27 : assumptions_(assumptions)
28 {
29 }
30
31 void bvisit(const Basic &x){};
32 void bvisit(const MatrixExpr &x)
33 {
34 is_square_ = tribool::indeterminate;
35 }
36
37 void bvisit(const IdentityMatrix &x)
38 {
39 is_square_ = tribool::tritrue;
40 }
41
42 void bvisit(const ZeroMatrix &x)
43 {
44 auto diff = sub(x.nrows(), x.ncols());
45 is_square_ = is_zero(*diff, assumptions_);
46 }
47
48 void bvisit(const DiagonalMatrix &x)
49 {
50 is_square_ = tribool::tritrue;
51 }
52
53 void bvisit(const ImmutableDenseMatrix &x)
54 {
55 if (x.nrows() == x.ncols()) {
56 is_square_ = tribool::tritrue;
57 } else {
58 is_square_ = tribool::trifalse;
59 }
60 }
61
62 void bvisit(const MatrixAdd &x)
63 {
64 check_vector(x.get_terms());
65 }
66
67 void bvisit(const HadamardProduct &x)
68 {
69 check_vector(x.get_factors());
70 }
71
72 tribool apply(const MatrixExpr &s)
73 {
74 s.accept(*this);
75 return is_square_;
76 }
77};
78
79tribool is_square(const MatrixExpr &m, const Assumptions *assumptions)
80{
82 return visitor.apply(m);
83}
84
85} // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition basic.h:97
Main namespace for SymEngine package.
Definition add.cpp:19
void hash_combine(hash_t &seed, const T &v)
Definition basic-inl.h:95
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition add.cpp:495