visitor.cpp
1 #include <symengine/visitor.h>
2 #include <symengine/polys/basic_conversions.h>
3 #include <symengine/sets.h>
4 
5 #define ACCEPT(CLASS) \
6  void CLASS::accept(Visitor &v) const \
7  { \
8  v.visit(*this); \
9  }
10 
11 namespace SymEngine
12 {
13 
14 #define SYMENGINE_ENUM(TypeID, Class) ACCEPT(Class)
15 #include "symengine/type_codes.inc"
16 #undef SYMENGINE_ENUM
17 
18 void preorder_traversal(const Basic &b, Visitor &v)
19 {
20  b.accept(v);
21  for (const auto &p : b.get_args())
22  preorder_traversal(*p, v);
23 }
24 
25 void postorder_traversal(const Basic &b, Visitor &v)
26 {
27  for (const auto &p : b.get_args())
28  postorder_traversal(*p, v);
29  b.accept(v);
30 }
31 
32 void preorder_traversal_stop(const Basic &b, StopVisitor &v)
33 {
34  b.accept(v);
35  if (v.stop_)
36  return;
37  for (const auto &p : b.get_args()) {
38  preorder_traversal_stop(*p, v);
39  if (v.stop_)
40  return;
41  }
42 }
43 
44 void postorder_traversal_stop(const Basic &b, StopVisitor &v)
45 {
46  for (const auto &p : b.get_args()) {
47  postorder_traversal_stop(*p, v);
48  if (v.stop_)
49  return;
50  }
51  b.accept(v);
52 }
53 
54 bool has_symbol(const Basic &b, const Basic &x)
55 {
56  // We are breaking a rule when using ptrFromRef() here, but since
57  // HasSymbolVisitor is only instantiated and freed from here, the `x` can
58  // never go out of scope, so this is safe.
59  HasSymbolVisitor v(ptrFromRef(x));
60  return v.apply(b);
61 }
62 
63 RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n)
64 {
65  if (!(is_a<Symbol>(x) || is_a<FunctionSymbol>(x))) {
66  throw NotImplementedError("Not implemented for non (Function)Symbols.");
67  }
68  CoeffVisitor v(ptrFromRef(x), ptrFromRef(n));
69  return v.apply(b);
70 }
71 
72 class FreeSymbolsVisitor : public BaseVisitor<FreeSymbolsVisitor>
73 {
74 public:
75  set_basic s;
76  uset_basic v;
77 
78  void bvisit(const Symbol &x)
79  {
80  s.insert(x.rcp_from_this());
81  }
82 
83  void bvisit(const Subs &x)
84  {
85  set_basic set_ = free_symbols(*x.get_arg());
86  for (const auto &p : x.get_variables()) {
87  set_.erase(p);
88  }
89  s.insert(set_.begin(), set_.end());
90  for (const auto &p : x.get_point()) {
91  auto iter = v.insert(p->rcp_from_this());
92  if (iter.second) {
93  p->accept(*this);
94  }
95  }
96  }
97 
98  void bvisit(const Basic &x)
99  {
100  for (const auto &p : x.get_args()) {
101  auto iter = v.insert(p->rcp_from_this());
102  if (iter.second) {
103  p->accept(*this);
104  }
105  }
106  }
107 
108  set_basic apply(const Basic &b)
109  {
110  b.accept(*this);
111  return s;
112  }
113 
114  set_basic apply(const MatrixBase &m)
115  {
116  for (unsigned i = 0; i < m.nrows(); i++) {
117  for (unsigned j = 0; j < m.ncols(); j++) {
118  m.get(i, j)->accept(*this);
119  }
120  }
121  return s;
122  }
123 };
124 
125 set_basic free_symbols(const MatrixBase &m)
126 {
127  FreeSymbolsVisitor visitor;
128  return visitor.apply(m);
129 }
130 
131 set_basic free_symbols(const Basic &b)
132 {
133  FreeSymbolsVisitor visitor;
134  return visitor.apply(b);
135 }
136 
137 set_basic function_symbols(const Basic &b)
138 {
139  return atoms<FunctionSymbol>(b);
140 }
141 
142 RCP<const Basic> TransformVisitor::apply(const RCP<const Basic> &x)
143 {
144  x->accept(*this);
145  return result_;
146 }
147 
148 void TransformVisitor::bvisit(const Basic &x)
149 {
150  result_ = x.rcp_from_this();
151 }
152 
153 void TransformVisitor::bvisit(const Add &x)
154 {
155  vec_basic newargs;
156  for (const auto &a : x.get_args()) {
157  newargs.push_back(apply(a));
158  }
159  result_ = add(newargs);
160 }
161 
162 void TransformVisitor::bvisit(const Mul &x)
163 {
164  vec_basic newargs;
165  for (const auto &a : x.get_args()) {
166  newargs.push_back(apply(a));
167  }
168  result_ = mul(newargs);
169 }
170 
171 void TransformVisitor::bvisit(const Pow &x)
172 {
173  auto base_ = x.get_base(), exp_ = x.get_exp();
174  auto newarg1 = apply(base_), newarg2 = apply(exp_);
175  if (base_ != newarg1 or exp_ != newarg2) {
176  result_ = pow(newarg1, newarg2);
177  } else {
178  result_ = x.rcp_from_this();
179  }
180 }
181 
182 void TransformVisitor::bvisit(const OneArgFunction &x)
183 {
184  auto farg = x.get_arg();
185  auto newarg = apply(farg);
186  if (eq(*newarg, *farg)) {
187  result_ = x.rcp_from_this();
188  } else {
189  result_ = x.create(newarg);
190  }
191 }
192 
193 void TransformVisitor::bvisit(const MultiArgFunction &x)
194 {
195  auto fargs = x.get_args();
196  vec_basic newargs;
197  for (const auto &a : fargs) {
198  newargs.push_back(apply(a));
199  }
200  auto nbarg = x.create(newargs);
201  result_ = nbarg;
202 }
203 
204 void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v)
205 {
206  b.accept(v);
207  if (v.stop_ or v.local_stop_)
208  return;
209  for (const auto &p : b.get_args()) {
210  preorder_traversal_local_stop(*p, v);
211  if (v.stop_)
212  return;
213  }
214 }
215 
216 void CountOpsVisitor::apply(const Basic &b)
217 {
218  unsigned count_now = count;
219  auto it = v.find(b.rcp_from_this());
220  if (it == v.end()) {
221  b.accept(*this);
222  insert(v, b.rcp_from_this(), count - count_now);
223  } else {
224  count += it->second;
225  }
226 }
227 
228 void CountOpsVisitor::bvisit(const Mul &x)
229 {
230  if (neq(*(x.get_coef()), *one)) {
231  count++;
232  apply(*x.get_coef());
233  }
234 
235  for (const auto &p : x.get_dict()) {
236  if (neq(*p.second, *one)) {
237  count++;
238  apply(*p.second);
239  }
240  apply(*p.first);
241  count++;
242  }
243  count--;
244 }
245 
246 void CountOpsVisitor::bvisit(const Add &x)
247 {
248  if (neq(*(x.get_coef()), *zero)) {
249  count++;
250  apply(*x.get_coef());
251  }
252 
253  unsigned i = 0;
254  for (const auto &p : x.get_dict()) {
255  if (neq(*p.second, *one)) {
256  count++;
257  apply(*p.second);
258  }
259  apply(*p.first);
260  count++;
261  i++;
262  }
263  count--;
264 }
265 
266 void CountOpsVisitor::bvisit(const Pow &x)
267 {
268  count++;
269  apply(*x.get_exp());
270  apply(*x.get_base());
271 }
272 
273 void CountOpsVisitor::bvisit(const Number &x) {}
274 
275 void CountOpsVisitor::bvisit(const ComplexBase &x)
276 {
277  if (neq(*x.real_part(), *zero)) {
278  count++;
279  }
280 
281  if (neq(*x.imaginary_part(), *one)) {
282  count++;
283  }
284 }
285 
286 void CountOpsVisitor::bvisit(const Symbol &x) {}
287 
288 void CountOpsVisitor::bvisit(const Constant &x) {}
289 
290 void CountOpsVisitor::bvisit(const Basic &x)
291 {
292  count++;
293  for (const auto &p : x.get_args()) {
294  apply(*p);
295  }
296 }
297 
298 unsigned count_ops(const vec_basic &a)
299 {
300  CountOpsVisitor v;
301  for (auto &p : a) {
302  v.apply(*p);
303  }
304  return v.count;
305 }
306 
307 } // namespace SymEngine
T begin(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:95
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
T end(T... args)
T erase(T... args)
T insert(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:347
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
T pow(T... args)
T push_back(T... args)