Loading...
Searching...
No Matches
visitor.cpp
1#include <symengine/visitor.h>
2#include <symengine/polys/basic_conversions.h>
3#include <symengine/sets.h>
4
5#define ACCEPT(CLASS) \
6 void CLASS::accept(Visitor &v) const \
7 { \
8 v.visit(*this); \
9 }
10
11namespace SymEngine
12{
13
14#define SYMENGINE_ENUM(TypeID, Class) ACCEPT(Class)
15#include "symengine/type_codes.inc"
16#undef SYMENGINE_ENUM
17
18void preorder_traversal(const Basic &b, Visitor &v)
19{
20 b.accept(v);
21 for (const auto &p : b.get_args())
22 preorder_traversal(*p, v);
23}
24
25void postorder_traversal(const Basic &b, Visitor &v)
26{
27 for (const auto &p : b.get_args())
28 postorder_traversal(*p, v);
29 b.accept(v);
30}
31
32void preorder_traversal_stop(const Basic &b, StopVisitor &v)
33{
34 b.accept(v);
35 if (v.stop_)
36 return;
37 for (const auto &p : b.get_args()) {
38 preorder_traversal_stop(*p, v);
39 if (v.stop_)
40 return;
41 }
42}
43
44void postorder_traversal_stop(const Basic &b, StopVisitor &v)
45{
46 for (const auto &p : b.get_args()) {
47 postorder_traversal_stop(*p, v);
48 if (v.stop_)
49 return;
50 }
51 b.accept(v);
52}
53
54bool has_symbol(const Basic &b, const Basic &x)
55{
56 // We are breaking a rule when using ptrFromRef() here, but since
57 // HasSymbolVisitor is only instantiated and freed from here, the `x` can
58 // never go out of scope, so this is safe.
59 HasSymbolVisitor v(ptrFromRef(x));
60 return v.apply(b);
61}
62
63RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n)
64{
65 if (!(is_a<Symbol>(x) || is_a<FunctionSymbol>(x))) {
66 throw NotImplementedError("Not implemented for non (Function)Symbols.");
67 }
68 CoeffVisitor v(ptrFromRef(x), ptrFromRef(n));
69 return v.apply(b);
70}
71
72class FreeSymbolsVisitor : public BaseVisitor<FreeSymbolsVisitor>
73{
74public:
75 set_basic s;
76 uset_basic v;
77
78 void bvisit(const Symbol &x)
79 {
80 s.insert(x.rcp_from_this());
81 }
82
83 void bvisit(const Subs &x)
84 {
85 set_basic set_ = free_symbols(*x.get_arg());
86 for (const auto &p : x.get_variables()) {
87 set_.erase(p);
88 }
89 s.insert(set_.begin(), set_.end());
90 for (const auto &p : x.get_point()) {
91 auto iter = v.insert(p->rcp_from_this());
92 if (iter.second) {
93 p->accept(*this);
94 }
95 }
96 }
97
98 void bvisit(const Basic &x)
99 {
100 for (const auto &p : x.get_args()) {
101 auto iter = v.insert(p->rcp_from_this());
102 if (iter.second) {
103 p->accept(*this);
104 }
105 }
106 }
107
108 set_basic apply(const Basic &b)
109 {
110 b.accept(*this);
111 return s;
112 }
113
114 set_basic apply(const MatrixBase &m)
115 {
116 for (unsigned i = 0; i < m.nrows(); i++) {
117 for (unsigned j = 0; j < m.ncols(); j++) {
118 m.get(i, j)->accept(*this);
119 }
120 }
121 return s;
122 }
123};
124
125set_basic free_symbols(const MatrixBase &m)
126{
127 FreeSymbolsVisitor visitor;
128 return visitor.apply(m);
129}
130
131set_basic free_symbols(const Basic &b)
132{
133 FreeSymbolsVisitor visitor;
134 return visitor.apply(b);
135}
136
137set_basic function_symbols(const Basic &b)
138{
139 return atoms<FunctionSymbol>(b);
140}
141
142RCP<const Basic> TransformVisitor::apply(const RCP<const Basic> &x)
143{
144 x->accept(*this);
145 return result_;
146}
147
148void TransformVisitor::bvisit(const Basic &x)
149{
150 result_ = x.rcp_from_this();
151}
152
153void TransformVisitor::bvisit(const Add &x)
154{
155 vec_basic newargs;
156 for (const auto &a : x.get_args()) {
157 newargs.push_back(apply(a));
158 }
159 result_ = add(newargs);
160}
161
162void TransformVisitor::bvisit(const Mul &x)
163{
164 vec_basic newargs;
165 for (const auto &a : x.get_args()) {
166 newargs.push_back(apply(a));
167 }
168 result_ = mul(newargs);
169}
170
171void TransformVisitor::bvisit(const Pow &x)
172{
173 auto base_ = x.get_base(), exp_ = x.get_exp();
174 auto newarg1 = apply(base_), newarg2 = apply(exp_);
175 if (base_ != newarg1 or exp_ != newarg2) {
176 result_ = pow(newarg1, newarg2);
177 } else {
178 result_ = x.rcp_from_this();
179 }
180}
181
182void TransformVisitor::bvisit(const OneArgFunction &x)
183{
184 auto farg = x.get_arg();
185 auto newarg = apply(farg);
186 if (eq(*newarg, *farg)) {
187 result_ = x.rcp_from_this();
188 } else {
189 result_ = x.create(newarg);
190 }
191}
192
193void TransformVisitor::bvisit(const MultiArgFunction &x)
194{
195 auto fargs = x.get_args();
196 vec_basic newargs;
197 for (const auto &a : fargs) {
198 newargs.push_back(apply(a));
199 }
200 auto nbarg = x.create(newargs);
201 result_ = nbarg;
202}
203
204void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v)
205{
206 b.accept(v);
207 if (v.stop_ or v.local_stop_)
208 return;
209 for (const auto &p : b.get_args()) {
210 preorder_traversal_local_stop(*p, v);
211 if (v.stop_)
212 return;
213 }
214}
215
216void CountOpsVisitor::apply(const Basic &b)
217{
218 unsigned count_now = count;
219 auto it = v.find(b.rcp_from_this());
220 if (it == v.end()) {
221 b.accept(*this);
222 insert(v, b.rcp_from_this(), count - count_now);
223 } else {
224 count += it->second;
225 }
226}
227
228void CountOpsVisitor::bvisit(const Mul &x)
229{
230 if (neq(*(x.get_coef()), *one)) {
231 count++;
232 apply(*x.get_coef());
233 }
234
235 for (const auto &p : x.get_dict()) {
236 if (neq(*p.second, *one)) {
237 count++;
238 apply(*p.second);
239 }
240 apply(*p.first);
241 count++;
242 }
243 count--;
244}
245
246void CountOpsVisitor::bvisit(const Add &x)
247{
248 if (neq(*(x.get_coef()), *zero)) {
249 count++;
250 apply(*x.get_coef());
251 }
252
253 unsigned i = 0;
254 for (const auto &p : x.get_dict()) {
255 if (neq(*p.second, *one)) {
256 count++;
257 apply(*p.second);
258 }
259 apply(*p.first);
260 count++;
261 i++;
262 }
263 count--;
264}
265
266void CountOpsVisitor::bvisit(const Pow &x)
267{
268 count++;
269 apply(*x.get_exp());
270 apply(*x.get_base());
271}
272
273void CountOpsVisitor::bvisit(const Number &x) {}
274
275void CountOpsVisitor::bvisit(const ComplexBase &x)
276{
277 if (neq(*x.real_part(), *zero)) {
278 count++;
279 }
280
281 if (neq(*x.imaginary_part(), *one)) {
282 count++;
283 }
284}
285
286void CountOpsVisitor::bvisit(const Symbol &x) {}
287
288void CountOpsVisitor::bvisit(const Constant &x) {}
289
290void CountOpsVisitor::bvisit(const Basic &x)
291{
292 count++;
293 for (const auto &p : x.get_args()) {
294 apply(*p);
295 }
296}
297
298unsigned count_ops(const vec_basic &a)
299{
300 CountOpsVisitor v;
301 for (auto &p : a) {
302 v.apply(*p);
303 }
304 return v.count;
305}
306
307} // namespace SymEngine
T begin(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
T end(T... args)
T erase(T... args)
T insert(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
T pow(T... args)
T push_back(T... args)