Loading...
Searching...
No Matches
trace.cpp
1#include <symengine/basic.h>
2#include <symengine/matrices/matrix_expr.h>
3#include <symengine/matrices/trace.h>
4#include <symengine/visitor.h>
5
6namespace SymEngine
7{
8
9hash_t Trace::__hash__() const
10{
11 hash_t seed = SYMENGINE_TRACE;
12 hash_combine<Basic>(seed, *arg_);
13 return seed;
14}
15
16bool Trace::__eq__(const Basic &o) const
17{
18 return (is_a<Trace>(o) && arg_->__eq__(*down_cast<const Trace &>(o).arg_));
19}
20
21int Trace::compare(const Basic &o) const
22{
23 SYMENGINE_ASSERT(is_a<Trace>(o));
24
25 return arg_->compare(*down_cast<const Trace &>(o).arg_);
26}
27
29{
30 return {arg_};
31}
32
33class MatrixTraceVisitor : public BaseVisitor<MatrixTraceVisitor>
34{
35private:
36 RCP<const Basic> trace_;
37
38 void trace_error()
39 {
40 throw DomainError("Trace is only valid for square matrices");
41 }
42
43public:
45
46 void bvisit(const Basic &x){};
47
48 void bvisit(const MatrixExpr &x)
49 {
51 trace_ = make_rcp<const Trace>(arg);
52 }
53
54 void bvisit(const IdentityMatrix &x)
55 {
56 trace_ = x.size();
57 }
58
59 void bvisit(const ZeroMatrix &x)
60 {
61 tribool sq = is_square(x);
62 if (is_true(sq)) {
63 trace_ = zero;
64 } else if (is_false(sq)) {
65 trace_error();
66 } else {
68 trace_ = make_rcp<const Trace>(arg);
69 }
70 }
71
72 void bvisit(const DiagonalMatrix &x)
73 {
74 trace_ = add(x.get_container());
75 }
76
77 void bvisit(const ImmutableDenseMatrix &x)
78 {
79 if (x.nrows() != x.ncols()) {
80 trace_error();
81 }
82 vec_basic diag;
83 for (size_t i = 0; i < x.nrows(); i++) {
84 diag.push_back(x.get(i, i));
85 }
86 trace_ = add(diag);
87 }
88
89 void bvisit(const MatrixAdd &x)
90 {
91 // Trace is a linear function so trace(A + B) = trace(A) + trace(B)
92 RCP<const Basic> sum = zero;
93 for (auto &e : x.get_terms()) {
94 e->accept(*this);
95 sum = add(sum, trace_);
96 }
97 trace_ = sum;
98 }
99
100 RCP<const Basic> apply(const MatrixExpr &s)
101 {
102 s.accept(*this);
103 return trace_;
104 }
105};
106
107RCP<const Basic> trace(const RCP<const MatrixExpr> &arg)
108{
110 return visitor.apply(*arg);
111}
112} // namespace SymEngine
The base class for SymEngine.
The lowest unit of symbolic representation.
Definition basic.h:97
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
hash_t __hash__() const override
Definition trace.cpp:9
bool __eq__(const Basic &o) const override
Test equality.
Definition trace.cpp:16
int compare(const Basic &o) const override
Definition trace.cpp:21
vec_basic get_args() const override
Returns the list of arguments.
Definition trace.cpp:28
Main namespace for SymEngine package.
Definition add.cpp:19
void hash_combine(hash_t &seed, const T &v)
Definition basic-inl.h:95
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition add.cpp:425