matrix_mul.cpp
1 #include <symengine/mul.h>
2 #include <symengine/add.h>
3 #include <symengine/constants.h>
4 #include <symengine/matrices/matrix_mul.h>
5 #include <symengine/matrices/zero_matrix.h>
6 #include <symengine/matrices/identity_matrix.h>
7 #include <symengine/matrices/diagonal_matrix.h>
8 #include <symengine/matrices/immutable_dense_matrix.h>
9 
10 namespace SymEngine
11 {
12 
13 hash_t MatrixMul::__hash__() const
14 {
15  hash_t seed = SYMENGINE_MATRIXMUL;
16  hash_combine<Basic>(seed, *scalar_);
17  for (const auto &a : factors_) {
18  hash_combine<Basic>(seed, *a);
19  }
20  return seed;
21 }
22 
23 bool MatrixMul::__eq__(const Basic &o) const
24 {
25  if (is_a<MatrixMul>(o)) {
26  const MatrixMul &other = down_cast<const MatrixMul &>(o);
27  if (!eq(*scalar_, *other.scalar_)) {
28  return false;
29  }
30  return unified_eq(factors_, other.factors_);
31  }
32  return false;
33 }
34 
35 int MatrixMul::compare(const Basic &o) const
36 {
37  SYMENGINE_ASSERT(is_a<MatrixMul>(o));
38  const MatrixMul &other = down_cast<const MatrixMul &>(o);
39  int cmp_scalar = scalar_->compare(*other.scalar_);
40  if (cmp_scalar != 0) {
41  return cmp_scalar;
42  }
43  return unified_compare(factors_, other.factors_);
44 }
45 
46 bool MatrixMul::is_canonical(const RCP<const Basic> &scalar,
47  const vec_basic &factors) const
48 {
49  if (factors.size() == 0 || (factors.size() == 1 && eq(*scalar, *one))) {
50  return false;
51  }
52  size_t num_diag = 0;
53  size_t num_dense = 0;
54  for (auto factor : factors) {
55  if (is_a<ZeroMatrix>(*factor) || is_a<IdentityMatrix>(*factor)
56  || is_a<MatrixMul>(*factor)) {
57  return false;
58  } else if (is_a<DiagonalMatrix>(*factor)) {
59  num_diag++;
60  } else if (is_a<ImmutableDenseMatrix>(*factor)) {
61  num_dense++;
62  } else {
63  if (num_diag > 1 || num_dense > 1) {
64  return false;
65  }
66  if (num_diag == 1 && num_dense == 1) {
67  return false;
68  }
69  num_diag = 0;
70  num_dense = 0;
71  }
72  }
73  if (num_diag > 1 || num_dense > 1) {
74  return false;
75  }
76  if (num_diag == 1 && num_dense == 1) {
77  return false;
78  }
79  return true;
80 }
81 
82 RCP<const DiagonalMatrix> mul_diag_diag(const DiagonalMatrix &A,
83  const DiagonalMatrix &B)
84 {
85  auto Avec = A.get_container();
86  auto Bvec = B.get_container();
87  vec_basic product(Avec.size());
88 
89  for (size_t i = 0; i < Avec.size(); i++) {
90  product[i] = mul(Avec[i], Bvec[i]);
91  }
92 
93  return make_rcp<const DiagonalMatrix>(product);
94 }
95 
96 RCP<const ImmutableDenseMatrix> mul_dense_dense(const ImmutableDenseMatrix &A,
97  const ImmutableDenseMatrix &B)
98 {
99  size_t nrows = A.nrows();
100  size_t ncols = B.ncols();
101  auto Avec = A.get_values();
102  auto Bvec = B.get_values();
103  vec_basic product(nrows * ncols);
104 
105  for (size_t i = 0; i < nrows; i++) {
106  for (size_t j = 0; j < ncols; j++) {
107  product[i * ncols + j] = zero;
108  for (size_t k = 0; k < A.ncols(); k++) {
109  product[i * ncols + j]
110  = add(product[i * ncols + j],
111  mul(Avec[i * A.ncols() + k], Bvec[k * ncols + j]));
112  }
113  }
114  }
115  return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
116 }
117 
118 RCP<const ImmutableDenseMatrix> mul_diag_dense(const DiagonalMatrix &A,
119  const ImmutableDenseMatrix &B)
120 {
121  size_t nrows = B.nrows();
122  size_t ncols = B.ncols();
123 
124  vec_basic product(B.get_values());
125 
126  for (size_t i = 0; i < nrows; i++) {
127  auto value = A.get_container()[i];
128  for (size_t j = 0; j < ncols; j++) {
129  product[i * ncols + j] = mul(product[i * ncols + j], value);
130  }
131  }
132  return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
133 }
134 
135 RCP<const ImmutableDenseMatrix> mul_dense_diag(const ImmutableDenseMatrix &A,
136  const DiagonalMatrix &B)
137 {
138  size_t nrows = A.nrows();
139  size_t ncols = A.ncols();
140 
141  vec_basic product(A.get_values());
142 
143  for (size_t j = 0; j < ncols; j++) {
144  auto value = B.get_container()[j];
145  for (size_t i = 0; i < nrows; i++) {
146  product[i * ncols + j] = mul(product[i * ncols + j], value);
147  }
148  }
149  return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
150 }
151 
152 void check_matching_mul_sizes(const vec_basic &vec)
153 {
154  auto first_size = size(down_cast<const MatrixExpr &>(*vec[0]));
155  for (size_t i = 1; i < vec.size(); i++) {
156  auto second_size = size(down_cast<const MatrixExpr &>(*vec[i]));
157  if (first_size.second.is_null() || second_size.first.is_null()) {
158  first_size = second_size;
159  continue;
160  }
161  auto diff = sub(first_size.second, second_size.first);
162  tribool match = is_zero(*diff);
163  if (is_false(match)) {
164  throw DomainError("Matrix dimension mismatch");
165  }
166  first_size = second_size;
167  }
168 }
169 
170 RCP<const MatrixExpr> matrix_mul(const vec_basic &factors)
171 {
172  if (factors.size() == 0) {
173  throw DomainError("Empty product of matrices");
174  }
175  if (factors.size() == 1) {
176  return rcp_static_cast<const MatrixExpr>(factors[0]);
177  }
178 
179  // extract nested MatrixMul and scalars
180  vec_basic expanded;
181  RCP<const Basic> scalar = one;
182  for (auto &factor : factors) {
183  if (is_a<const MatrixMul>(*factor)) {
184  auto container
185  = down_cast<const MatrixMul &>(*factor).get_factors();
186  scalar = mul(scalar,
187  down_cast<const MatrixMul &>(*factor).get_scalar());
188  expanded.insert(expanded.end(), container.begin(), container.end());
189  } else if (is_a_MatrixExpr(*factor)) {
190  expanded.push_back(factor);
191  } else {
192  scalar = mul(scalar, factor);
193  }
194  }
195 
196  check_matching_mul_sizes(expanded);
197 
198  // Handle ZeroMatrix first
199  for (auto &factor : factors) {
200  if (is_a<ZeroMatrix>(*factor)) {
201  return rcp_static_cast<const MatrixExpr>(factor);
202  }
203  }
204 
205  vec_basic keep;
206  RCP<const DiagonalMatrix> diag;
207  RCP<const ImmutableDenseMatrix> dense;
208  RCP<const IdentityMatrix> ident;
209  for (auto &factor : expanded) {
210  if (is_a<IdentityMatrix>(*factor)) {
211  ident = rcp_static_cast<const IdentityMatrix>(factor);
212  } else if (is_a<DiagonalMatrix>(*factor)) {
213  if (!diag.is_null()) {
214  diag = mul_diag_diag(
215  *diag, down_cast<const DiagonalMatrix &>(*factor));
216  } else if (!dense.is_null()) {
217  dense = mul_dense_diag(
218  *dense, down_cast<const DiagonalMatrix &>(*factor));
219  } else {
220  diag = rcp_static_cast<const DiagonalMatrix>(factor);
221  }
222  } else if (is_a<ImmutableDenseMatrix>(*factor)) {
223  if (!dense.is_null()) {
224  dense = mul_dense_dense(
225  *dense, down_cast<const ImmutableDenseMatrix &>(*factor));
226  } else if (!diag.is_null()) {
227  dense = mul_diag_dense(
228  *diag, down_cast<const ImmutableDenseMatrix &>(*factor));
229  diag.reset();
230  } else {
231  dense = rcp_static_cast<const ImmutableDenseMatrix>(factor);
232  }
233  } else {
234  if (!diag.is_null()) {
235  keep.push_back(diag);
236  diag.reset();
237  } else if (!dense.is_null()) {
238  keep.push_back(dense);
239  dense.reset();
240  }
241  keep.push_back(factor);
242  }
243  }
244  if (!diag.is_null()) {
245  keep.push_back(diag);
246  } else if (!dense.is_null()) {
247  keep.push_back(dense);
248  }
249  if (keep.size() == 1 && eq(*scalar, *one)) {
250  return rcp_static_cast<const MatrixExpr>(keep[0]);
251  }
252  if (keep.size() == 0 && !ident.is_null()) {
253  return ident;
254  }
255  return make_rcp<const MatrixMul>(scalar, keep);
256 }
257 
258 } // namespace SymEngine
Classes and functions relating to the binary operation of addition.
The lowest unit of symbolic representation.
Definition: basic.h:97
bool __eq__(const Basic &o) const override
Test equality.
Definition: matrix_mul.cpp:23
hash_t __hash__() const override
Definition: matrix_mul.cpp:13
int compare(const Basic &o) const override
Definition: matrix_mul.cpp:35
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
int factor(const Ptr< RCP< const Integer >> &f, const Integer &n, double B1)
Definition: ntheory.cpp:371
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
int unified_compare(const T &a, const T &b)
Definition: dict.h:205
T size(T... args)