2 #include <symengine/matrices/hadamard_product.h>
3 #include <symengine/matrices/zero_matrix.h>
4 #include <symengine/matrices/diagonal_matrix.h>
5 #include <symengine/matrices/immutable_dense_matrix.h>
6 #include <symengine/matrices/identity_matrix.h>
11 void check_matching_sizes(
const vec_basic &vec);
15 hash_t seed = SYMENGINE_HADAMARDPRODUCT;
16 for (
const auto &a : factors_) {
17 hash_combine<Basic>(seed, *a);
24 if (is_a<HadamardProduct>(o)) {
26 return unified_eq(factors_, other.factors_);
33 SYMENGINE_ASSERT(is_a<HadamardProduct>(o));
38 bool HadamardProduct::is_canonical(
const vec_basic &factors)
const
40 if (factors.
size() < 2) {
46 for (
auto factor : factors) {
47 if (is_a<ZeroMatrix>(*
factor) || is_a<HadamardProduct>(*
factor)) {
49 }
else if (is_a<DiagonalMatrix>(*
factor)) {
51 }
else if (is_a<ImmutableDenseMatrix>(*
factor)) {
53 }
else if (is_a<IdentityMatrix>(*
factor)) {
57 if (num_diag > 1 || num_ident > 1 || num_dense > 1) {
60 if (num_diag == 1 && num_dense == 1) {
66 RCP<const MatrixExpr> hadamard_product(
const vec_basic &factors)
68 if (factors.size() == 0) {
69 throw DomainError(
"Empty hadamard product");
71 if (factors.size() == 1) {
72 return rcp_static_cast<const MatrixExpr>(factors[0]);
76 for (
auto &
factor : factors) {
77 if (is_a<const HadamardProduct>(*
factor)) {
79 = down_cast<const HadamardProduct &>(*factor).get_factors();
80 expanded.
insert(expanded.end(), container.begin(), container.end());
82 expanded.push_back(
factor);
85 check_matching_sizes(expanded);
87 RCP<const DiagonalMatrix> diag;
88 RCP<const ImmutableDenseMatrix> dense;
89 bool have_identity =
false;
90 for (
auto &
factor : expanded) {
91 if (is_a<ZeroMatrix>(*
factor)) {
92 return rcp_static_cast<const MatrixExpr>(
factor);
93 }
else if (is_a<IdentityMatrix>(*
factor)) {
98 }
else if (is_a<DiagonalMatrix>(*
factor)) {
100 diag = rcp_static_cast<const DiagonalMatrix>(
factor);
103 for (
size_t i = 0; i < diag->get_container().size(); i++) {
105 mul(diag->get_container()[i],
106 down_cast<const DiagonalMatrix &>(*
factor)
107 .get_container()[i]));
109 diag = make_rcp<const DiagonalMatrix>(container);
111 }
else if (is_a<ImmutableDenseMatrix>(*
factor)) {
112 if (dense.is_null()) {
113 dense = rcp_static_cast<const ImmutableDenseMatrix>(
factor);
115 const vec_basic &vec1
116 = down_cast<const ImmutableDenseMatrix &>(*
factor)
118 const vec_basic &vec2 = dense->get_values();
119 vec_basic product(vec1.size());
120 for (
size_t i = 0; i < vec1.size(); i++) {
121 product[i] =
mul(vec1[i], vec2[i]);
123 dense = make_rcp<const ImmutableDenseMatrix>(
124 dense->nrows(), dense->ncols(), product);
130 if (!dense.is_null()) {
131 if (!diag.is_null()) {
134 for (
size_t i = 0; i < dense->nrows(); i++) {
137 diag = make_rcp<const DiagonalMatrix>(product);
139 keep.push_back(dense);
142 if (!diag.is_null()) {
143 keep.push_back(diag);
145 if (keep.size() == 1) {
146 return rcp_static_cast<const MatrixExpr>(keep[0]);
148 return make_rcp<const HadamardProduct>(keep);
The lowest unit of symbolic representation.
hash_t __hash__() const override
int compare(const Basic &o) const override
bool __eq__(const Basic &o) const override
Test equality.
Main namespace for SymEngine package.
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
int factor(const Ptr< RCP< const Integer >> &f, const Integer &n, double B1)
int unified_compare(const T &a, const T &b)