visitor.cpp
1 #include <symengine/visitor.h>
2 #include <symengine/polys/basic_conversions.h>
3 #include <symengine/sets.h>
4 
5 #define ACCEPT(CLASS) \
6  void CLASS::accept(Visitor &v) const \
7  { \
8  v.visit(*this); \
9  }
10 
11 namespace SymEngine
12 {
13 
14 #define SYMENGINE_ENUM(TypeID, Class) ACCEPT(Class)
15 #include "symengine/type_codes.inc"
16 #undef SYMENGINE_ENUM
17 
18 void preorder_traversal(const Basic &b, Visitor &v)
19 {
20  b.accept(v);
21  for (const auto &p : b.get_args())
22  preorder_traversal(*p, v);
23 }
24 
25 void postorder_traversal(const Basic &b, Visitor &v)
26 {
27  for (const auto &p : b.get_args())
28  postorder_traversal(*p, v);
29  b.accept(v);
30 }
31 
32 void preorder_traversal_stop(const Basic &b, StopVisitor &v)
33 {
34  b.accept(v);
35  if (v.stop_)
36  return;
37  for (const auto &p : b.get_args()) {
38  preorder_traversal_stop(*p, v);
39  if (v.stop_)
40  return;
41  }
42 }
43 
44 void postorder_traversal_stop(const Basic &b, StopVisitor &v)
45 {
46  for (const auto &p : b.get_args()) {
47  postorder_traversal_stop(*p, v);
48  if (v.stop_)
49  return;
50  }
51  b.accept(v);
52 }
53 
54 bool has_symbol(const Basic &b, const Basic &x)
55 {
56  // We are breaking a rule when using ptrFromRef() here, but since
57  // HasSymbolVisitor is only instantiated and freed from here, the `x` can
58  // never go out of scope, so this is safe.
59  HasSymbolVisitor v(ptrFromRef(x));
60  return v.apply(b);
61 }
62 
63 RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n)
64 {
65  if (!(is_a<Symbol>(x) || is_a<FunctionSymbol>(x))) {
66  throw NotImplementedError("Not implemented for non (Function)Symbols.");
67  }
68  CoeffVisitor v(ptrFromRef(x), ptrFromRef(n));
69  return v.apply(b);
70 }
71 
72 class FreeSymbolsVisitor : public BaseVisitor<FreeSymbolsVisitor>
73 {
74 public:
75  set_basic s;
76  uset_basic v;
77 
78  void bvisit(const Symbol &x)
79  {
80  s.insert(x.rcp_from_this());
81  }
82 
83  void bvisit(const Subs &x)
84  {
85  set_basic set_ = free_symbols(*x.get_arg());
86  for (const auto &p : x.get_variables()) {
87  set_.erase(p);
88  }
89  s.insert(set_.begin(), set_.end());
90  for (const auto &p : x.get_point()) {
91  auto iter = v.insert(p->rcp_from_this());
92  if (iter.second) {
93  p->accept(*this);
94  }
95  }
96  }
97 
98  void bvisit(const Basic &x)
99  {
100  for (const auto &p : x.get_args()) {
101  auto iter = v.insert(p->rcp_from_this());
102  if (iter.second) {
103  p->accept(*this);
104  }
105  }
106  }
107 
108  set_basic apply(const Basic &b)
109  {
110  b.accept(*this);
111  return s;
112  }
113 
114  set_basic apply(const MatrixBase &m)
115  {
116  for (unsigned i = 0; i < m.nrows(); i++) {
117  for (unsigned j = 0; j < m.ncols(); j++) {
118  m.get(i, j)->accept(*this);
119  }
120  }
121  return s;
122  }
123 };
124 
125 set_basic free_symbols(const MatrixBase &m)
126 {
127  FreeSymbolsVisitor visitor;
128  return visitor.apply(m);
129 }
130 
131 set_basic free_symbols(const Basic &b)
132 {
133  FreeSymbolsVisitor visitor;
134  return visitor.apply(b);
135 }
136 
137 set_basic function_symbols(const Basic &b)
138 {
139  return atoms<FunctionSymbol>(b);
140 }
141 
142 RCP<const Basic> TransformVisitor::apply(const RCP<const Basic> &x)
143 {
144  x->accept(*this);
145  return result_;
146 }
147 
148 void TransformVisitor::bvisit(const Basic &x)
149 {
150  result_ = x.rcp_from_this();
151 }
152 
153 void TransformVisitor::bvisit(const Add &x)
154 {
155  vec_basic newargs;
156  for (const auto &a : x.get_args()) {
157  newargs.push_back(apply(a));
158  }
159  result_ = add(newargs);
160 }
161 
162 void TransformVisitor::bvisit(const Mul &x)
163 {
164  vec_basic newargs;
165  for (const auto &a : x.get_args()) {
166  newargs.push_back(apply(a));
167  }
168  result_ = mul(newargs);
169 }
170 
171 void TransformVisitor::bvisit(const Pow &x)
172 {
173  auto base_ = x.get_base(), exp_ = x.get_exp();
174  auto newarg1 = apply(base_), newarg2 = apply(exp_);
175  if (base_ != newarg1 or exp_ != newarg2) {
176  result_ = pow(newarg1, newarg2);
177  } else {
178  result_ = x.rcp_from_this();
179  }
180 }
181 
182 void TransformVisitor::bvisit(const OneArgFunction &x)
183 {
184  auto farg = x.get_arg();
185  auto newarg = apply(farg);
186  if (eq(*newarg, *farg)) {
187  result_ = x.rcp_from_this();
188  } else {
189  result_ = x.create(newarg);
190  }
191 }
192 
193 void TransformVisitor::bvisit(const MultiArgFunction &x)
194 {
195  auto fargs = x.get_args();
196  vec_basic newargs;
197  for (const auto &a : fargs) {
198  newargs.push_back(apply(a));
199  }
200  auto nbarg = x.create(newargs);
201  result_ = nbarg;
202 }
203 
204 void TransformVisitor::bvisit(const Piecewise &x)
205 {
206  auto branch_cond_pairs = x.get_vec();
207  PiecewiseVec new_pairs;
208  for (const auto &branch_cond : branch_cond_pairs) {
209  auto branch = branch_cond.first;
210  auto cond = branch_cond.second;
211  auto new_branch = apply(branch);
212  auto new_cond = apply(cond);
213  if (!is_a_Boolean(*new_cond)) {
214  new_cond = Eq(new_cond, boolTrue);
215  }
216  new_pairs.push_back(
217  {new_branch, rcp_static_cast<const Boolean>(new_cond)});
218  }
219  result_ = piecewise(new_pairs);
220 }
221 
222 void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v)
223 {
224  b.accept(v);
225  if (v.stop_ or v.local_stop_)
226  return;
227  for (const auto &p : b.get_args()) {
228  preorder_traversal_local_stop(*p, v);
229  if (v.stop_)
230  return;
231  }
232 }
233 
234 void CountOpsVisitor::apply(const Basic &b)
235 {
236  unsigned count_now = count;
237  auto it = v.find(b.rcp_from_this());
238  if (it == v.end()) {
239  b.accept(*this);
240  insert(v, b.rcp_from_this(), count - count_now);
241  } else {
242  count += it->second;
243  }
244 }
245 
246 void CountOpsVisitor::bvisit(const Mul &x)
247 {
248  if (neq(*(x.get_coef()), *one)) {
249  count++;
250  apply(*x.get_coef());
251  }
252 
253  for (const auto &p : x.get_dict()) {
254  if (neq(*p.second, *one)) {
255  count++;
256  apply(*p.second);
257  }
258  apply(*p.first);
259  count++;
260  }
261  count--;
262 }
263 
264 void CountOpsVisitor::bvisit(const Add &x)
265 {
266  if (neq(*(x.get_coef()), *zero)) {
267  count++;
268  apply(*x.get_coef());
269  }
270 
271  for (const auto &p : x.get_dict()) {
272  if (neq(*p.second, *one)) {
273  count++;
274  apply(*p.second);
275  }
276  apply(*p.first);
277  count++;
278  }
279  count--;
280 }
281 
282 void CountOpsVisitor::bvisit(const Pow &x)
283 {
284  count++;
285  apply(*x.get_exp());
286  apply(*x.get_base());
287 }
288 
289 void CountOpsVisitor::bvisit(const Number &x) {}
290 
291 void CountOpsVisitor::bvisit(const ComplexBase &x)
292 {
293  if (neq(*x.real_part(), *zero)) {
294  count++;
295  }
296 
297  if (neq(*x.imaginary_part(), *one)) {
298  count++;
299  }
300 }
301 
302 void CountOpsVisitor::bvisit(const Symbol &x) {}
303 
304 void CountOpsVisitor::bvisit(const Constant &x) {}
305 
306 void CountOpsVisitor::bvisit(const Basic &x)
307 {
308  count++;
309  for (const auto &p : x.get_args()) {
310  apply(*p);
311  }
312 }
313 
314 unsigned count_ops(const vec_basic &a)
315 {
316  CountOpsVisitor v;
317  for (auto &p : a) {
318  v.apply(*p);
319  }
320  return v.count;
321 }
322 
323 } // namespace SymEngine
T begin(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
T end(T... args)
T erase(T... args)
T insert(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:653
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
T pow(T... args)
T push_back(T... args)