Loading...
Searching...
No Matches
conjugate_matrix.cpp
1#include <symengine/basic.h>
2#include <symengine/matrices/matrix_expr.h>
3#include <symengine/matrices/conjugate_matrix.h>
4#include <symengine/visitor.h>
5
6namespace SymEngine
7{
8
10{
11 hash_t seed = SYMENGINE_CONJUGATEMATRIX;
12 hash_combine<Basic>(seed, *arg_);
13 return seed;
14}
15
16bool ConjugateMatrix::__eq__(const Basic &o) const
17{
18 return (is_a<ConjugateMatrix>(o)
19 && arg_->__eq__(*down_cast<const ConjugateMatrix &>(o).arg_));
20}
21
22bool ConjugateMatrix::is_canonical(const RCP<const MatrixExpr> &arg) const
23{
24 // NOTE: For conjugate transpose always have the conjugate operation first
25 // i.e. transpose(conjugate(A))
26 if (is_a<IdentityMatrix>(*arg) || is_a<ZeroMatrix>(*arg)
27 || is_a<DiagonalMatrix>(*arg) || is_a<ImmutableDenseMatrix>(*arg)
28 || is_a<ConjugateMatrix>(*arg) || is_a<Transpose>(*arg)
29 || is_a<MatrixAdd>(*arg) || is_a<HadamardProduct>(*arg)) {
30 return false;
31 }
32 return true;
33}
34
36{
37 SYMENGINE_ASSERT(is_a<ConjugateMatrix>(o));
38
39 return arg_->compare(*down_cast<const ConjugateMatrix &>(o).arg_);
40}
41
43{
44 return {arg_};
45}
46
47class ConjugateMatrixVisitor : public BaseVisitor<ConjugateMatrixVisitor>
48{
49private:
50 RCP<const MatrixExpr> conjugate_;
51
52public:
54
55 void bvisit(const Basic &x){};
56
57 void bvisit(const MatrixExpr &x)
58 {
59 auto arg = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
60 conjugate_ = make_rcp<const ConjugateMatrix>(arg);
61 }
62
63 void bvisit(const IdentityMatrix &x)
64 {
65 conjugate_ = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
66 }
67
68 void bvisit(const ZeroMatrix &x)
69 {
70 conjugate_ = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
71 }
72
73 void bvisit(const DiagonalMatrix &x)
74 {
75 auto diag = x.get_container();
76 vec_basic conj(diag.size());
77 for (size_t i = 0; i < diag.size(); i++) {
78 conj[i] = conjugate(diag[i]);
79 }
80 conjugate_ = make_rcp<const DiagonalMatrix>(conj);
81 }
82
83 void bvisit(const ImmutableDenseMatrix &x)
84 {
85 auto values = x.get_values();
86 vec_basic conj(values.size());
87 for (size_t i = 0; i < values.size(); i++) {
88 conj[i] = conjugate(values[i]);
89 }
90 conjugate_
91 = make_rcp<const ImmutableDenseMatrix>(x.nrows(), x.ncols(), conj);
92 }
93
94 void bvisit(const ConjugateMatrix &x)
95 {
96 conjugate_ = x.get_arg();
97 }
98
99 void bvisit(const Transpose &x)
100 {
101 // Shift order to transpose(conj(A))
102 auto arg = x.get_arg();
103 auto conj = make_rcp<const ConjugateMatrix>(arg);
104 conjugate_ = make_rcp<const Transpose>(conj);
105 }
106
107 void bvisit(const MatrixAdd &x)
108 {
109 vec_basic conj;
110 for (auto &e : x.get_terms()) {
111 e->accept(*this);
112 conj.push_back(conjugate_);
113 }
114 conjugate_ = make_rcp<const MatrixAdd>(conj);
115 }
116
117 void bvisit(const HadamardProduct &x)
118 {
119 vec_basic conj;
120 for (auto &e : x.get_factors()) {
121 e->accept(*this);
122 conj.push_back(conjugate_);
123 }
124 conjugate_ = make_rcp<const HadamardProduct>(conj);
125 }
126
127 RCP<const MatrixExpr> apply(const MatrixExpr &s)
128 {
129 s.accept(*this);
130 return conjugate_;
131 }
132};
133
134RCP<const MatrixExpr> conjugate_matrix(const RCP<const MatrixExpr> &arg)
135{
137 return visitor.apply(*arg);
138}
139} // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
vec_basic get_args() const override
Returns the list of arguments.
int compare(const Basic &o) const override
hash_t __hash__() const override
bool __eq__(const Basic &o) const override
Test equality.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.
Definition: functions.cpp:149
T push_back(T... args)