2 #include <symengine/matrices/matrix_add.h>
3 #include <symengine/matrices/zero_matrix.h>
4 #include <symengine/matrices/diagonal_matrix.h>
5 #include <symengine/matrices/immutable_dense_matrix.h>
12 hash_t seed = SYMENGINE_MATRIXADD;
13 for (
const auto &a : terms_) {
14 hash_combine<Basic>(seed, *a);
21 if (is_a<MatrixAdd>(o)) {
22 const MatrixAdd &other = down_cast<const MatrixAdd &>(o);
23 return unified_eq(terms_, other.terms_);
30 SYMENGINE_ASSERT(is_a<MatrixAdd>(o));
31 const MatrixAdd &other = down_cast<const MatrixAdd &>(o);
35 bool MatrixAdd::is_canonical(
const vec_basic &terms)
const
37 if (terms.
size() < 2) {
42 for (
auto term : terms) {
43 if (is_a<ZeroMatrix>(*term) || is_a<MatrixAdd>(*term)) {
45 }
else if (is_a<DiagonalMatrix>(*term)) {
47 }
else if (is_a<ImmutableDenseMatrix>(*term)) {
51 if (num_diag > 1 || num_dense > 1) {
54 if (num_diag == 1 && num_dense == 1) {
60 void check_matching_sizes(
const vec_basic &vec)
62 for (
size_t i = 0; i < vec.size() - 1; i++) {
63 auto first_size = size(down_cast<const MatrixExpr &>(*vec[i]));
64 if (first_size.first.is_null()) {
67 for (
size_t j = 1; j < vec.size(); j++) {
68 auto second_size = size(down_cast<const MatrixExpr &>(*vec[j]));
69 if (second_size.first.is_null()) {
72 auto rowdiff =
sub(first_size.first, second_size.first);
73 tribool rowmatch = is_zero(*rowdiff);
74 if (is_false(rowmatch)) {
75 throw DomainError(
"Matrix dimension mismatch");
77 auto coldiff =
sub(first_size.second, second_size.second);
78 tribool colmatch = is_zero(*coldiff);
79 if (is_false(colmatch)) {
80 throw DomainError(
"Matrix dimension mismatch");
86 RCP<const MatrixExpr> matrix_add(
const vec_basic &terms)
88 if (terms.size() == 0) {
89 throw DomainError(
"Empty sum of matrices");
91 if (terms.size() == 1) {
92 return rcp_static_cast<const MatrixExpr>(terms[0]);
96 for (
auto &term : terms) {
97 if (is_a<const MatrixAdd>(*term)) {
98 auto container = down_cast<const MatrixAdd &>(*term).get_terms();
99 expanded.
insert(expanded.end(), container.begin(), container.end());
101 expanded.push_back(term);
104 check_matching_sizes(expanded);
106 RCP<const DiagonalMatrix> diag;
107 RCP<const ImmutableDenseMatrix> dense;
108 RCP<const ZeroMatrix> zero;
109 for (
auto &term : expanded) {
110 if (is_a<ZeroMatrix>(*term)) {
111 zero = rcp_static_cast<const ZeroMatrix>(term);
112 }
else if (is_a<DiagonalMatrix>(*term)) {
113 if (diag.is_null()) {
114 diag = rcp_static_cast<const DiagonalMatrix>(term);
117 for (
size_t i = 0; i < diag->get_container().size(); i++) {
119 add(diag->get_container()[i],
120 down_cast<const DiagonalMatrix &>(*term)
121 .get_container()[i]));
123 diag = make_rcp<const DiagonalMatrix>(container);
125 }
else if (is_a<ImmutableDenseMatrix>(*term)) {
126 if (dense.is_null()) {
127 dense = rcp_static_cast<const ImmutableDenseMatrix>(term);
129 const vec_basic &vec1
130 = down_cast<const ImmutableDenseMatrix &>(*term)
132 const vec_basic &vec2 = dense->get_values();
133 vec_basic sum(vec1.size());
134 for (
size_t i = 0; i < vec1.size(); i++) {
135 sum[i] =
add(vec1[i], vec2[i]);
137 dense = make_rcp<const ImmutableDenseMatrix>(
138 dense->nrows(), dense->ncols(), sum);
141 keep.push_back(term);
144 if (!diag.is_null()) {
145 if (!dense.is_null()) {
147 auto vec = dense->get_values();
149 for (
size_t i = 0; i < dense->nrows(); i++) {
150 for (
size_t j = 0; j < dense->ncols(); j++) {
154 sum.push_back(dense->get(i, j));
158 dense = make_rcp<const ImmutableDenseMatrix>(dense->nrows(),
159 dense->ncols(), sum);
161 keep.push_back(diag);
164 if (!dense.is_null()) {
165 keep.push_back(dense);
167 if (keep.size() == 1) {
168 return rcp_static_cast<const MatrixExpr>(keep[0]);
170 if (keep.size() == 0 && !zero.is_null()) {
173 return make_rcp<const MatrixAdd>(keep);
Classes and functions relating to the binary operation of addition.
The lowest unit of symbolic representation.
bool __eq__(const Basic &o) const override
Test equality.
int compare(const Basic &o) const override
hash_t __hash__() const override
Main namespace for SymEngine package.
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
int unified_compare(const T &a, const T &b)