visitor.h
Go to the documentation of this file.
1 
6 #ifndef SYMENGINE_VISITOR_H
7 #define SYMENGINE_VISITOR_H
8 
11 #include <symengine/polys/uexprpoly.h>
12 #include <symengine/polys/msymenginepoly.h>
14 #include <symengine/complex_mpc.h>
16 #include <symengine/series_piranha.h>
17 #include <symengine/series_flint.h>
19 #include <symengine/series_piranha.h>
20 #include <symengine/sets.h>
21 #include <symengine/fields.h>
22 #include <symengine/logic.h>
23 #include <symengine/infinity.h>
24 #include <symengine/nan.h>
25 #include <symengine/matrix.h>
26 #include <symengine/ntheory_funcs.h>
27 #include <symengine/symengine_casts.h>
28 #include <symengine/tuple.h>
29 #include <symengine/matrix_expressions.h>
30 
31 namespace SymEngine
32 {
33 
34 class Visitor
35 {
36 public:
37  virtual ~Visitor(){};
38 #define SYMENGINE_ENUM(TypeID, Class) virtual void visit(const Class &) = 0;
39 #include "symengine/type_codes.inc"
40 #undef SYMENGINE_ENUM
41 };
42 
43 void preorder_traversal(const Basic &b, Visitor &v);
44 void postorder_traversal(const Basic &b, Visitor &v);
45 
46 template <class Derived, class Base = Visitor>
47 class BaseVisitor : public Base
48 {
49 
50 public:
51 #if (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ < 8 \
52  || defined(__SUNPRO_CC))
53  // Following two ctors can be replaced by `using Base::Base` if inheriting
54  // constructors are allowed by the compiler. GCC 4.8 is the earliest
55  // version supporting this.
56  template <typename... Args,
57  typename
58  = enable_if_t<std::is_constructible<Base, Args...>::value>>
59  BaseVisitor(Args &&...args) : Base(std::forward<Args>(args)...)
60  {
61  }
62 
63  BaseVisitor() : Base() {}
64 #else
65  using Base::Base;
66 #endif
67 
68 #define SYMENGINE_ENUM(TypeID, Class) \
69  virtual void visit(const Class &x) \
70  { \
71  down_cast<Derived *>(this)->bvisit(x); \
72  };
73 #include "symengine/type_codes.inc"
74 #undef SYMENGINE_ENUM
75 };
76 
77 class StopVisitor : public Visitor
78 {
79 public:
80  bool stop_;
81 };
82 
84 {
85 public:
86  bool local_stop_;
87 };
88 
89 void preorder_traversal_stop(const Basic &b, StopVisitor &v);
90 void postorder_traversal_stop(const Basic &b, StopVisitor &v);
91 void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v);
92 
94 {
95 protected:
96  Ptr<const Basic> x_;
97  bool has_;
98 
99 public:
100  HasSymbolVisitor(Ptr<const Basic> x) : x_(x) {}
101 
102  void bvisit(const Symbol &x)
103  {
104  if (eq(*x_, x)) {
105  has_ = true;
106  stop_ = true;
107  }
108  }
109 
110  void bvisit(const FunctionSymbol &x)
111  {
112  if (eq(*x_, x)) {
113  has_ = true;
114  stop_ = true;
115  }
116  }
117 
118  void bvisit(const Basic &x){};
119 
120  bool apply(const Basic &b)
121  {
122  has_ = false;
123  stop_ = false;
124  preorder_traversal_stop(b, *this);
125  return has_;
126  }
127 };
128 
129 bool has_symbol(const Basic &b, const Basic &x);
130 
132 {
133 protected:
134  Ptr<const Basic> x_;
135  Ptr<const Basic> n_;
136  RCP<const Basic> coeff_;
137 
138 public:
139  CoeffVisitor(Ptr<const Basic> x, Ptr<const Basic> n) : x_(x), n_(n) {}
140 
141  void bvisit(const Add &x)
142  {
143  umap_basic_num dict;
144  RCP<const Number> coef = zero;
145  for (auto &p : x.get_dict()) {
146  p.first->accept(*this);
147  if (neq(*coeff_, *zero)) {
148  Add::coef_dict_add_term(outArg(coef), dict, p.second, coeff_);
149  }
150  }
151  if (eq(*zero, *n_)) {
152  iaddnum(outArg(coef), x.get_coef());
153  }
154  coeff_ = Add::from_dict(coef, std::move(dict));
155  }
156 
157  void bvisit(const Mul &x)
158  {
159  for (auto &p : x.get_dict()) {
160  if (eq(*p.first, *x_) and eq(*p.second, *n_)) {
161  map_basic_basic dict = x.get_dict();
162  dict.erase(p.first);
163  coeff_ = Mul::from_dict(x.get_coef(), std::move(dict));
164  return;
165  }
166  }
167  if (eq(*zero, *n_) and not has_symbol(x, *x_)) {
168  coeff_ = x.rcp_from_this();
169  } else {
170  coeff_ = zero;
171  }
172  }
173 
174  void bvisit(const Pow &x)
175  {
176  if (eq(*x.get_base(), *x_) and eq(*x.get_exp(), *n_)) {
177  coeff_ = one;
178  } else if (neq(*x.get_base(), *x_) and eq(*zero, *n_)) {
179  coeff_ = x.rcp_from_this();
180  } else {
181  coeff_ = zero;
182  }
183  }
184 
185  void bvisit(const Symbol &x)
186  {
187  if (eq(x, *x_) and eq(*one, *n_)) {
188  coeff_ = one;
189  } else if (neq(x, *x_) and eq(*zero, *n_)) {
190  coeff_ = x.rcp_from_this();
191  } else {
192  coeff_ = zero;
193  }
194  }
195 
196  void bvisit(const FunctionSymbol &x)
197  {
198  if (eq(x, *x_) and eq(*one, *n_)) {
199  coeff_ = one;
200  } else if (neq(x, *x_) and eq(*zero, *n_)) {
201  coeff_ = x.rcp_from_this();
202  } else {
203  coeff_ = zero;
204  }
205  }
206 
207  void bvisit(const Basic &x)
208  {
209  if (neq(*zero, *n_)) {
210  coeff_ = zero;
211  return;
212  }
213  if (has_symbol(x, *x_)) {
214  coeff_ = zero;
215  } else {
216  coeff_ = x.rcp_from_this();
217  }
218  }
219 
220  RCP<const Basic> apply(const Basic &b)
221  {
222  coeff_ = zero;
223  b.accept(*this);
224  return coeff_;
225  }
226 };
227 
228 RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n);
229 
230 set_basic free_symbols(const Basic &b);
231 
232 set_basic free_symbols(const MatrixBase &m);
233 
234 set_basic function_symbols(const Basic &b);
235 
237 {
238 protected:
239  RCP<const Basic> result_;
240 
241 public:
242  TransformVisitor() {}
243 
244  virtual RCP<const Basic> apply(const RCP<const Basic> &x);
245 
246  void bvisit(const Basic &x);
247  void bvisit(const Add &x);
248  void bvisit(const Mul &x);
249  void bvisit(const Pow &x);
250  void bvisit(const OneArgFunction &x);
251 
252  template <class T>
253  void bvisit(const TwoArgBasic<T> &x)
254  {
255  auto farg1 = x.get_arg1(), farg2 = x.get_arg2();
256  auto newarg1 = apply(farg1), newarg2 = apply(farg2);
257  if (farg1 != newarg1 or farg2 != newarg2) {
258  result_ = x.create(newarg1, newarg2);
259  } else {
260  result_ = x.rcp_from_this();
261  }
262  }
263 
264  void bvisit(const MultiArgFunction &x);
265  void bvisit(const Piecewise &x);
266 };
267 
268 template <typename Derived, typename First, typename... Rest>
270  static const bool value = std::is_base_of<First, Derived>::value
271  or is_base_of_multiple<Derived, Rest...>::value;
272 };
273 
274 template <typename Derived, typename First>
275 struct is_base_of_multiple<Derived, First> {
276  static const bool value = std::is_base_of<First, Derived>::value;
277 };
278 
279 template <typename... Args>
280 class AtomsVisitor : public BaseVisitor<AtomsVisitor<Args...>>
281 {
282 public:
283  set_basic s;
284  uset_basic visited;
285 
286  template <typename T,
287  typename = enable_if_t<is_base_of_multiple<T, Args...>::value>>
288  void bvisit(const T &x)
289  {
290  s.insert(x.rcp_from_this());
291  visited.insert(x.rcp_from_this());
292  bvisit((const Basic &)x);
293  }
294 
295  void bvisit(const Basic &x)
296  {
297  for (const auto &p : x.get_args()) {
298  auto iter = visited.insert(p->rcp_from_this());
299  if (iter.second) {
300  p->accept(*this);
301  }
302  }
303  }
304 
305  set_basic apply(const Basic &b)
306  {
307  b.accept(*this);
308  return s;
309  }
310 };
311 
312 template <typename... Args>
313 inline set_basic atoms(const Basic &b)
314 {
315  AtomsVisitor<Args...> visitor;
316  return visitor.apply(b);
317 };
318 
319 class CountOpsVisitor : public BaseVisitor<CountOpsVisitor>
320 {
321 protected:
323  v;
324 
325 public:
326  unsigned count = 0;
327  void apply(const Basic &b);
328  void bvisit(const Mul &x);
329  void bvisit(const Add &x);
330  void bvisit(const Pow &x);
331  void bvisit(const Number &x);
332  void bvisit(const ComplexBase &x);
333  void bvisit(const Symbol &x);
334  void bvisit(const Constant &x);
335  void bvisit(const Basic &x);
336 };
337 
338 unsigned count_ops(const vec_basic &a);
339 
340 } // namespace SymEngine
341 
342 #endif
The base class for representing addition in symbolic expressions.
Definition: add.h:27
static RCP< const Basic > from_dict(const RCP< const Number > &coef, umap_basic_num &&d)
Create an appropriate instance from dictionary quickly.
Definition: add.cpp:140
static void coef_dict_add_term(const Ptr< RCP< const Number >> &coef, umap_basic_num &d, const RCP< const Number > &c, const RCP< const Basic > &term)
Updates the numerical coefficient and the dictionary.
Definition: add.cpp:261
const RCP< const Number > & get_coef() const
Definition: add.h:142
The lowest unit of symbolic representation.
Definition: basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
ComplexBase Class for deriving all complex classes.
Definition: complex.h:16
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:115
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_base() const
Definition: pow.h:37
virtual RCP< const Basic > create(const RCP< const Basic > &a, const RCP< const Basic > &b) const =0
Method to construct classes with canonicalization.
RCP< const Basic > get_arg1() const
Definition: functions.h:91
RCP< const Basic > get_arg2() const
Definition: functions.h:96
T erase(T... args)
T insert(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
Our comparison (==)
Definition: basic.h:219