Program Listing for File visitor.h¶
↰ Return to documentation for file (symengine/symengine/visitor.h
)
#ifndef SYMENGINE_VISITOR_H
#define SYMENGINE_VISITOR_H
#include <symengine/polys/uintpoly_flint.h>
#include <symengine/polys/uintpoly_piranha.h>
#include <symengine/polys/uexprpoly.h>
#include <symengine/polys/msymenginepoly.h>
#include <symengine/polys/uratpoly.h>
#include <symengine/complex_mpc.h>
#include <symengine/series_generic.h>
#include <symengine/series_piranha.h>
#include <symengine/series_flint.h>
#include <symengine/series_generic.h>
#include <symengine/series_piranha.h>
#include <symengine/sets.h>
#include <symengine/fields.h>
#include <symengine/logic.h>
#include <symengine/infinity.h>
#include <symengine/nan.h>
#include <symengine/matrix.h>
#include <symengine/symengine_casts.h>
namespace SymEngine
{
class Visitor
{
public:
virtual ~Visitor(){};
#define SYMENGINE_ENUM(TypeID, Class) virtual void visit(const Class &) = 0;
#include "symengine/type_codes.inc"
#undef SYMENGINE_ENUM
};
void preorder_traversal(const Basic &b, Visitor &v);
void postorder_traversal(const Basic &b, Visitor &v);
template <class Derived, class Base = Visitor>
class BaseVisitor : public Base
{
public:
#if (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ < 8 \
|| defined(__SUNPRO_CC))
// Following two ctors can be replaced by `using Base::Base` if inheriting
// constructors are allowed by the compiler. GCC 4.8 is the earliest
// version supporting this.
template <typename... Args,
typename
= enable_if_t<std::is_constructible<Base, Args...>::value>>
BaseVisitor(Args &&... args) : Base(std::forward<Args>(args)...)
{
}
BaseVisitor() : Base()
{
}
#else
using Base::Base;
#endif
#define SYMENGINE_ENUM(TypeID, Class) \
virtual void visit(const Class &x) \
{ \
down_cast<Derived *>(this)->bvisit(x); \
};
#include "symengine/type_codes.inc"
#undef SYMENGINE_ENUM
};
class StopVisitor : public Visitor
{
public:
bool stop_;
};
class LocalStopVisitor : public StopVisitor
{
public:
bool local_stop_;
};
void preorder_traversal_stop(const Basic &b, StopVisitor &v);
void postorder_traversal_stop(const Basic &b, StopVisitor &v);
void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v);
class HasSymbolVisitor : public BaseVisitor<HasSymbolVisitor, StopVisitor>
{
protected:
Ptr<const Basic> x_;
bool has_;
public:
HasSymbolVisitor(Ptr<const Basic> x) : x_(x)
{
}
void bvisit(const Symbol &x)
{
if (eq(*x_, x)) {
has_ = true;
stop_ = true;
}
}
void bvisit(const FunctionSymbol &x)
{
if (eq(*x_, x)) {
has_ = true;
stop_ = true;
}
}
void bvisit(const Basic &x){};
bool apply(const Basic &b)
{
has_ = false;
stop_ = false;
preorder_traversal_stop(b, *this);
return has_;
}
};
bool has_symbol(const Basic &b, const Basic &x);
class CoeffVisitor : public BaseVisitor<CoeffVisitor, StopVisitor>
{
protected:
Ptr<const Basic> x_;
Ptr<const Basic> n_;
RCP<const Basic> coeff_;
public:
CoeffVisitor(Ptr<const Basic> x, Ptr<const Basic> n) : x_(x), n_(n)
{
}
void bvisit(const Add &x)
{
umap_basic_num dict;
RCP<const Number> coef = zero;
for (auto &p : x.get_dict()) {
p.first->accept(*this);
if (neq(*coeff_, *zero)) {
Add::coef_dict_add_term(outArg(coef), dict, p.second, coeff_);
}
}
if (eq(*zero, *n_)) {
iaddnum(outArg(coef), x.get_coef());
}
coeff_ = Add::from_dict(coef, std::move(dict));
}
void bvisit(const Mul &x)
{
for (auto &p : x.get_dict()) {
if (eq(*p.first, *x_) and eq(*p.second, *n_)) {
map_basic_basic dict = x.get_dict();
dict.erase(p.first);
coeff_ = Mul::from_dict(x.get_coef(), std::move(dict));
return;
}
}
if (eq(*zero, *n_) and not has_symbol(x, *x_)) {
coeff_ = x.rcp_from_this();
} else {
coeff_ = zero;
}
}
void bvisit(const Pow &x)
{
if (eq(*x.get_base(), *x_) and eq(*x.get_exp(), *n_)) {
coeff_ = one;
} else if (neq(*x.get_base(), *x_) and eq(*zero, *n_)) {
coeff_ = x.rcp_from_this();
} else {
coeff_ = zero;
}
}
void bvisit(const Symbol &x)
{
if (eq(x, *x_) and eq(*one, *n_)) {
coeff_ = one;
} else if (neq(x, *x_) and eq(*zero, *n_)) {
coeff_ = x.rcp_from_this();
} else {
coeff_ = zero;
}
}
void bvisit(const FunctionSymbol &x)
{
if (eq(x, *x_) and eq(*one, *n_)) {
coeff_ = one;
} else if (neq(x, *x_) and eq(*zero, *n_)) {
coeff_ = x.rcp_from_this();
} else {
coeff_ = zero;
}
}
void bvisit(const Basic &x)
{
if (neq(*zero, *n_)) {
coeff_ = zero;
return;
}
if (has_symbol(x, *x_)) {
coeff_ = zero;
} else {
coeff_ = x.rcp_from_this();
}
}
RCP<const Basic> apply(const Basic &b)
{
coeff_ = zero;
b.accept(*this);
return coeff_;
}
};
RCP<const Basic> coeff(const Basic &b, const Basic &x, const Basic &n);
set_basic free_symbols(const Basic &b);
set_basic free_symbols(const MatrixBase &m);
set_basic function_symbols(const Basic &b);
class TransformVisitor : public BaseVisitor<TransformVisitor>
{
protected:
RCP<const Basic> result_;
public:
TransformVisitor()
{
}
virtual RCP<const Basic> apply(const RCP<const Basic> &x);
void bvisit(const Basic &x);
void bvisit(const Add &x);
void bvisit(const Mul &x);
void bvisit(const Pow &x);
void bvisit(const OneArgFunction &x);
template <class T>
void bvisit(const TwoArgBasic<T> &x)
{
auto farg1 = x.get_arg1(), farg2 = x.get_arg2();
auto newarg1 = apply(farg1), newarg2 = apply(farg2);
if (farg1 != newarg1 or farg2 != newarg2) {
result_ = x.create(newarg1, newarg2);
} else {
result_ = x.rcp_from_this();
}
}
void bvisit(const MultiArgFunction &x);
};
template <typename Derived, typename First, typename... Rest>
struct is_base_of_multiple {
static const bool value = std::is_base_of<First, Derived>::value
or is_base_of_multiple<Derived, Rest...>::value;
};
template <typename Derived, typename First>
struct is_base_of_multiple<Derived, First> {
static const bool value = std::is_base_of<First, Derived>::value;
};
template <typename... Args>
class AtomsVisitor : public BaseVisitor<AtomsVisitor<Args...>>
{
public:
set_basic s;
uset_basic visited;
template <typename T,
typename = enable_if_t<is_base_of_multiple<T, Args...>::value>>
void bvisit(const T &x)
{
s.insert(x.rcp_from_this());
visited.insert(x.rcp_from_this());
bvisit((const Basic &)x);
}
void bvisit(const Basic &x)
{
for (const auto &p : x.get_args()) {
auto iter = visited.insert(p->rcp_from_this());
if (iter.second) {
p->accept(*this);
}
}
}
set_basic apply(const Basic &b)
{
b.accept(*this);
return s;
}
};
template <typename... Args>
inline set_basic atoms(const Basic &b)
{
AtomsVisitor<Args...> visitor;
return visitor.apply(b);
};
class CountOpsVisitor : public BaseVisitor<CountOpsVisitor>
{
protected:
std::unordered_map<RCP<const Basic>, unsigned, RCPBasicHash, RCPBasicKeyEq>
v;
public:
unsigned count = 0;
void apply(const Basic &b);
void bvisit(const Mul &x);
void bvisit(const Add &x);
void bvisit(const Pow &x);
void bvisit(const Number &x);
void bvisit(const ComplexBase &x);
void bvisit(const Symbol &x);
void bvisit(const Constant &x);
void bvisit(const Basic &x);
};
unsigned count_ops(const vec_basic &a);
} // SymEngine
#endif