visitor.cpp
1 #include "symengine/symengine_exception.h"
2 #include <symengine/visitor.h>
3 #include <symengine/polys/basic_conversions.h>
4 #include <symengine/sets.h>
5 
6 #define ACCEPT(CLASS) \
7  void CLASS::accept(Visitor &v) const \
8  { \
9  v.visit(*this); \
10  }
11 
12 namespace SymEngine
13 {
14 
15 #define SYMENGINE_ENUM(TypeID, Class) ACCEPT(Class)
16 #include "symengine/type_codes.inc"
17 #undef SYMENGINE_ENUM
18 
19 void preorder_traversal(const Basic &b, Visitor &v)
20 {
21  b.accept(v);
22  for (const auto &p : b.get_args())
23  preorder_traversal(*p, v);
24 }
25 
26 void postorder_traversal(const Basic &b, Visitor &v)
27 {
28  for (const auto &p : b.get_args())
29  postorder_traversal(*p, v);
30  b.accept(v);
31 }
32 
33 void preorder_traversal_stop(const Basic &b, StopVisitor &v)
34 {
35  b.accept(v);
36  if (v.stop_)
37  return;
38  for (const auto &p : b.get_args()) {
39  preorder_traversal_stop(*p, v);
40  if (v.stop_)
41  return;
42  }
43 }
44 
45 void postorder_traversal_stop(const Basic &b, StopVisitor &v)
46 {
47  for (const auto &p : b.get_args()) {
48  postorder_traversal_stop(*p, v);
49  if (v.stop_)
50  return;
51  }
52  b.accept(v);
53 }
54 
55 bool has_basic(const Basic &b, const Basic &x)
56 {
57  // We are breaking a rule when using ptrFromRef() here, but since
58  // HasBasicVisitor is only instantiated and freed from here, the `x` can
59  // never go out of scope, so this is safe.
60  HasBasicVisitor v(ptrFromRef(x));
61  return v.apply(b);
62 }
63 
64 bool has_symbol(const Basic &b, const Basic &x)
65 {
66  // We are breaking a rule when using ptrFromRef() here, but since
67  // HasSymbolVisitor is only instantiated and freed from here, the `x` can
68  // never go out of scope, so this is safe.
69  HasSymbolVisitor v(ptrFromRef(x));
70  return v.apply(b);
71 }
72 
73 RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n)
74 {
75  if (!(is_a<Symbol>(x) || is_a<FunctionSymbol>(x))) {
76  throw NotImplementedError("Not implemented for non (Function)Symbols.");
77  }
78  CoeffVisitor v(ptrFromRef(x), ptrFromRef(n));
79  return v.apply(b);
80 }
81 
82 class FreeSymbolsVisitor : public BaseVisitor<FreeSymbolsVisitor>
83 {
84 public:
85  set_basic s;
86  uset_basic v;
87 
88  void bvisit(const Symbol &x)
89  {
90  s.insert(x.rcp_from_this());
91  }
92 
93  void bvisit(const Subs &x)
94  {
95  set_basic set_ = free_symbols(*x.get_arg());
96  for (const auto &p : x.get_variables()) {
97  set_.erase(p);
98  }
99  s.insert(set_.begin(), set_.end());
100  for (const auto &p : x.get_point()) {
101  auto iter = v.insert(p->rcp_from_this());
102  if (iter.second) {
103  p->accept(*this);
104  }
105  }
106  }
107 
108  void bvisit(const Basic &x)
109  {
110  for (const auto &p : x.get_args()) {
111  auto iter = v.insert(p->rcp_from_this());
112  if (iter.second) {
113  p->accept(*this);
114  }
115  }
116  }
117 
118  set_basic apply(const Basic &b)
119  {
120  b.accept(*this);
121  return s;
122  }
123 
124  set_basic apply(const MatrixBase &m)
125  {
126  for (unsigned i = 0; i < m.nrows(); i++) {
127  for (unsigned j = 0; j < m.ncols(); j++) {
128  m.get(i, j)->accept(*this);
129  }
130  }
131  return s;
132  }
133 };
134 
135 set_basic free_symbols(const MatrixBase &m)
136 {
137  FreeSymbolsVisitor visitor;
138  return visitor.apply(m);
139 }
140 
141 set_basic free_symbols(const Basic &b)
142 {
143  FreeSymbolsVisitor visitor;
144  return visitor.apply(b);
145 }
146 
147 set_basic function_symbols(const Basic &b)
148 {
149  return atoms<FunctionSymbol>(b);
150 }
151 
152 HasBasicVisitor::HasBasicVisitor(Ptr<const Basic> looking_for)
153  : looking_for_(looking_for)
154 {
155  if (is_a<Add>(*looking_for) || is_a<Mul>(*looking_for)
156  || is_a<And>(*looking_for) || is_a<Or>(*looking_for)
157  || is_a<Xor>(*looking_for)) {
158  // To avoid confusion with how subtree matching would behave in the
159  // current state of this visitor, associative operators are for now
160  // disallowed. If there is a need for this, a more advanced (and more
161  // expensive) visitor could be created.
162  throw NotImplementedError(
163  "Associative classes not yet handled in HasBasicVisitor");
164  }
165 }
166 
167 RCP<const Basic> TransformVisitor::apply(const RCP<const Basic> &x)
168 {
169  x->accept(*this);
170  return result_;
171 }
172 
173 void TransformVisitor::bvisit(const Basic &x)
174 {
175  result_ = x.rcp_from_this();
176 }
177 
178 void TransformVisitor::bvisit(const Add &x)
179 {
180  vec_basic newargs;
181  for (const auto &a : x.get_args()) {
182  newargs.push_back(apply(a));
183  }
184  result_ = add(newargs);
185 }
186 
187 void TransformVisitor::bvisit(const Mul &x)
188 {
189  vec_basic newargs;
190  for (const auto &a : x.get_args()) {
191  newargs.push_back(apply(a));
192  }
193  result_ = mul(newargs);
194 }
195 
196 void TransformVisitor::bvisit(const Pow &x)
197 {
198  auto base_ = x.get_base(), exp_ = x.get_exp();
199  auto newarg1 = apply(base_), newarg2 = apply(exp_);
200  if (base_ != newarg1 or exp_ != newarg2) {
201  result_ = pow(newarg1, newarg2);
202  } else {
203  result_ = x.rcp_from_this();
204  }
205 }
206 
207 void TransformVisitor::bvisit(const OneArgFunction &x)
208 {
209  auto farg = x.get_arg();
210  auto newarg = apply(farg);
211  if (eq(*newarg, *farg)) {
212  result_ = x.rcp_from_this();
213  } else {
214  result_ = x.create(newarg);
215  }
216 }
217 
218 void TransformVisitor::bvisit(const MultiArgFunction &x)
219 {
220  auto fargs = x.get_args();
221  vec_basic newargs;
222  for (const auto &a : fargs) {
223  newargs.push_back(apply(a));
224  }
225  auto nbarg = x.create(newargs);
226  result_ = nbarg;
227 }
228 
229 void TransformVisitor::bvisit(const Piecewise &x)
230 {
231  auto branch_cond_pairs = x.get_vec();
232  PiecewiseVec new_pairs;
233  for (const auto &branch_cond : branch_cond_pairs) {
234  auto branch = branch_cond.first;
235  auto cond = branch_cond.second;
236  auto new_branch = apply(branch);
237  auto new_cond = apply(cond);
238  if (!is_a_Boolean(*new_cond)) {
239  new_cond = Eq(new_cond, boolTrue);
240  }
241  new_pairs.push_back(
242  {new_branch, rcp_static_cast<const Boolean>(new_cond)});
243  }
244  result_ = piecewise(new_pairs);
245 }
246 
247 void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v)
248 {
249  b.accept(v);
250  if (v.stop_ or v.local_stop_)
251  return;
252  for (const auto &p : b.get_args()) {
253  preorder_traversal_local_stop(*p, v);
254  if (v.stop_)
255  return;
256  }
257 }
258 
259 void CountOpsVisitor::apply(const Basic &b)
260 {
261  unsigned count_now = count;
262  auto it = v.find(b.rcp_from_this());
263  if (it == v.end()) {
264  b.accept(*this);
265  insert(v, b.rcp_from_this(), count - count_now);
266  } else {
267  count += it->second;
268  }
269 }
270 
271 void CountOpsVisitor::bvisit(const Mul &x)
272 {
273  if (neq(*(x.get_coef()), *one)) {
274  count++;
275  apply(*x.get_coef());
276  }
277 
278  for (const auto &p : x.get_dict()) {
279  if (neq(*p.second, *one)) {
280  count++;
281  apply(*p.second);
282  }
283  apply(*p.first);
284  count++;
285  }
286  count--;
287 }
288 
289 void CountOpsVisitor::bvisit(const Add &x)
290 {
291  if (neq(*(x.get_coef()), *zero)) {
292  count++;
293  apply(*x.get_coef());
294  }
295 
296  for (const auto &p : x.get_dict()) {
297  if (neq(*p.second, *one)) {
298  count++;
299  apply(*p.second);
300  }
301  apply(*p.first);
302  count++;
303  }
304  count--;
305 }
306 
307 void CountOpsVisitor::bvisit(const Pow &x)
308 {
309  count++;
310  apply(*x.get_exp());
311  apply(*x.get_base());
312 }
313 
314 void CountOpsVisitor::bvisit(const Number &x) {}
315 
316 void CountOpsVisitor::bvisit(const ComplexBase &x)
317 {
318  if (neq(*x.real_part(), *zero)) {
319  count++;
320  }
321 
322  if (neq(*x.imaginary_part(), *one)) {
323  count++;
324  }
325 }
326 
327 void CountOpsVisitor::bvisit(const Symbol &x) {}
328 
329 void CountOpsVisitor::bvisit(const Constant &x) {}
330 
331 void CountOpsVisitor::bvisit(const Basic &x)
332 {
333  count++;
334  for (const auto &p : x.get_args()) {
335  apply(*p);
336  }
337 }
338 
339 unsigned count_ops(const vec_basic &a)
340 {
341  CountOpsVisitor v;
342  for (auto &p : a) {
343  v.apply(*p);
344  }
345  return v.count;
346 }
347 
348 } // namespace SymEngine
The lowest unit of symbolic representation.
Definition: basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:642
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29