visitor.h
Go to the documentation of this file.
1 
6 #ifndef SYMENGINE_VISITOR_H
7 #define SYMENGINE_VISITOR_H
8 
11 #include <symengine/polys/uexprpoly.h>
12 #include <symengine/polys/msymenginepoly.h>
14 #include <symengine/complex_mpc.h>
16 #include <symengine/series_piranha.h>
17 #include <symengine/series_flint.h>
19 #include <symengine/series_piranha.h>
20 #include <symengine/sets.h>
21 #include <symengine/fields.h>
22 #include <symengine/logic.h>
23 #include <symengine/infinity.h>
24 #include <symengine/nan.h>
25 #include <symengine/matrix.h>
26 #include <symengine/ntheory_funcs.h>
27 #include <symengine/symengine_casts.h>
28 
29 namespace SymEngine
30 {
31 
32 class Visitor
33 {
34 public:
35  virtual ~Visitor(){};
36 #define SYMENGINE_ENUM(TypeID, Class) virtual void visit(const Class &) = 0;
37 #include "symengine/type_codes.inc"
38 #undef SYMENGINE_ENUM
39 };
40 
41 void preorder_traversal(const Basic &b, Visitor &v);
42 void postorder_traversal(const Basic &b, Visitor &v);
43 
44 template <class Derived, class Base = Visitor>
45 class BaseVisitor : public Base
46 {
47 
48 public:
49 #if (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ < 8 \
50  || defined(__SUNPRO_CC))
51  // Following two ctors can be replaced by `using Base::Base` if inheriting
52  // constructors are allowed by the compiler. GCC 4.8 is the earliest
53  // version supporting this.
54  template <typename... Args,
55  typename
56  = enable_if_t<std::is_constructible<Base, Args...>::value>>
57  BaseVisitor(Args &&...args) : Base(std::forward<Args>(args)...)
58  {
59  }
60 
61  BaseVisitor() : Base() {}
62 #else
63  using Base::Base;
64 #endif
65 
66 #define SYMENGINE_ENUM(TypeID, Class) \
67  virtual void visit(const Class &x) \
68  { \
69  down_cast<Derived *>(this)->bvisit(x); \
70  };
71 #include "symengine/type_codes.inc"
72 #undef SYMENGINE_ENUM
73 };
74 
75 class StopVisitor : public Visitor
76 {
77 public:
78  bool stop_;
79 };
80 
82 {
83 public:
84  bool local_stop_;
85 };
86 
87 void preorder_traversal_stop(const Basic &b, StopVisitor &v);
88 void postorder_traversal_stop(const Basic &b, StopVisitor &v);
89 void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v);
90 
92 {
93 protected:
94  Ptr<const Basic> x_;
95  bool has_;
96 
97 public:
98  HasSymbolVisitor(Ptr<const Basic> x) : x_(x) {}
99 
100  void bvisit(const Symbol &x)
101  {
102  if (eq(*x_, x)) {
103  has_ = true;
104  stop_ = true;
105  }
106  }
107 
108  void bvisit(const FunctionSymbol &x)
109  {
110  if (eq(*x_, x)) {
111  has_ = true;
112  stop_ = true;
113  }
114  }
115 
116  void bvisit(const Basic &x){};
117 
118  bool apply(const Basic &b)
119  {
120  has_ = false;
121  stop_ = false;
122  preorder_traversal_stop(b, *this);
123  return has_;
124  }
125 };
126 
127 bool has_symbol(const Basic &b, const Basic &x);
128 
130 {
131 protected:
132  Ptr<const Basic> x_;
133  Ptr<const Basic> n_;
134  RCP<const Basic> coeff_;
135 
136 public:
137  CoeffVisitor(Ptr<const Basic> x, Ptr<const Basic> n) : x_(x), n_(n) {}
138 
139  void bvisit(const Add &x)
140  {
141  umap_basic_num dict;
142  RCP<const Number> coef = zero;
143  for (auto &p : x.get_dict()) {
144  p.first->accept(*this);
145  if (neq(*coeff_, *zero)) {
146  Add::coef_dict_add_term(outArg(coef), dict, p.second, coeff_);
147  }
148  }
149  if (eq(*zero, *n_)) {
150  iaddnum(outArg(coef), x.get_coef());
151  }
152  coeff_ = Add::from_dict(coef, std::move(dict));
153  }
154 
155  void bvisit(const Mul &x)
156  {
157  for (auto &p : x.get_dict()) {
158  if (eq(*p.first, *x_) and eq(*p.second, *n_)) {
159  map_basic_basic dict = x.get_dict();
160  dict.erase(p.first);
161  coeff_ = Mul::from_dict(x.get_coef(), std::move(dict));
162  return;
163  }
164  }
165  if (eq(*zero, *n_) and not has_symbol(x, *x_)) {
166  coeff_ = x.rcp_from_this();
167  } else {
168  coeff_ = zero;
169  }
170  }
171 
172  void bvisit(const Pow &x)
173  {
174  if (eq(*x.get_base(), *x_) and eq(*x.get_exp(), *n_)) {
175  coeff_ = one;
176  } else if (neq(*x.get_base(), *x_) and eq(*zero, *n_)) {
177  coeff_ = x.rcp_from_this();
178  } else {
179  coeff_ = zero;
180  }
181  }
182 
183  void bvisit(const Symbol &x)
184  {
185  if (eq(x, *x_) and eq(*one, *n_)) {
186  coeff_ = one;
187  } else if (neq(x, *x_) and eq(*zero, *n_)) {
188  coeff_ = x.rcp_from_this();
189  } else {
190  coeff_ = zero;
191  }
192  }
193 
194  void bvisit(const FunctionSymbol &x)
195  {
196  if (eq(x, *x_) and eq(*one, *n_)) {
197  coeff_ = one;
198  } else if (neq(x, *x_) and eq(*zero, *n_)) {
199  coeff_ = x.rcp_from_this();
200  } else {
201  coeff_ = zero;
202  }
203  }
204 
205  void bvisit(const Basic &x)
206  {
207  if (neq(*zero, *n_)) {
208  coeff_ = zero;
209  return;
210  }
211  if (has_symbol(x, *x_)) {
212  coeff_ = zero;
213  } else {
214  coeff_ = x.rcp_from_this();
215  }
216  }
217 
218  RCP<const Basic> apply(const Basic &b)
219  {
220  coeff_ = zero;
221  b.accept(*this);
222  return coeff_;
223  }
224 };
225 
226 RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n);
227 
228 set_basic free_symbols(const Basic &b);
229 
230 set_basic free_symbols(const MatrixBase &m);
231 
232 set_basic function_symbols(const Basic &b);
233 
235 {
236 protected:
237  RCP<const Basic> result_;
238 
239 public:
240  TransformVisitor() {}
241 
242  virtual RCP<const Basic> apply(const RCP<const Basic> &x);
243 
244  void bvisit(const Basic &x);
245  void bvisit(const Add &x);
246  void bvisit(const Mul &x);
247  void bvisit(const Pow &x);
248  void bvisit(const OneArgFunction &x);
249 
250  template <class T>
251  void bvisit(const TwoArgBasic<T> &x)
252  {
253  auto farg1 = x.get_arg1(), farg2 = x.get_arg2();
254  auto newarg1 = apply(farg1), newarg2 = apply(farg2);
255  if (farg1 != newarg1 or farg2 != newarg2) {
256  result_ = x.create(newarg1, newarg2);
257  } else {
258  result_ = x.rcp_from_this();
259  }
260  }
261 
262  void bvisit(const MultiArgFunction &x);
263 };
264 
265 template <typename Derived, typename First, typename... Rest>
267  static const bool value = std::is_base_of<First, Derived>::value
268  or is_base_of_multiple<Derived, Rest...>::value;
269 };
270 
271 template <typename Derived, typename First>
272 struct is_base_of_multiple<Derived, First> {
273  static const bool value = std::is_base_of<First, Derived>::value;
274 };
275 
276 template <typename... Args>
277 class AtomsVisitor : public BaseVisitor<AtomsVisitor<Args...>>
278 {
279 public:
280  set_basic s;
281  uset_basic visited;
282 
283  template <typename T,
284  typename = enable_if_t<is_base_of_multiple<T, Args...>::value>>
285  void bvisit(const T &x)
286  {
287  s.insert(x.rcp_from_this());
288  visited.insert(x.rcp_from_this());
289  bvisit((const Basic &)x);
290  }
291 
292  void bvisit(const Basic &x)
293  {
294  for (const auto &p : x.get_args()) {
295  auto iter = visited.insert(p->rcp_from_this());
296  if (iter.second) {
297  p->accept(*this);
298  }
299  }
300  }
301 
302  set_basic apply(const Basic &b)
303  {
304  b.accept(*this);
305  return s;
306  }
307 };
308 
309 template <typename... Args>
310 inline set_basic atoms(const Basic &b)
311 {
312  AtomsVisitor<Args...> visitor;
313  return visitor.apply(b);
314 };
315 
316 class CountOpsVisitor : public BaseVisitor<CountOpsVisitor>
317 {
318 protected:
320  v;
321 
322 public:
323  unsigned count = 0;
324  void apply(const Basic &b);
325  void bvisit(const Mul &x);
326  void bvisit(const Add &x);
327  void bvisit(const Pow &x);
328  void bvisit(const Number &x);
329  void bvisit(const ComplexBase &x);
330  void bvisit(const Symbol &x);
331  void bvisit(const Constant &x);
332  void bvisit(const Basic &x);
333 };
334 
335 unsigned count_ops(const vec_basic &a);
336 
337 } // namespace SymEngine
338 
339 #endif
The base class for representing addition in symbolic expressions.
Definition: add.h:27
static RCP< const Basic > from_dict(const RCP< const Number > &coef, umap_basic_num &&d)
Create an appropriate instance from dictionary quickly.
Definition: add.cpp:140
static void coef_dict_add_term(const Ptr< RCP< const Number >> &coef, umap_basic_num &d, const RCP< const Number > &c, const RCP< const Basic > &term)
Updates the numerical coefficient and the dictionary.
Definition: add.cpp:261
const RCP< const Number > & get_coef() const
Definition: add.h:142
The lowest unit of symbolic representation.
Definition: basic.h:95
virtual vec_basic get_args() const =0
Returns the list of arguments.
ComplexBase Class for deriving all complex classes.
Definition: complex.h:16
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:116
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_base() const
Definition: pow.h:37
virtual RCP< const Basic > create(const RCP< const Basic > &a, const RCP< const Basic > &b) const =0
Method to construct classes with canonicalization.
RCP< const Basic > get_arg1() const
Definition: functions.h:91
RCP< const Basic > get_arg2() const
Definition: functions.h:96
T erase(T... args)
T insert(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
Our comparison (==)
Definition: basic.h:211