Loading...
Searching...
No Matches
size.cpp
1#include <symengine/visitor.h>
2#include <symengine/test_visitors.h>
3#include <symengine/matrices/size.h>
4
5namespace SymEngine
6{
7
8class MatrixSizeVisitor : public BaseVisitor<MatrixSizeVisitor>
9{
10private:
11 RCP<const Basic> nrows_;
12 RCP<const Basic> ncols_;
13
14 void all_same_size(const vec_basic &vec)
15 {
16 vec[0]->accept(*this);
17 auto rows = nrows_;
18 auto cols = ncols_;
19 if (!rows.is_null() && !cols.is_null() && is_a<Integer>(*rows)
20 && is_a<Integer>(*cols)) {
21 return;
22 }
23 // Priority order of type of nrows and ncols:
24 // 1. integer
25 // 2. other expressions
26 // 3. nullptr (meaning unknown)
27 // Note that all elements must have known same size or indeterminate
28 // size diff because of canonicalization
29 for (size_t i = 1; i < vec.size(); i++) {
30 vec[i]->accept(*this);
31 if ((!nrows_.is_null() && is_a<Integer>(*nrows_))
32 || (rows.is_null() && !nrows_.is_null())) {
33 rows = nrows_;
34 }
35 if ((!ncols_.is_null() && is_a<Integer>(*ncols_))
36 || (cols.is_null() && !ncols_.is_null())) {
37 cols = ncols_;
38 }
39 if (!rows.is_null() && !cols.is_null() && is_a<Integer>(*rows)
40 && is_a<Integer>(*cols)) {
41 break;
42 }
43 }
44 nrows_ = rows;
45 ncols_ = cols;
46 }
47
48public:
50
51 void bvisit(const Basic &x)
52 {
53 nrows_.reset();
54 ncols_.reset();
55 }
56
57 void bvisit(const IdentityMatrix &x)
58 {
59 nrows_ = x.size();
60 ncols_ = x.size();
61 }
62
63 void bvisit(const ZeroMatrix &x)
64 {
65 nrows_ = x.nrows();
66 ncols_ = x.ncols();
67 }
68
69 void bvisit(const MatrixSymbol &x)
70 {
71 nrows_.reset();
72 ncols_.reset();
73 }
74
75 void bvisit(const DiagonalMatrix &x)
76 {
77 nrows_ = integer(x.get_container().size());
78 ncols_ = nrows_;
79 }
80
81 void bvisit(const ImmutableDenseMatrix &x)
82 {
83 nrows_ = integer(x.nrows());
84 ncols_ = integer(x.ncols());
85 }
86
87 void bvisit(const MatrixAdd &x)
88 {
89 auto vec = x.get_terms();
90 all_same_size(vec);
91 }
92
93 void bvisit(const HadamardProduct &x)
94 {
95 auto vec = x.get_factors();
96 all_same_size(vec);
97 }
98
99 void bvisit(const MatrixMul &x)
100 {
101 auto vec = x.get_factors();
102 vec[0]->accept(*this);
103 auto row = nrows_;
104 vec.back()->accept(*this);
105 nrows_ = row;
106 }
107
108 std::pair<RCP<const Basic>, RCP<const Basic>> apply(const MatrixExpr &s)
109 {
110 s.accept(*this);
111 return std::make_pair(nrows_, ncols_);
112 }
113};
114
115std::pair<RCP<const Basic>, RCP<const Basic>> size(const MatrixExpr &m)
116{
117 MatrixSizeVisitor visitor;
118 return visitor.apply(m);
119}
120
121} // namespace SymEngine
T back(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
T make_pair(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
std::enable_if< std::is_integral< T >::value, RCP< constInteger > >::type integer(T i)
Definition: integer.h:197
T size(T... args)