Loading...
Searching...
No Matches
matrix_mul.cpp
1#include <symengine/mul.h>
2#include <symengine/add.h>
4#include <symengine/matrices/matrix_mul.h>
5#include <symengine/matrices/zero_matrix.h>
6#include <symengine/matrices/identity_matrix.h>
7#include <symengine/matrices/diagonal_matrix.h>
8#include <symengine/matrices/immutable_dense_matrix.h>
9
10namespace SymEngine
11{
12
13hash_t MatrixMul::__hash__() const
14{
15 hash_t seed = SYMENGINE_MATRIXMUL;
16 hash_combine<Basic>(seed, *scalar_);
17 for (const auto &a : factors_) {
18 hash_combine<Basic>(seed, *a);
19 }
20 return seed;
21}
22
23bool MatrixMul::__eq__(const Basic &o) const
24{
25 if (is_a<MatrixMul>(o)) {
26 const MatrixMul &other = down_cast<const MatrixMul &>(o);
27 if (!eq(*scalar_, *other.scalar_)) {
28 return false;
29 }
30 return unified_eq(factors_, other.factors_);
31 }
32 return false;
33}
34
35int MatrixMul::compare(const Basic &o) const
36{
37 SYMENGINE_ASSERT(is_a<MatrixMul>(o));
38 const MatrixMul &other = down_cast<const MatrixMul &>(o);
39 int cmp_scalar = scalar_->compare(*other.scalar_);
40 if (cmp_scalar != 0) {
41 return cmp_scalar;
42 }
43 return unified_compare(factors_, other.factors_);
44}
45
46bool MatrixMul::is_canonical(const RCP<const Basic> &scalar,
47 const vec_basic &factors) const
48{
49 if (factors.size() == 0 || (factors.size() == 1 && eq(*scalar, *one))) {
50 return false;
51 }
52 size_t num_diag = 0;
53 size_t num_dense = 0;
54 for (auto factor : factors) {
55 if (is_a<ZeroMatrix>(*factor) || is_a<IdentityMatrix>(*factor)
56 || is_a<MatrixMul>(*factor)) {
57 return false;
58 } else if (is_a<DiagonalMatrix>(*factor)) {
59 num_diag++;
60 } else if (is_a<ImmutableDenseMatrix>(*factor)) {
61 num_dense++;
62 } else {
63 if (num_diag > 1 || num_dense > 1) {
64 return false;
65 }
66 if (num_diag == 1 && num_dense == 1) {
67 return false;
68 }
69 num_diag = 0;
70 num_dense = 0;
71 }
72 }
73 if (num_diag > 1 || num_dense > 1) {
74 return false;
75 }
76 if (num_diag == 1 && num_dense == 1) {
77 return false;
78 }
79 return true;
80}
81
82RCP<const DiagonalMatrix> mul_diag_diag(const DiagonalMatrix &A,
83 const DiagonalMatrix &B)
84{
85 auto Avec = A.get_container();
86 auto Bvec = B.get_container();
87 vec_basic product(Avec.size());
88
89 for (size_t i = 0; i < Avec.size(); i++) {
90 product[i] = mul(Avec[i], Bvec[i]);
91 }
92
93 return make_rcp<const DiagonalMatrix>(product);
94}
95
96RCP<const ImmutableDenseMatrix> mul_dense_dense(const ImmutableDenseMatrix &A,
97 const ImmutableDenseMatrix &B)
98{
99 size_t nrows = A.nrows();
100 size_t ncols = B.ncols();
101 auto Avec = A.get_values();
102 auto Bvec = B.get_values();
103 vec_basic product(nrows * ncols);
104
105 for (size_t i = 0; i < nrows; i++) {
106 for (size_t j = 0; j < ncols; j++) {
107 product[i * ncols + j] = zero;
108 for (size_t k = 0; k < A.ncols(); k++) {
109 product[i * ncols + j]
110 = add(product[i * ncols + j],
111 mul(Avec[i * A.ncols() + k], Bvec[k * ncols + j]));
112 }
113 }
114 }
115 return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
116}
117
118RCP<const ImmutableDenseMatrix> mul_diag_dense(const DiagonalMatrix &A,
119 const ImmutableDenseMatrix &B)
120{
121 size_t nrows = B.nrows();
122 size_t ncols = B.ncols();
123
124 vec_basic product(B.get_values());
125
126 for (size_t i = 0; i < nrows; i++) {
127 auto value = A.get_container()[i];
128 for (size_t j = 0; j < ncols; j++) {
129 product[i * ncols + j] = mul(product[i * ncols + j], value);
130 }
131 }
132 return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
133}
134
135RCP<const ImmutableDenseMatrix> mul_dense_diag(const ImmutableDenseMatrix &A,
136 const DiagonalMatrix &B)
137{
138 size_t nrows = A.nrows();
139 size_t ncols = A.ncols();
140
141 vec_basic product(A.get_values());
142
143 for (size_t j = 0; j < ncols; j++) {
144 auto value = B.get_container()[j];
145 for (size_t i = 0; i < nrows; i++) {
146 product[i * ncols + j] = mul(product[i * ncols + j], value);
147 }
148 }
149 return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
150}
151
152void check_matching_mul_sizes(const vec_basic &vec)
153{
154 auto first_size = size(down_cast<const MatrixExpr &>(*vec[0]));
155 for (size_t i = 1; i < vec.size(); i++) {
156 auto second_size = size(down_cast<const MatrixExpr &>(*vec[i]));
157 if (first_size.second.is_null() || second_size.first.is_null()) {
158 first_size = second_size;
159 continue;
160 }
161 auto diff = sub(first_size.second, second_size.first);
162 tribool match = is_zero(*diff);
163 if (is_false(match)) {
164 throw DomainError("Matrix dimension mismatch");
165 }
166 first_size = second_size;
167 }
168}
169
170RCP<const MatrixExpr> matrix_mul(const vec_basic &factors)
171{
172 if (factors.size() == 0) {
173 throw DomainError("Empty product of matrices");
174 }
175 if (factors.size() == 1) {
176 return rcp_static_cast<const MatrixExpr>(factors[0]);
177 }
178
179 // extract nested MatrixMul and scalars
180 vec_basic expanded;
181 RCP<const Basic> scalar = one;
182 for (auto &factor : factors) {
183 if (is_a<const MatrixMul>(*factor)) {
184 auto container
185 = down_cast<const MatrixMul &>(*factor).get_factors();
186 scalar = mul(scalar,
187 down_cast<const MatrixMul &>(*factor).get_scalar());
188 expanded.insert(expanded.end(), container.begin(), container.end());
189 } else if (is_a_MatrixExpr(*factor)) {
190 expanded.push_back(factor);
191 } else {
192 scalar = mul(scalar, factor);
193 }
194 }
195
196 check_matching_mul_sizes(expanded);
197
198 // Handle ZeroMatrix first
199 for (auto &factor : factors) {
200 if (is_a<ZeroMatrix>(*factor)) {
201 return rcp_static_cast<const MatrixExpr>(factor);
202 }
203 }
204
205 vec_basic keep;
206 RCP<const DiagonalMatrix> diag;
207 RCP<const ImmutableDenseMatrix> dense;
208 RCP<const IdentityMatrix> ident;
209 for (auto &factor : expanded) {
210 if (is_a<IdentityMatrix>(*factor)) {
211 ident = rcp_static_cast<const IdentityMatrix>(factor);
212 } else if (is_a<DiagonalMatrix>(*factor)) {
213 if (!diag.is_null()) {
214 diag = mul_diag_diag(
215 *diag, down_cast<const DiagonalMatrix &>(*factor));
216 } else if (!dense.is_null()) {
217 dense = mul_dense_diag(
218 *dense, down_cast<const DiagonalMatrix &>(*factor));
219 } else {
220 diag = rcp_static_cast<const DiagonalMatrix>(factor);
221 }
222 } else if (is_a<ImmutableDenseMatrix>(*factor)) {
223 if (!dense.is_null()) {
224 dense = mul_dense_dense(
225 *dense, down_cast<const ImmutableDenseMatrix &>(*factor));
226 } else if (!diag.is_null()) {
227 dense = mul_diag_dense(
228 *diag, down_cast<const ImmutableDenseMatrix &>(*factor));
229 diag.reset();
230 } else {
231 dense = rcp_static_cast<const ImmutableDenseMatrix>(factor);
232 }
233 } else {
234 if (!diag.is_null()) {
235 keep.push_back(diag);
236 diag.reset();
237 } else if (!dense.is_null()) {
238 keep.push_back(dense);
239 dense.reset();
240 }
241 keep.push_back(factor);
242 }
243 }
244 if (!diag.is_null()) {
245 keep.push_back(diag);
246 } else if (!dense.is_null()) {
247 keep.push_back(dense);
248 }
249 if (keep.size() == 1 && eq(*scalar, *one)) {
250 return rcp_static_cast<const MatrixExpr>(keep[0]);
251 }
252 if (keep.size() == 0 && !ident.is_null()) {
253 return ident;
254 }
255 return make_rcp<const MatrixMul>(scalar, keep);
256}
257
258} // namespace SymEngine
Classes and functions relating to the binary operation of addition.
The lowest unit of symbolic representation.
Definition: basic.h:97
bool __eq__(const Basic &o) const override
Test equality.
Definition: matrix_mul.cpp:23
hash_t __hash__() const override
Definition: matrix_mul.cpp:13
int compare(const Basic &o) const override
Definition: matrix_mul.cpp:35
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
int factor(const Ptr< RCP< const Integer > > &f, const Integer &n, double B1)
Definition: ntheory.cpp:370
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
int unified_compare(const T &a, const T &b)
Definition: dict.h:205
T size(T... args)