size.cpp
1 #include <symengine/visitor.h>
2 #include <symengine/test_visitors.h>
3 #include <symengine/matrices/size.h>
4 
5 namespace SymEngine
6 {
7 
8 class MatrixSizeVisitor : public BaseVisitor<MatrixSizeVisitor>
9 {
10 private:
11  RCP<const Basic> nrows_;
12  RCP<const Basic> ncols_;
13 
14  void all_same_size(const vec_basic &vec)
15  {
16  vec[0]->accept(*this);
17  auto rows = nrows_;
18  auto cols = ncols_;
19  if (!rows.is_null() && !cols.is_null() && is_a<Integer>(*rows)
20  && is_a<Integer>(*cols)) {
21  return;
22  }
23  // Priority order of type of nrows and ncols:
24  // 1. integer
25  // 2. other expressions
26  // 3. nullptr (meaning unknown)
27  // Note that all elements must have known same size or indeterminate
28  // size diff because of canonicalization
29  for (size_t i = 1; i < vec.size(); i++) {
30  vec[i]->accept(*this);
31  if ((!nrows_.is_null() && is_a<Integer>(*nrows_))
32  || (rows.is_null() && !nrows_.is_null())) {
33  rows = nrows_;
34  }
35  if ((!ncols_.is_null() && is_a<Integer>(*ncols_))
36  || (cols.is_null() && !ncols_.is_null())) {
37  cols = ncols_;
38  }
39  if (!rows.is_null() && !cols.is_null() && is_a<Integer>(*rows)
40  && is_a<Integer>(*cols)) {
41  break;
42  }
43  }
44  nrows_ = rows;
45  ncols_ = cols;
46  }
47 
48 public:
50 
51  void bvisit(const Basic &x)
52  {
53  nrows_.reset();
54  ncols_.reset();
55  }
56 
57  void bvisit(const IdentityMatrix &x)
58  {
59  nrows_ = x.size();
60  ncols_ = x.size();
61  }
62 
63  void bvisit(const ZeroMatrix &x)
64  {
65  nrows_ = x.nrows();
66  ncols_ = x.ncols();
67  }
68 
69  void bvisit(const MatrixSymbol &x)
70  {
71  nrows_.reset();
72  ncols_.reset();
73  }
74 
75  void bvisit(const DiagonalMatrix &x)
76  {
77  nrows_ = integer(x.get_container().size());
78  ncols_ = nrows_;
79  }
80 
81  void bvisit(const ImmutableDenseMatrix &x)
82  {
83  nrows_ = integer(x.nrows());
84  ncols_ = integer(x.ncols());
85  }
86 
87  void bvisit(const MatrixAdd &x)
88  {
89  auto vec = x.get_terms();
90  all_same_size(vec);
91  }
92 
93  void bvisit(const HadamardProduct &x)
94  {
95  auto vec = x.get_factors();
96  all_same_size(vec);
97  }
98 
99  void bvisit(const MatrixMul &x)
100  {
101  auto vec = x.get_factors();
102  vec[0]->accept(*this);
103  auto row = nrows_;
104  vec.back()->accept(*this);
105  nrows_ = row;
106  }
107 
108  std::pair<RCP<const Basic>, RCP<const Basic>> apply(const MatrixExpr &s)
109  {
110  s.accept(*this);
111  return std::make_pair(nrows_, ncols_);
112  }
113 };
114 
115 std::pair<RCP<const Basic>, RCP<const Basic>> size(const MatrixExpr &m)
116 {
117  MatrixSizeVisitor visitor;
118  return visitor.apply(m);
119 }
120 
121 } // namespace SymEngine
T back(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
T make_pair(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
T size(T... args)