Loading...
Searching...
No Matches
visitor.cpp
1#include <symengine/visitor.h>
2#include <symengine/polys/basic_conversions.h>
3#include <symengine/sets.h>
4
5#define ACCEPT(CLASS) \
6 void CLASS::accept(Visitor &v) const \
7 { \
8 v.visit(*this); \
9 }
10
11namespace SymEngine
12{
13
14#define SYMENGINE_ENUM(TypeID, Class) ACCEPT(Class)
15#include "symengine/type_codes.inc"
16#undef SYMENGINE_ENUM
17
18void preorder_traversal(const Basic &b, Visitor &v)
19{
20 b.accept(v);
21 for (const auto &p : b.get_args())
22 preorder_traversal(*p, v);
23}
24
25void postorder_traversal(const Basic &b, Visitor &v)
26{
27 for (const auto &p : b.get_args())
28 postorder_traversal(*p, v);
29 b.accept(v);
30}
31
32void preorder_traversal_stop(const Basic &b, StopVisitor &v)
33{
34 b.accept(v);
35 if (v.stop_)
36 return;
37 for (const auto &p : b.get_args()) {
38 preorder_traversal_stop(*p, v);
39 if (v.stop_)
40 return;
41 }
42}
43
44void postorder_traversal_stop(const Basic &b, StopVisitor &v)
45{
46 for (const auto &p : b.get_args()) {
47 postorder_traversal_stop(*p, v);
48 if (v.stop_)
49 return;
50 }
51 b.accept(v);
52}
53
54bool has_symbol(const Basic &b, const Basic &x)
55{
56 // We are breaking a rule when using ptrFromRef() here, but since
57 // HasSymbolVisitor is only instantiated and freed from here, the `x` can
58 // never go out of scope, so this is safe.
59 HasSymbolVisitor v(ptrFromRef(x));
60 return v.apply(b);
61}
62
63RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n)
64{
65 if (!(is_a<Symbol>(x) || is_a<FunctionSymbol>(x))) {
66 throw NotImplementedError("Not implemented for non (Function)Symbols.");
67 }
68 CoeffVisitor v(ptrFromRef(x), ptrFromRef(n));
69 return v.apply(b);
70}
71
72class FreeSymbolsVisitor : public BaseVisitor<FreeSymbolsVisitor>
73{
74public:
75 set_basic s;
76 uset_basic v;
77
78 void bvisit(const Symbol &x)
79 {
80 s.insert(x.rcp_from_this());
81 }
82
83 void bvisit(const Subs &x)
84 {
85 set_basic set_ = free_symbols(*x.get_arg());
86 for (const auto &p : x.get_variables()) {
87 set_.erase(p);
88 }
89 s.insert(set_.begin(), set_.end());
90 for (const auto &p : x.get_point()) {
91 auto iter = v.insert(p->rcp_from_this());
92 if (iter.second) {
93 p->accept(*this);
94 }
95 }
96 }
97
98 void bvisit(const Basic &x)
99 {
100 for (const auto &p : x.get_args()) {
101 auto iter = v.insert(p->rcp_from_this());
102 if (iter.second) {
103 p->accept(*this);
104 }
105 }
106 }
107
108 set_basic apply(const Basic &b)
109 {
110 b.accept(*this);
111 return s;
112 }
113
114 set_basic apply(const MatrixBase &m)
115 {
116 for (unsigned i = 0; i < m.nrows(); i++) {
117 for (unsigned j = 0; j < m.ncols(); j++) {
118 m.get(i, j)->accept(*this);
119 }
120 }
121 return s;
122 }
123};
124
125set_basic free_symbols(const MatrixBase &m)
126{
127 FreeSymbolsVisitor visitor;
128 return visitor.apply(m);
129}
130
131set_basic free_symbols(const Basic &b)
132{
133 FreeSymbolsVisitor visitor;
134 return visitor.apply(b);
135}
136
137set_basic function_symbols(const Basic &b)
138{
139 return atoms<FunctionSymbol>(b);
140}
141
142RCP<const Basic> TransformVisitor::apply(const RCP<const Basic> &x)
143{
144 x->accept(*this);
145 return result_;
146}
147
148void TransformVisitor::bvisit(const Basic &x)
149{
150 result_ = x.rcp_from_this();
151}
152
153void TransformVisitor::bvisit(const Add &x)
154{
155 vec_basic newargs;
156 for (const auto &a : x.get_args()) {
157 newargs.push_back(apply(a));
158 }
159 result_ = add(newargs);
160}
161
162void TransformVisitor::bvisit(const Mul &x)
163{
164 vec_basic newargs;
165 for (const auto &a : x.get_args()) {
166 newargs.push_back(apply(a));
167 }
168 result_ = mul(newargs);
169}
170
171void TransformVisitor::bvisit(const Pow &x)
172{
173 auto base_ = x.get_base(), exp_ = x.get_exp();
174 auto newarg1 = apply(base_), newarg2 = apply(exp_);
175 if (base_ != newarg1 or exp_ != newarg2) {
176 result_ = pow(newarg1, newarg2);
177 } else {
178 result_ = x.rcp_from_this();
179 }
180}
181
182void TransformVisitor::bvisit(const OneArgFunction &x)
183{
184 auto farg = x.get_arg();
185 auto newarg = apply(farg);
186 if (eq(*newarg, *farg)) {
187 result_ = x.rcp_from_this();
188 } else {
189 result_ = x.create(newarg);
190 }
191}
192
193void TransformVisitor::bvisit(const MultiArgFunction &x)
194{
195 auto fargs = x.get_args();
196 vec_basic newargs;
197 for (const auto &a : fargs) {
198 newargs.push_back(apply(a));
199 }
200 auto nbarg = x.create(newargs);
201 result_ = nbarg;
202}
203
204void TransformVisitor::bvisit(const Piecewise &x)
205{
206 auto branch_cond_pairs = x.get_vec();
207 PiecewiseVec new_pairs;
208 for (const auto &branch_cond : branch_cond_pairs) {
209 auto branch = branch_cond.first;
210 auto cond = branch_cond.second;
211 auto new_branch = apply(branch);
212 auto new_cond = apply(cond);
213 if (!is_a_Boolean(*new_cond)) {
214 new_cond = Eq(new_cond, boolTrue);
215 }
216 new_pairs.push_back(
217 {new_branch, rcp_static_cast<const Boolean>(new_cond)});
218 }
219 result_ = piecewise(new_pairs);
220}
221
222void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v)
223{
224 b.accept(v);
225 if (v.stop_ or v.local_stop_)
226 return;
227 for (const auto &p : b.get_args()) {
228 preorder_traversal_local_stop(*p, v);
229 if (v.stop_)
230 return;
231 }
232}
233
234void CountOpsVisitor::apply(const Basic &b)
235{
236 unsigned count_now = count;
237 auto it = v.find(b.rcp_from_this());
238 if (it == v.end()) {
239 b.accept(*this);
240 insert(v, b.rcp_from_this(), count - count_now);
241 } else {
242 count += it->second;
243 }
244}
245
246void CountOpsVisitor::bvisit(const Mul &x)
247{
248 if (neq(*(x.get_coef()), *one)) {
249 count++;
250 apply(*x.get_coef());
251 }
252
253 for (const auto &p : x.get_dict()) {
254 if (neq(*p.second, *one)) {
255 count++;
256 apply(*p.second);
257 }
258 apply(*p.first);
259 count++;
260 }
261 count--;
262}
263
264void CountOpsVisitor::bvisit(const Add &x)
265{
266 if (neq(*(x.get_coef()), *zero)) {
267 count++;
268 apply(*x.get_coef());
269 }
270
271 for (const auto &p : x.get_dict()) {
272 if (neq(*p.second, *one)) {
273 count++;
274 apply(*p.second);
275 }
276 apply(*p.first);
277 count++;
278 }
279 count--;
280}
281
282void CountOpsVisitor::bvisit(const Pow &x)
283{
284 count++;
285 apply(*x.get_exp());
286 apply(*x.get_base());
287}
288
289void CountOpsVisitor::bvisit(const Number &x) {}
290
291void CountOpsVisitor::bvisit(const ComplexBase &x)
292{
293 if (neq(*x.real_part(), *zero)) {
294 count++;
295 }
296
297 if (neq(*x.imaginary_part(), *one)) {
298 count++;
299 }
300}
301
302void CountOpsVisitor::bvisit(const Symbol &x) {}
303
304void CountOpsVisitor::bvisit(const Constant &x) {}
305
306void CountOpsVisitor::bvisit(const Basic &x)
307{
308 count++;
309 for (const auto &p : x.get_args()) {
310 apply(*p);
311 }
312}
313
314unsigned count_ops(const vec_basic &a)
315{
316 CountOpsVisitor v;
317 for (auto &p : a) {
318 v.apply(*p);
319 }
320 return v.count;
321}
322
323} // namespace SymEngine
T begin(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
T end(T... args)
T erase(T... args)
T insert(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:653
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
T pow(T... args)
T push_back(T... args)