trace.cpp
1 #include <symengine/basic.h>
2 #include <symengine/matrices/matrix_expr.h>
3 #include <symengine/matrices/trace.h>
4 #include <symengine/visitor.h>
5 
6 namespace SymEngine
7 {
8 
9 hash_t Trace::__hash__() const
10 {
11  hash_t seed = SYMENGINE_TRACE;
12  hash_combine<Basic>(seed, *arg_);
13  return seed;
14 }
15 
16 bool Trace::__eq__(const Basic &o) const
17 {
18  return (is_a<Trace>(o) && arg_->__eq__(*down_cast<const Trace &>(o).arg_));
19 }
20 
21 int Trace::compare(const Basic &o) const
22 {
23  SYMENGINE_ASSERT(is_a<Trace>(o));
24 
25  return arg_->compare(*down_cast<const Trace &>(o).arg_);
26 }
27 
29 {
30  return {arg_};
31 }
32 
33 class MatrixTraceVisitor : public BaseVisitor<MatrixTraceVisitor>
34 {
35 private:
36  RCP<const Basic> trace_;
37 
38  void trace_error()
39  {
40  throw DomainError("Trace is only valid for square matrices");
41  }
42 
43 public:
45 
46  void bvisit(const Basic &x){};
47 
48  void bvisit(const MatrixExpr &x)
49  {
50  auto arg = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
51  trace_ = make_rcp<const Trace>(arg);
52  }
53 
54  void bvisit(const IdentityMatrix &x)
55  {
56  trace_ = x.size();
57  }
58 
59  void bvisit(const ZeroMatrix &x)
60  {
61  tribool sq = is_square(x);
62  if (is_true(sq)) {
63  trace_ = zero;
64  } else if (is_false(sq)) {
65  trace_error();
66  } else {
67  auto arg = rcp_static_cast<const MatrixExpr>(x.rcp_from_this());
68  trace_ = make_rcp<const Trace>(arg);
69  }
70  }
71 
72  void bvisit(const DiagonalMatrix &x)
73  {
74  trace_ = add(x.get_container());
75  }
76 
77  void bvisit(const ImmutableDenseMatrix &x)
78  {
79  if (x.nrows() != x.ncols()) {
80  trace_error();
81  }
82  vec_basic diag;
83  for (size_t i = 0; i < x.nrows(); i++) {
84  diag.push_back(x.get(i, i));
85  }
86  trace_ = add(diag);
87  }
88 
89  void bvisit(const MatrixAdd &x)
90  {
91  // Trace is a linear function so trace(A + B) = trace(A) + trace(B)
92  RCP<const Basic> sum = zero;
93  for (auto &e : x.get_terms()) {
94  e->accept(*this);
95  sum = add(sum, trace_);
96  }
97  trace_ = sum;
98  }
99 
100  RCP<const Basic> apply(const MatrixExpr &s)
101  {
102  s.accept(*this);
103  return trace_;
104  }
105 };
106 
107 RCP<const Basic> trace(const RCP<const MatrixExpr> &arg)
108 {
109  MatrixTraceVisitor visitor;
110  return visitor.apply(*arg);
111 }
112 } // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition: basic.h:97
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
hash_t __hash__() const override
Definition: trace.cpp:9
bool __eq__(const Basic &o) const override
Test equality.
Definition: trace.cpp:16
int compare(const Basic &o) const override
Definition: trace.cpp:21
vec_basic get_args() const override
Returns the list of arguments.
Definition: trace.cpp:28
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425