matrix_add.cpp
1 #include <symengine/add.h>
2 #include <symengine/matrices/matrix_add.h>
3 #include <symengine/matrices/zero_matrix.h>
4 #include <symengine/matrices/diagonal_matrix.h>
5 #include <symengine/matrices/immutable_dense_matrix.h>
6 
7 namespace SymEngine
8 {
9 
10 hash_t MatrixAdd::__hash__() const
11 {
12  hash_t seed = SYMENGINE_MATRIXADD;
13  for (const auto &a : terms_) {
14  hash_combine<Basic>(seed, *a);
15  }
16  return seed;
17 }
18 
19 bool MatrixAdd::__eq__(const Basic &o) const
20 {
21  if (is_a<MatrixAdd>(o)) {
22  const MatrixAdd &other = down_cast<const MatrixAdd &>(o);
23  return unified_eq(terms_, other.terms_);
24  }
25  return false;
26 }
27 
28 int MatrixAdd::compare(const Basic &o) const
29 {
30  SYMENGINE_ASSERT(is_a<MatrixAdd>(o));
31  const MatrixAdd &other = down_cast<const MatrixAdd &>(o);
32  return unified_compare(terms_, other.terms_);
33 }
34 
35 bool MatrixAdd::is_canonical(const vec_basic &terms) const
36 {
37  if (terms.size() < 2) {
38  return false;
39  }
40  size_t num_diag = 0;
41  size_t num_dense = 0;
42  for (auto term : terms) {
43  if (is_a<ZeroMatrix>(*term) || is_a<MatrixAdd>(*term)) {
44  return false;
45  } else if (is_a<DiagonalMatrix>(*term)) {
46  num_diag++;
47  } else if (is_a<ImmutableDenseMatrix>(*term)) {
48  num_dense++;
49  }
50  }
51  if (num_diag > 1 || num_dense > 1) {
52  return false;
53  }
54  if (num_diag == 1 && num_dense == 1) {
55  return false;
56  }
57  return true;
58 }
59 
60 void check_matching_sizes(const vec_basic &vec)
61 {
62  for (size_t i = 0; i < vec.size() - 1; i++) {
63  auto first_size = size(down_cast<const MatrixExpr &>(*vec[i]));
64  if (first_size.first.is_null()) {
65  continue;
66  }
67  for (size_t j = 1; j < vec.size(); j++) {
68  auto second_size = size(down_cast<const MatrixExpr &>(*vec[j]));
69  if (second_size.first.is_null()) {
70  continue;
71  }
72  auto rowdiff = sub(first_size.first, second_size.first);
73  tribool rowmatch = is_zero(*rowdiff);
74  if (is_false(rowmatch)) {
75  throw DomainError("Matrix dimension mismatch");
76  }
77  auto coldiff = sub(first_size.second, second_size.second);
78  tribool colmatch = is_zero(*coldiff);
79  if (is_false(colmatch)) {
80  throw DomainError("Matrix dimension mismatch");
81  }
82  }
83  }
84 }
85 
86 RCP<const MatrixExpr> matrix_add(const vec_basic &terms)
87 {
88  if (terms.size() == 0) {
89  throw DomainError("Empty sum of matrices");
90  }
91  if (terms.size() == 1) {
92  return rcp_static_cast<const MatrixExpr>(terms[0]);
93  }
94  // extract nested MatrixAdd
95  vec_basic expanded;
96  for (auto &term : terms) {
97  if (is_a<const MatrixAdd>(*term)) {
98  auto container = down_cast<const MatrixAdd &>(*term).get_terms();
99  expanded.insert(expanded.end(), container.begin(), container.end());
100  } else {
101  expanded.push_back(term);
102  }
103  }
104  check_matching_sizes(expanded);
105  vec_basic keep;
106  RCP<const DiagonalMatrix> diag;
107  RCP<const ImmutableDenseMatrix> dense;
108  RCP<const ZeroMatrix> zero;
109  for (auto &term : expanded) {
110  if (is_a<ZeroMatrix>(*term)) {
111  zero = rcp_static_cast<const ZeroMatrix>(term);
112  } else if (is_a<DiagonalMatrix>(*term)) {
113  if (diag.is_null()) {
114  diag = rcp_static_cast<const DiagonalMatrix>(term);
115  } else {
116  vec_basic container;
117  for (size_t i = 0; i < diag->get_container().size(); i++) {
118  container.push_back(
119  add(diag->get_container()[i],
120  down_cast<const DiagonalMatrix &>(*term)
121  .get_container()[i]));
122  }
123  diag = make_rcp<const DiagonalMatrix>(container);
124  }
125  } else if (is_a<ImmutableDenseMatrix>(*term)) {
126  if (dense.is_null()) {
127  dense = rcp_static_cast<const ImmutableDenseMatrix>(term);
128  } else {
129  const vec_basic &vec1
130  = down_cast<const ImmutableDenseMatrix &>(*term)
131  .get_values();
132  const vec_basic &vec2 = dense->get_values();
133  vec_basic sum(vec1.size());
134  for (size_t i = 0; i < vec1.size(); i++) {
135  sum[i] = add(vec1[i], vec2[i]);
136  }
137  dense = make_rcp<const ImmutableDenseMatrix>(
138  dense->nrows(), dense->ncols(), sum);
139  }
140  } else {
141  keep.push_back(term);
142  }
143  }
144  if (!diag.is_null()) {
145  if (!dense.is_null()) {
146  // Add diagonal with dense matrix
147  auto vec = dense->get_values();
148  vec_basic sum;
149  for (size_t i = 0; i < dense->nrows(); i++) {
150  for (size_t j = 0; j < dense->ncols(); j++) {
151  if (i == j) {
152  sum.push_back(add(dense->get(i, j), diag->get(i)));
153  } else {
154  sum.push_back(dense->get(i, j));
155  }
156  }
157  }
158  dense = make_rcp<const ImmutableDenseMatrix>(dense->nrows(),
159  dense->ncols(), sum);
160  } else {
161  keep.push_back(diag);
162  }
163  }
164  if (!dense.is_null()) {
165  keep.push_back(dense);
166  }
167  if (keep.size() == 1) {
168  return rcp_static_cast<const MatrixExpr>(keep[0]);
169  }
170  if (keep.size() == 0 && !zero.is_null()) {
171  return zero;
172  }
173  return make_rcp<const MatrixAdd>(keep);
174 }
175 
176 } // namespace SymEngine
Classes and functions relating to the binary operation of addition.
The lowest unit of symbolic representation.
Definition: basic.h:97
bool __eq__(const Basic &o) const override
Test equality.
Definition: matrix_add.cpp:19
int compare(const Basic &o) const override
Definition: matrix_add.cpp:28
hash_t __hash__() const override
Definition: matrix_add.cpp:10
T insert(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
Definition: add.cpp:495
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
int unified_compare(const T &a, const T &b)
Definition: dict.h:205
T push_back(T... args)
T size(T... args)