4 #include <symengine/matrices/matrix_mul.h>
5 #include <symengine/matrices/zero_matrix.h>
6 #include <symengine/matrices/identity_matrix.h>
7 #include <symengine/matrices/diagonal_matrix.h>
8 #include <symengine/matrices/immutable_dense_matrix.h>
15 hash_t seed = SYMENGINE_MATRIXMUL;
16 hash_combine<Basic>(seed, *scalar_);
17 for (
const auto &a : factors_) {
18 hash_combine<Basic>(seed, *a);
25 if (is_a<MatrixMul>(o)) {
26 const MatrixMul &other = down_cast<const MatrixMul &>(o);
27 if (!
eq(*scalar_, *other.scalar_)) {
30 return unified_eq(factors_, other.factors_);
37 SYMENGINE_ASSERT(is_a<MatrixMul>(o));
38 const MatrixMul &other = down_cast<const MatrixMul &>(o);
39 int cmp_scalar = scalar_->compare(*other.scalar_);
40 if (cmp_scalar != 0) {
46 bool MatrixMul::is_canonical(
const RCP<const Basic> &scalar,
49 if (factors.
size() == 0 || (factors.
size() == 1 &&
eq(*scalar, *one))) {
54 for (
auto factor : factors) {
55 if (is_a<ZeroMatrix>(*
factor) || is_a<IdentityMatrix>(*
factor)
56 || is_a<MatrixMul>(*
factor)) {
58 }
else if (is_a<DiagonalMatrix>(*
factor)) {
60 }
else if (is_a<ImmutableDenseMatrix>(*
factor)) {
63 if (num_diag > 1 || num_dense > 1) {
66 if (num_diag == 1 && num_dense == 1) {
73 if (num_diag > 1 || num_dense > 1) {
76 if (num_diag == 1 && num_dense == 1) {
82 RCP<const DiagonalMatrix> mul_diag_diag(
const DiagonalMatrix &A,
83 const DiagonalMatrix &B)
85 auto Avec = A.get_container();
86 auto Bvec = B.get_container();
87 vec_basic product(Avec.size());
89 for (
size_t i = 0; i < Avec.size(); i++) {
90 product[i] =
mul(Avec[i], Bvec[i]);
93 return make_rcp<const DiagonalMatrix>(product);
96 RCP<const ImmutableDenseMatrix> mul_dense_dense(
const ImmutableDenseMatrix &A,
97 const ImmutableDenseMatrix &B)
99 size_t nrows = A.nrows();
100 size_t ncols = B.ncols();
101 auto Avec = A.get_values();
102 auto Bvec = B.get_values();
103 vec_basic product(nrows * ncols);
105 for (
size_t i = 0; i < nrows; i++) {
106 for (
size_t j = 0; j < ncols; j++) {
107 product[i * ncols + j] = zero;
108 for (
size_t k = 0; k < A.ncols(); k++) {
109 product[i * ncols + j]
110 =
add(product[i * ncols + j],
111 mul(Avec[i * A.ncols() + k], Bvec[k * ncols + j]));
115 return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
118 RCP<const ImmutableDenseMatrix> mul_diag_dense(
const DiagonalMatrix &A,
119 const ImmutableDenseMatrix &B)
121 size_t nrows = B.nrows();
122 size_t ncols = B.ncols();
124 vec_basic product(B.get_values());
126 for (
size_t i = 0; i < nrows; i++) {
127 auto value = A.get_container()[i];
128 for (
size_t j = 0; j < ncols; j++) {
129 product[i * ncols + j] =
mul(product[i * ncols + j], value);
132 return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
135 RCP<const ImmutableDenseMatrix> mul_dense_diag(
const ImmutableDenseMatrix &A,
136 const DiagonalMatrix &B)
138 size_t nrows = A.nrows();
139 size_t ncols = A.ncols();
141 vec_basic product(A.get_values());
143 for (
size_t j = 0; j < ncols; j++) {
144 auto value = B.get_container()[j];
145 for (
size_t i = 0; i < nrows; i++) {
146 product[i * ncols + j] =
mul(product[i * ncols + j], value);
149 return make_rcp<const ImmutableDenseMatrix>(nrows, ncols, product);
152 void check_matching_mul_sizes(
const vec_basic &vec)
154 auto first_size = size(down_cast<const MatrixExpr &>(*vec[0]));
155 for (
size_t i = 1; i < vec.size(); i++) {
156 auto second_size = size(down_cast<const MatrixExpr &>(*vec[i]));
157 if (first_size.second.is_null() || second_size.first.is_null()) {
158 first_size = second_size;
161 auto diff =
sub(first_size.second, second_size.first);
162 tribool match = is_zero(*diff);
163 if (is_false(match)) {
164 throw DomainError(
"Matrix dimension mismatch");
166 first_size = second_size;
170 RCP<const MatrixExpr> matrix_mul(
const vec_basic &factors)
172 if (factors.size() == 0) {
173 throw DomainError(
"Empty product of matrices");
175 if (factors.size() == 1) {
176 return rcp_static_cast<const MatrixExpr>(factors[0]);
181 RCP<const Basic> scalar = one;
182 for (
auto &
factor : factors) {
183 if (is_a<const MatrixMul>(*
factor)) {
185 = down_cast<const MatrixMul &>(*factor).get_factors();
187 down_cast<const MatrixMul &>(*factor).get_scalar());
188 expanded.insert(expanded.end(), container.begin(), container.end());
189 }
else if (is_a_MatrixExpr(*
factor)) {
190 expanded.push_back(
factor);
196 check_matching_mul_sizes(expanded);
199 for (
auto &
factor : factors) {
200 if (is_a<ZeroMatrix>(*
factor)) {
201 return rcp_static_cast<const MatrixExpr>(
factor);
206 RCP<const DiagonalMatrix> diag;
207 RCP<const ImmutableDenseMatrix> dense;
208 RCP<const IdentityMatrix> ident;
209 for (
auto &
factor : expanded) {
210 if (is_a<IdentityMatrix>(*
factor)) {
211 ident = rcp_static_cast<const IdentityMatrix>(
factor);
212 }
else if (is_a<DiagonalMatrix>(*
factor)) {
213 if (!diag.is_null()) {
214 diag = mul_diag_diag(
215 *diag, down_cast<const DiagonalMatrix &>(*
factor));
216 }
else if (!dense.is_null()) {
217 dense = mul_dense_diag(
218 *dense, down_cast<const DiagonalMatrix &>(*
factor));
220 diag = rcp_static_cast<const DiagonalMatrix>(
factor);
222 }
else if (is_a<ImmutableDenseMatrix>(*
factor)) {
223 if (!dense.is_null()) {
224 dense = mul_dense_dense(
225 *dense, down_cast<const ImmutableDenseMatrix &>(*
factor));
226 }
else if (!diag.is_null()) {
227 dense = mul_diag_dense(
228 *diag, down_cast<const ImmutableDenseMatrix &>(*
factor));
231 dense = rcp_static_cast<const ImmutableDenseMatrix>(
factor);
234 if (!diag.is_null()) {
235 keep.push_back(diag);
237 }
else if (!dense.is_null()) {
238 keep.push_back(dense);
244 if (!diag.is_null()) {
245 keep.push_back(diag);
246 }
else if (!dense.is_null()) {
247 keep.push_back(dense);
249 if (keep.size() == 1 &&
eq(*scalar, *one)) {
250 return rcp_static_cast<const MatrixExpr>(keep[0]);
252 if (keep.size() == 0 && !ident.is_null()) {
255 return make_rcp<const MatrixMul>(scalar, keep);
Classes and functions relating to the binary operation of addition.
The lowest unit of symbolic representation.
bool __eq__(const Basic &o) const override
Test equality.
hash_t __hash__() const override
int compare(const Basic &o) const override
Main namespace for SymEngine package.
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
int factor(const Ptr< RCP< const Integer >> &f, const Integer &n, double B1)
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
int unified_compare(const T &a, const T &b)