hadamard_product.cpp
1 #include <symengine/mul.h>
2 #include <symengine/matrices/hadamard_product.h>
3 #include <symengine/matrices/zero_matrix.h>
4 #include <symengine/matrices/diagonal_matrix.h>
5 #include <symengine/matrices/immutable_dense_matrix.h>
6 #include <symengine/matrices/identity_matrix.h>
7 
8 namespace SymEngine
9 {
10 
11 void check_matching_sizes(const vec_basic &vec);
12 
14 {
15  hash_t seed = SYMENGINE_HADAMARDPRODUCT;
16  for (const auto &a : factors_) {
17  hash_combine<Basic>(seed, *a);
18  }
19  return seed;
20 }
21 
22 bool HadamardProduct::__eq__(const Basic &o) const
23 {
24  if (is_a<HadamardProduct>(o)) {
25  const HadamardProduct &other = down_cast<const HadamardProduct &>(o);
26  return unified_eq(factors_, other.factors_);
27  }
28  return false;
29 }
30 
31 int HadamardProduct::compare(const Basic &o) const
32 {
33  SYMENGINE_ASSERT(is_a<HadamardProduct>(o));
34  const HadamardProduct &other = down_cast<const HadamardProduct &>(o);
35  return unified_compare(factors_, other.factors_);
36 }
37 
38 bool HadamardProduct::is_canonical(const vec_basic &factors) const
39 {
40  if (factors.size() < 2) {
41  return false;
42  }
43  size_t num_diag = 0;
44  size_t num_dense = 0;
45  size_t num_ident = 0;
46  for (auto factor : factors) {
47  if (is_a<ZeroMatrix>(*factor) || is_a<HadamardProduct>(*factor)) {
48  return false;
49  } else if (is_a<DiagonalMatrix>(*factor)) {
50  num_diag++;
51  } else if (is_a<ImmutableDenseMatrix>(*factor)) {
52  num_dense++;
53  } else if (is_a<IdentityMatrix>(*factor)) {
54  num_ident++;
55  }
56  }
57  if (num_diag > 1 || num_ident > 1 || num_dense > 1) {
58  return false;
59  }
60  if (num_diag == 1 && num_dense == 1) {
61  return false;
62  }
63  return true;
64 }
65 
66 RCP<const MatrixExpr> hadamard_product(const vec_basic &factors)
67 {
68  if (factors.size() == 0) {
69  throw DomainError("Empty hadamard product");
70  }
71  if (factors.size() == 1) {
72  return rcp_static_cast<const MatrixExpr>(factors[0]);
73  }
74  // extract nested HadamardProduct
75  vec_basic expanded;
76  for (auto &factor : factors) {
77  if (is_a<const HadamardProduct>(*factor)) {
78  auto container
79  = down_cast<const HadamardProduct &>(*factor).get_factors();
80  expanded.insert(expanded.end(), container.begin(), container.end());
81  } else {
82  expanded.push_back(factor);
83  }
84  }
85  check_matching_sizes(expanded);
86  vec_basic keep;
87  RCP<const DiagonalMatrix> diag;
88  RCP<const ImmutableDenseMatrix> dense;
89  bool have_identity = false;
90  for (auto &factor : expanded) {
91  if (is_a<ZeroMatrix>(*factor)) {
92  return rcp_static_cast<const MatrixExpr>(factor);
93  } else if (is_a<IdentityMatrix>(*factor)) {
94  if (!have_identity) {
95  have_identity = true;
96  keep.push_back(factor);
97  }
98  } else if (is_a<DiagonalMatrix>(*factor)) {
99  if (diag.is_null()) {
100  diag = rcp_static_cast<const DiagonalMatrix>(factor);
101  } else {
102  vec_basic container;
103  for (size_t i = 0; i < diag->get_container().size(); i++) {
104  container.push_back(
105  mul(diag->get_container()[i],
106  down_cast<const DiagonalMatrix &>(*factor)
107  .get_container()[i]));
108  }
109  diag = make_rcp<const DiagonalMatrix>(container);
110  }
111  } else if (is_a<ImmutableDenseMatrix>(*factor)) {
112  if (dense.is_null()) {
113  dense = rcp_static_cast<const ImmutableDenseMatrix>(factor);
114  } else {
115  const vec_basic &vec1
116  = down_cast<const ImmutableDenseMatrix &>(*factor)
117  .get_values();
118  const vec_basic &vec2 = dense->get_values();
119  vec_basic product(vec1.size());
120  for (size_t i = 0; i < vec1.size(); i++) {
121  product[i] = mul(vec1[i], vec2[i]);
122  }
123  dense = make_rcp<const ImmutableDenseMatrix>(
124  dense->nrows(), dense->ncols(), product);
125  }
126  } else {
127  keep.push_back(factor);
128  }
129  }
130  if (!dense.is_null()) {
131  if (!diag.is_null()) {
132  // Multiply diagonal with dense matrix
133  vec_basic product;
134  for (size_t i = 0; i < dense->nrows(); i++) {
135  product.push_back(mul(dense->get(i, i), diag->get(i)));
136  }
137  diag = make_rcp<const DiagonalMatrix>(product);
138  } else {
139  keep.push_back(dense);
140  }
141  }
142  if (!diag.is_null()) {
143  keep.push_back(diag);
144  }
145  if (keep.size() == 1) {
146  return rcp_static_cast<const MatrixExpr>(keep[0]);
147  }
148  return make_rcp<const HadamardProduct>(keep);
149 }
150 
151 } // namespace SymEngine
The lowest unit of symbolic representation.
Definition: basic.h:97
hash_t __hash__() const override
int compare(const Basic &o) const override
bool __eq__(const Basic &o) const override
Test equality.
T insert(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
int factor(const Ptr< RCP< const Integer >> &f, const Integer &n, double B1)
Definition: ntheory.cpp:371
int unified_compare(const T &a, const T &b)
Definition: dict.h:205
T push_back(T... args)
T size(T... args)