1 #include "symengine/symengine_exception.h"
3 #include <symengine/polys/basic_conversions.h>
6 #define ACCEPT(CLASS) \
7 void CLASS::accept(Visitor &v) const \
15 #define SYMENGINE_ENUM(TypeID, Class) ACCEPT(Class)
16 #include "symengine/type_codes.inc"
19 void preorder_traversal(
const Basic &b, Visitor &v)
22 for (
const auto &p : b.get_args())
23 preorder_traversal(*p, v);
26 void postorder_traversal(
const Basic &b, Visitor &v)
28 for (
const auto &p : b.get_args())
29 postorder_traversal(*p, v);
33 void preorder_traversal_stop(
const Basic &b, StopVisitor &v)
38 for (
const auto &p : b.get_args()) {
39 preorder_traversal_stop(*p, v);
45 void postorder_traversal_stop(
const Basic &b, StopVisitor &v)
47 for (
const auto &p : b.get_args()) {
48 postorder_traversal_stop(*p, v);
55 bool has_basic(
const Basic &b,
const Basic &x)
60 HasBasicVisitor v(ptrFromRef(x));
64 bool has_symbol(
const Basic &b,
const Basic &x)
69 HasSymbolVisitor v(ptrFromRef(x));
73 RCP<const Basic> coeff(
const Basic &b,
const Basic &x,
const Basic &n)
75 if (!(is_a<Symbol>(x) || is_a<FunctionSymbol>(x))) {
76 throw NotImplementedError(
"Not implemented for non (Function)Symbols.");
78 CoeffVisitor v(ptrFromRef(x), ptrFromRef(n));
88 void bvisit(
const Symbol &x)
93 void bvisit(
const Subs &x)
95 set_basic set_ = free_symbols(*x.get_arg());
96 for (
const auto &p : x.get_variables()) {
99 s.insert(set_.begin(), set_.end());
100 for (
const auto &p : x.get_point()) {
101 auto iter = v.insert(p->rcp_from_this());
108 void bvisit(
const Basic &x)
110 for (
const auto &p : x.
get_args()) {
111 auto iter = v.insert(p->rcp_from_this());
118 set_basic apply(
const Basic &b)
126 for (
unsigned i = 0; i < m.nrows(); i++) {
127 for (
unsigned j = 0; j < m.ncols(); j++) {
128 m.get(i, j)->accept(*
this);
138 return visitor.apply(m);
141 set_basic free_symbols(
const Basic &b)
143 FreeSymbolsVisitor visitor;
144 return visitor.apply(b);
147 set_basic function_symbols(
const Basic &b)
149 return atoms<FunctionSymbol>(b);
152 HasBasicVisitor::HasBasicVisitor(Ptr<const Basic> looking_for)
153 : looking_for_(looking_for)
155 if (is_a<Add>(*looking_for) || is_a<Mul>(*looking_for)
156 || is_a<And>(*looking_for) || is_a<Or>(*looking_for)
157 || is_a<Xor>(*looking_for)) {
162 throw NotImplementedError(
163 "Associative classes not yet handled in HasBasicVisitor");
167 RCP<const Basic> TransformVisitor::apply(
const RCP<const Basic> &x)
173 void TransformVisitor::bvisit(
const Basic &x)
175 result_ = x.rcp_from_this();
178 void TransformVisitor::bvisit(
const Add &x)
181 for (
const auto &a : x.get_args()) {
182 newargs.push_back(apply(a));
184 result_ =
add(newargs);
187 void TransformVisitor::bvisit(
const Mul &x)
190 for (
const auto &a : x.get_args()) {
191 newargs.push_back(apply(a));
193 result_ =
mul(newargs);
196 void TransformVisitor::bvisit(
const Pow &x)
198 auto base_ = x.get_base(), exp_ = x.get_exp();
199 auto newarg1 = apply(base_), newarg2 = apply(exp_);
200 if (base_ != newarg1 or exp_ != newarg2) {
201 result_ = pow(newarg1, newarg2);
203 result_ = x.rcp_from_this();
207 void TransformVisitor::bvisit(
const OneArgFunction &x)
209 auto farg = x.get_arg();
210 auto newarg = apply(farg);
211 if (
eq(*newarg, *farg)) {
212 result_ = x.rcp_from_this();
214 result_ = x.create(newarg);
218 void TransformVisitor::bvisit(
const MultiArgFunction &x)
220 auto fargs = x.get_args();
222 for (
const auto &a : fargs) {
223 newargs.push_back(apply(a));
225 auto nbarg = x.create(newargs);
229 void TransformVisitor::bvisit(
const Piecewise &x)
231 auto branch_cond_pairs = x.get_vec();
232 PiecewiseVec new_pairs;
233 for (
const auto &branch_cond : branch_cond_pairs) {
234 auto branch = branch_cond.first;
235 auto cond = branch_cond.second;
236 auto new_branch = apply(branch);
237 auto new_cond = apply(cond);
238 if (!is_a_Boolean(*new_cond)) {
239 new_cond =
Eq(new_cond, boolTrue);
242 {new_branch, rcp_static_cast<const Boolean>(new_cond)});
244 result_ = piecewise(new_pairs);
247 void preorder_traversal_local_stop(
const Basic &b, LocalStopVisitor &v)
250 if (v.stop_ or v.local_stop_)
252 for (
const auto &p : b.get_args()) {
253 preorder_traversal_local_stop(*p, v);
259 void CountOpsVisitor::apply(
const Basic &b)
261 unsigned count_now = count;
262 auto it = v.find(b.rcp_from_this());
265 insert(v, b.rcp_from_this(), count - count_now);
271 void CountOpsVisitor::bvisit(
const Mul &x)
273 if (
neq(*(x.get_coef()), *one)) {
275 apply(*x.get_coef());
278 for (
const auto &p : x.get_dict()) {
279 if (
neq(*p.second, *one)) {
289 void CountOpsVisitor::bvisit(
const Add &x)
291 if (
neq(*(x.get_coef()), *zero)) {
293 apply(*x.get_coef());
296 for (
const auto &p : x.get_dict()) {
297 if (
neq(*p.second, *one)) {
307 void CountOpsVisitor::bvisit(
const Pow &x)
311 apply(*x.get_base());
314 void CountOpsVisitor::bvisit(
const Number &x) {}
316 void CountOpsVisitor::bvisit(
const ComplexBase &x)
318 if (
neq(*x.real_part(), *zero)) {
322 if (
neq(*x.imaginary_part(), *one)) {
327 void CountOpsVisitor::bvisit(
const Symbol &x) {}
329 void CountOpsVisitor::bvisit(
const Constant &x) {}
331 void CountOpsVisitor::bvisit(
const Basic &x)
334 for (
const auto &p : x.get_args()) {
339 unsigned count_ops(
const vec_basic &a)
The lowest unit of symbolic representation.
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
Main namespace for SymEngine package.
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
void insert(T1 &m, const T2 &first, const T3 &second)
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b