conjugate_matrix.cpp
1 #include <symengine/basic.h>
2 #include <symengine/matrices/matrix_expr.h>
3 #include <symengine/matrices/conjugate_matrix.h>
4 #include <symengine/visitor.h>
5 
6 namespace SymEngine
7 {
8 
10 {
11  hash_t seed = SYMENGINE_CONJUGATEMATRIX;
12  hash_combine<Basic>(seed, *arg_);
13  return seed;
14 }
15 
16 bool ConjugateMatrix::__eq__(const Basic &o) const
17 {
18  return (is_a<ConjugateMatrix>(o)
19  && arg_->__eq__(*down_cast<const ConjugateMatrix &>(o).arg_));
20 }
21 
22 bool ConjugateMatrix::is_canonical(const RCP<const MatrixExpr> &arg) const
23 {
24  // NOTE: For conjugate transpose always have the conjugate operation first
25  // i.e. transpose(conjugate(A))
26  if (is_a<IdentityMatrix>(*arg) || is_a<ZeroMatrix>(*arg)
27  || is_a<DiagonalMatrix>(*arg) || is_a<ImmutableDenseMatrix>(*arg)
28  || is_a<ConjugateMatrix>(*arg) || is_a<Transpose>(*arg)
29  || is_a<MatrixAdd>(*arg) || is_a<HadamardProduct>(*arg)) {
30  return false;
31  }
32  return true;
33 }
34 
35 int ConjugateMatrix::compare(const Basic &o) const
36 {
37  SYMENGINE_ASSERT(is_a<ConjugateMatrix>(o));
38 
39  return arg_->compare(*down_cast<const ConjugateMatrix &>(o).arg_);
40 }
41 
43 {
44  return {arg_};
45 }
46 
47 class ConjugateMatrixVisitor : public BaseVisitor<ConjugateMatrixVisitor>
48 {
49 private:
50  RCP<const MatrixExpr> conjugate_;
51 
52 public:
54 
55  void bvisit(const Basic &x){};
56 
57  void bvisit(const MatrixExpr &x)
58  {
59  auto arg = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
60  conjugate_ = make_rcp<const ConjugateMatrix>(arg);
61  }
62 
63  void bvisit(const IdentityMatrix &x)
64  {
65  conjugate_ = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
66  }
67 
68  void bvisit(const ZeroMatrix &x)
69  {
70  conjugate_ = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
71  }
72 
73  void bvisit(const DiagonalMatrix &x)
74  {
75  auto diag = x.get_container();
76  vec_basic conj(diag.size());
77  for (size_t i = 0; i < diag.size(); i++) {
78  conj[i] = conjugate(diag[i]);
79  }
80  conjugate_ = make_rcp<const DiagonalMatrix>(conj);
81  }
82 
83  void bvisit(const ImmutableDenseMatrix &x)
84  {
85  auto values = x.get_values();
86  vec_basic conj(values.size());
87  for (size_t i = 0; i < values.size(); i++) {
88  conj[i] = conjugate(values[i]);
89  }
90  conjugate_
91  = make_rcp<const ImmutableDenseMatrix>(x.nrows(), x.ncols(), conj);
92  }
93 
94  void bvisit(const ConjugateMatrix &x)
95  {
96  conjugate_ = x.get_arg();
97  }
98 
99  void bvisit(const Transpose &x)
100  {
101  // Shift order to transpose(conj(A))
102  auto arg = x.get_arg();
103  auto conj = make_rcp<const ConjugateMatrix>(arg);
104  conjugate_ = make_rcp<const Transpose>(conj);
105  }
106 
107  void bvisit(const MatrixAdd &x)
108  {
109  vec_basic conj;
110  for (auto &e : x.get_terms()) {
111  e->accept(*this);
112  conj.push_back(conjugate_);
113  }
114  conjugate_ = make_rcp<const MatrixAdd>(conj);
115  }
116 
117  void bvisit(const HadamardProduct &x)
118  {
119  vec_basic conj;
120  for (auto &e : x.get_factors()) {
121  e->accept(*this);
122  conj.push_back(conjugate_);
123  }
124  conjugate_ = make_rcp<const HadamardProduct>(conj);
125  }
126 
127  RCP<const MatrixExpr> apply(const MatrixExpr &s)
128  {
129  s.accept(*this);
130  return conjugate_;
131  }
132 };
133 
134 RCP<const MatrixExpr> conjugate_matrix(const RCP<const MatrixExpr> &arg)
135 {
136  ConjugateMatrixVisitor visitor;
137  return visitor.apply(*arg);
138 }
139 } // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
vec_basic get_args() const override
Returns the list of arguments.
int compare(const Basic &o) const override
hash_t __hash__() const override
bool __eq__(const Basic &o) const override
Test equality.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.
Definition: functions.cpp:149
T push_back(T... args)