Loading...
Searching...
No Matches
visitor.h
Go to the documentation of this file.
1
6#ifndef SYMENGINE_VISITOR_H
7#define SYMENGINE_VISITOR_H
8
11#include <symengine/polys/uexprpoly.h>
12#include <symengine/polys/msymenginepoly.h>
14#include <symengine/complex_mpc.h>
16#include <symengine/series_piranha.h>
17#include <symengine/series_flint.h>
19#include <symengine/series_piranha.h>
20#include <symengine/sets.h>
21#include <symengine/fields.h>
22#include <symengine/logic.h>
23#include <symengine/infinity.h>
24#include <symengine/nan.h>
25#include <symengine/matrix.h>
26#include <symengine/ntheory_funcs.h>
27#include <symengine/symengine_casts.h>
28#include <symengine/tuple.h>
29#include <symengine/matrix_expressions.h>
30
31namespace SymEngine
32{
33
35{
36public:
37 virtual ~Visitor(){};
38#define SYMENGINE_ENUM(TypeID, Class) virtual void visit(const Class &) = 0;
39#include "symengine/type_codes.inc"
40#undef SYMENGINE_ENUM
41};
42
43void preorder_traversal(const Basic &b, Visitor &v);
44void postorder_traversal(const Basic &b, Visitor &v);
45
46template <class Derived, class Base = Visitor>
47class BaseVisitor : public Base
48{
49
50public:
51#if (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ < 8 \
52 || defined(__SUNPRO_CC))
53 // Following two ctors can be replaced by `using Base::Base` if inheriting
54 // constructors are allowed by the compiler. GCC 4.8 is the earliest
55 // version supporting this.
56 template <typename... Args,
57 typename
58 = enable_if_t<std::is_constructible<Base, Args...>::value>>
59 BaseVisitor(Args &&...args) : Base(std::forward<Args>(args)...)
60 {
61 }
62
63 BaseVisitor() : Base() {}
64#else
65 using Base::Base;
66#endif
67
68#define SYMENGINE_ENUM(TypeID, Class) \
69 virtual void visit(const Class &x) \
70 { \
71 down_cast<Derived *>(this)->bvisit(x); \
72 };
73#include "symengine/type_codes.inc"
74#undef SYMENGINE_ENUM
75};
76
77class StopVisitor : public Visitor
78{
79public:
80 bool stop_;
81};
82
84{
85public:
86 bool local_stop_;
87};
88
89void preorder_traversal_stop(const Basic &b, StopVisitor &v);
90void postorder_traversal_stop(const Basic &b, StopVisitor &v);
91void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v);
92
94{
95protected:
96 Ptr<const Basic> x_;
97 bool has_;
98
99public:
100 HasSymbolVisitor(Ptr<const Basic> x) : x_(x) {}
101
102 void bvisit(const Symbol &x)
103 {
104 if (eq(*x_, x)) {
105 has_ = true;
106 stop_ = true;
107 }
108 }
109
110 void bvisit(const FunctionSymbol &x)
111 {
112 if (eq(*x_, x)) {
113 has_ = true;
114 stop_ = true;
115 }
116 }
117
118 void bvisit(const Basic &x){};
119
120 bool apply(const Basic &b)
121 {
122 has_ = false;
123 stop_ = false;
124 preorder_traversal_stop(b, *this);
125 return has_;
126 }
127};
128
129bool has_symbol(const Basic &b, const Basic &x);
130
132{
133protected:
134 Ptr<const Basic> x_;
135 Ptr<const Basic> n_;
136 RCP<const Basic> coeff_;
137
138public:
139 CoeffVisitor(Ptr<const Basic> x, Ptr<const Basic> n) : x_(x), n_(n) {}
140
141 void bvisit(const Add &x)
142 {
143 umap_basic_num dict;
144 RCP<const Number> coef = zero;
145 for (auto &p : x.get_dict()) {
146 p.first->accept(*this);
147 if (neq(*coeff_, *zero)) {
148 Add::coef_dict_add_term(outArg(coef), dict, p.second, coeff_);
149 }
150 }
151 if (eq(*zero, *n_)) {
152 iaddnum(outArg(coef), x.get_coef());
153 }
154 coeff_ = Add::from_dict(coef, std::move(dict));
155 }
156
157 void bvisit(const Mul &x)
158 {
159 for (auto &p : x.get_dict()) {
160 if (eq(*p.first, *x_) and eq(*p.second, *n_)) {
161 map_basic_basic dict = x.get_dict();
162 dict.erase(p.first);
163 coeff_ = Mul::from_dict(x.get_coef(), std::move(dict));
164 return;
165 }
166 }
167 if (eq(*zero, *n_) and not has_symbol(x, *x_)) {
168 coeff_ = x.rcp_from_this();
169 } else {
170 coeff_ = zero;
171 }
172 }
173
174 void bvisit(const Pow &x)
175 {
176 if (eq(*x.get_base(), *x_) and eq(*x.get_exp(), *n_)) {
177 coeff_ = one;
178 } else if (neq(*x.get_base(), *x_) and eq(*zero, *n_)) {
179 coeff_ = x.rcp_from_this();
180 } else {
181 coeff_ = zero;
182 }
183 }
184
185 void bvisit(const Symbol &x)
186 {
187 if (eq(x, *x_) and eq(*one, *n_)) {
188 coeff_ = one;
189 } else if (neq(x, *x_) and eq(*zero, *n_)) {
190 coeff_ = x.rcp_from_this();
191 } else {
192 coeff_ = zero;
193 }
194 }
195
196 void bvisit(const FunctionSymbol &x)
197 {
198 if (eq(x, *x_) and eq(*one, *n_)) {
199 coeff_ = one;
200 } else if (neq(x, *x_) and eq(*zero, *n_)) {
201 coeff_ = x.rcp_from_this();
202 } else {
203 coeff_ = zero;
204 }
205 }
206
207 void bvisit(const Basic &x)
208 {
209 if (neq(*zero, *n_)) {
210 coeff_ = zero;
211 return;
212 }
213 if (has_symbol(x, *x_)) {
214 coeff_ = zero;
215 } else {
216 coeff_ = x.rcp_from_this();
217 }
218 }
219
220 RCP<const Basic> apply(const Basic &b)
221 {
222 coeff_ = zero;
223 b.accept(*this);
224 return coeff_;
225 }
226};
227
228RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n);
229
230set_basic free_symbols(const Basic &b);
231
232set_basic free_symbols(const MatrixBase &m);
233
234set_basic function_symbols(const Basic &b);
235
237{
238protected:
239 RCP<const Basic> result_;
240
241public:
243
244 virtual RCP<const Basic> apply(const RCP<const Basic> &x);
245
246 void bvisit(const Basic &x);
247 void bvisit(const Add &x);
248 void bvisit(const Mul &x);
249 void bvisit(const Pow &x);
250 void bvisit(const OneArgFunction &x);
251
252 template <class T>
253 void bvisit(const TwoArgBasic<T> &x)
254 {
255 auto farg1 = x.get_arg1(), farg2 = x.get_arg2();
256 auto newarg1 = apply(farg1), newarg2 = apply(farg2);
257 if (farg1 != newarg1 or farg2 != newarg2) {
258 result_ = x.create(newarg1, newarg2);
259 } else {
260 result_ = x.rcp_from_this();
261 }
262 }
263
264 void bvisit(const MultiArgFunction &x);
265 void bvisit(const Piecewise &x);
266};
267
268template <typename Derived, typename First, typename... Rest>
270 static const bool value = std::is_base_of<First, Derived>::value
271 or is_base_of_multiple<Derived, Rest...>::value;
272};
273
274template <typename Derived, typename First>
275struct is_base_of_multiple<Derived, First> {
276 static const bool value = std::is_base_of<First, Derived>::value;
277};
278
279template <typename... Args>
280class AtomsVisitor : public BaseVisitor<AtomsVisitor<Args...>>
281{
282public:
283 set_basic s;
284 uset_basic visited;
285
286 template <typename T,
287 typename = enable_if_t<is_base_of_multiple<T, Args...>::value>>
288 void bvisit(const T &x)
289 {
290 s.insert(x.rcp_from_this());
291 visited.insert(x.rcp_from_this());
292 bvisit((const Basic &)x);
293 }
294
295 void bvisit(const Basic &x)
296 {
297 for (const auto &p : x.get_args()) {
298 auto iter = visited.insert(p->rcp_from_this());
299 if (iter.second) {
300 p->accept(*this);
301 }
302 }
303 }
304
305 set_basic apply(const Basic &b)
306 {
307 b.accept(*this);
308 return s;
309 }
310};
311
312template <typename... Args>
313inline set_basic atoms(const Basic &b)
314{
315 AtomsVisitor<Args...> visitor;
316 return visitor.apply(b);
317};
318
319class CountOpsVisitor : public BaseVisitor<CountOpsVisitor>
320{
321protected:
323 v;
324
325public:
326 unsigned count = 0;
327 void apply(const Basic &b);
328 void bvisit(const Mul &x);
329 void bvisit(const Add &x);
330 void bvisit(const Pow &x);
331 void bvisit(const Number &x);
332 void bvisit(const ComplexBase &x);
333 void bvisit(const Symbol &x);
334 void bvisit(const Constant &x);
335 void bvisit(const Basic &x);
336};
337
338unsigned count_ops(const vec_basic &a);
339
340} // namespace SymEngine
341
342#endif
The base class for representing addition in symbolic expressions.
Definition: add.h:27
static RCP< const Basic > from_dict(const RCP< const Number > &coef, umap_basic_num &&d)
Create an appropriate instance from dictionary quickly.
Definition: add.cpp:140
const RCP< const Number > & get_coef() const
Definition: add.h:142
static void coef_dict_add_term(const Ptr< RCP< const Number > > &coef, umap_basic_num &d, const RCP< const Number > &c, const RCP< const Basic > &term)
Updates the numerical coefficient and the dictionary.
Definition: add.cpp:261
The lowest unit of symbolic representation.
Definition: basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
ComplexBase Class for deriving all complex classes.
Definition: complex.h:16
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:115
RCP< const Basic > get_base() const
Definition: pow.h:37
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_arg2() const
Definition: functions.h:96
RCP< const Basic > get_arg1() const
Definition: functions.h:91
virtual RCP< const Basic > create(const RCP< const Basic > &a, const RCP< const Basic > &b) const =0
Method to construct classes with canonicalization.
T erase(T... args)
T insert(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
Our comparison (==)
Definition: basic.h:219