transpose.cpp
1 #include <symengine/basic.h>
2 #include <symengine/matrices/matrix_expr.h>
3 #include <symengine/matrices/transpose.h>
4 #include <symengine/visitor.h>
5 
6 namespace SymEngine
7 {
8 
9 hash_t Transpose::__hash__() const
10 {
11  hash_t seed = SYMENGINE_TRANSPOSE;
12  hash_combine<Basic>(seed, *arg_);
13  return seed;
14 }
15 
16 bool Transpose::__eq__(const Basic &o) const
17 {
18  return (is_a<Transpose>(o)
19  && arg_->__eq__(*down_cast<const Transpose &>(o).arg_));
20 }
21 
22 bool Transpose::is_canonical(const RCP<const MatrixExpr> &arg) const
23 {
24  if (is_a<IdentityMatrix>(*arg) || is_a<ZeroMatrix>(*arg)
25  || is_a<DiagonalMatrix>(*arg) || is_a<ImmutableDenseMatrix>(*arg)
26  || is_a<Transpose>(*arg) || is_a<MatrixAdd>(*arg)
27  || is_a<HadamardProduct>(*arg)) {
28  return false;
29  }
30  return true;
31 }
32 
33 int Transpose::compare(const Basic &o) const
34 {
35  SYMENGINE_ASSERT(is_a<Transpose>(o));
36 
37  return arg_->compare(*down_cast<const Transpose &>(o).arg_);
38 }
39 
41 {
42  return {arg_};
43 }
44 
45 class TransposeVisitor : public BaseVisitor<TransposeVisitor>
46 {
47 private:
48  RCP<const MatrixExpr> transpose_;
49 
50 public:
51  TransposeVisitor() {}
52 
53  void bvisit(const Basic &x){};
54 
55  void bvisit(const MatrixExpr &x)
56  {
57  auto arg = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
58  transpose_ = make_rcp<const Transpose>(arg);
59  }
60 
61  void bvisit(const IdentityMatrix &x)
62  {
63  transpose_ = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
64  }
65 
66  void bvisit(const ZeroMatrix &x)
67  {
68  transpose_ = make_rcp<const ZeroMatrix>(x.ncols(), x.nrows());
69  }
70 
71  void bvisit(const DiagonalMatrix &x)
72  {
73  transpose_ = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
74  }
75 
76  void bvisit(const ImmutableDenseMatrix &x)
77  {
78  auto values = x.get_values();
79  vec_basic t(values.size());
80 
81  for (size_t i = 0; i < x.nrows(); i++)
82  for (size_t j = 0; j < x.ncols(); j++)
83  t[j * x.ncols() + i] = x.get(i, j);
84 
85  transpose_
86  = make_rcp<const ImmutableDenseMatrix>(x.ncols(), x.nrows(), t);
87  }
88 
89  void bvisit(const Transpose &x)
90  {
91  transpose_ = x.get_arg();
92  }
93 
94  void bvisit(const MatrixAdd &x)
95  {
96  vec_basic t;
97  for (auto &e : x.get_terms()) {
98  e->accept(*this);
99  t.push_back(transpose_);
100  }
101  transpose_ = make_rcp<const MatrixAdd>(t);
102  }
103 
104  void bvisit(const HadamardProduct &x)
105  {
106  vec_basic t;
107  for (auto &e : x.get_factors()) {
108  e->accept(*this);
109  t.push_back(transpose_);
110  }
111  transpose_ = make_rcp<const HadamardProduct>(t);
112  }
113 
114  RCP<const MatrixExpr> apply(const MatrixExpr &s)
115  {
116  s.accept(*this);
117  return transpose_;
118  }
119 };
120 
121 RCP<const MatrixExpr> transpose(const RCP<const MatrixExpr> &arg)
122 {
123  TransposeVisitor visitor;
124  return visitor.apply(*arg);
125 }
126 } // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
vec_basic get_args() const override
Returns the list of arguments.
Definition: transpose.cpp:40
bool __eq__(const Basic &o) const override
Test equality.
Definition: transpose.cpp:16
int compare(const Basic &o) const override
Definition: transpose.cpp:33
hash_t __hash__() const override
Definition: transpose.cpp:9
Main namespace for SymEngine package.
Definition: add.cpp:19
T push_back(T... args)