is_square.cpp
1 #include <symengine/basic.h>
2 #include <symengine/assumptions.h>
3 #include <symengine/visitor.h>
4 #include <symengine/test_visitors.h>
5 
6 namespace SymEngine
7 {
8 
9 class MatrixSquareVisitor : public BaseVisitor<MatrixSquareVisitor>
10 {
11 private:
12  tribool is_square_;
13  const Assumptions *assumptions_;
14 
15  void check_vector(const vec_basic &vec)
16  {
17  for (auto &elt : vec) {
18  elt->accept(*this);
19  if (not is_indeterminate(is_square_)) {
20  return;
21  }
22  }
23  }
24 
25 public:
26  MatrixSquareVisitor(const Assumptions *assumptions)
27  : assumptions_(assumptions)
28  {
29  }
30 
31  void bvisit(const Basic &x){};
32  void bvisit(const MatrixExpr &x)
33  {
34  is_square_ = tribool::indeterminate;
35  }
36 
37  void bvisit(const IdentityMatrix &x)
38  {
39  is_square_ = tribool::tritrue;
40  }
41 
42  void bvisit(const ZeroMatrix &x)
43  {
44  auto diff = sub(x.nrows(), x.ncols());
45  is_square_ = is_zero(*diff, assumptions_);
46  }
47 
48  void bvisit(const DiagonalMatrix &x)
49  {
50  is_square_ = tribool::tritrue;
51  }
52 
53  void bvisit(const ImmutableDenseMatrix &x)
54  {
55  if (x.nrows() == x.ncols()) {
56  is_square_ = tribool::tritrue;
57  } else {
58  is_square_ = tribool::trifalse;
59  }
60  }
61 
62  void bvisit(const MatrixAdd &x)
63  {
64  check_vector(x.get_terms());
65  }
66 
67  void bvisit(const HadamardProduct &x)
68  {
69  check_vector(x.get_factors());
70  }
71 
72  tribool apply(const MatrixExpr &s)
73  {
74  s.accept(*this);
75  return is_square_;
76  }
77 };
78 
79 tribool is_square(const MatrixExpr &m, const Assumptions *assumptions)
80 {
81  MatrixSquareVisitor visitor(assumptions);
82  return visitor.apply(m);
83 }
84 
85 } // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495