Program Listing for File series_visitor.h¶
↰ Return to documentation for file (symengine/symengine/series_visitor.h
)
#ifndef SYMENGINE_SERIES_VISITOR_H
#define SYMENGINE_SERIES_VISITOR_H
#include <symengine/visitor.h>
namespace SymEngine
{
template <typename Poly, typename Coeff, typename Series>
class SeriesVisitor : public BaseVisitor<SeriesVisitor<Poly, Coeff, Series>>
{
private:
Poly p;
const Poly var;
const std::string varname;
const unsigned prec;
public:
inline SeriesVisitor(const Poly &var_, const std::string &varname_,
const unsigned prec_)
: var(var_), varname(varname_), prec(prec_)
{
}
RCP<const Series> series(const RCP<const Basic> &x)
{
return make_rcp<Series>(apply(x), varname, prec);
}
Poly apply(const RCP<const Basic> &x)
{
x->accept(*this);
Poly temp(std::move(p));
return temp;
}
void bvisit(const Add &x)
{
Poly temp(apply(x.get_coef()));
for (const auto &term : x.get_dict()) {
temp += apply(term.first) * apply(term.second);
}
p = temp;
}
void bvisit(const Mul &x)
{
Poly temp(apply(x.get_coef()));
for (const auto &term : x.get_dict()) {
temp = Series::mul(temp, apply(pow(term.first, term.second)), prec);
}
p = temp;
}
void bvisit(const Pow &x)
{
const RCP<const Basic> &base = x.get_base(), exp = x.get_exp();
if (is_a<Integer>(*exp)) {
const Integer &ii = (down_cast<const Integer &>(*exp));
if (not mp_fits_slong_p(ii.as_integer_class()))
throw SymEngineException("series power exponent size");
const int sh = numeric_cast<int>(mp_get_si(ii.as_integer_class()));
base->accept(*this);
if (sh == 1) {
return;
} else if (sh > 0) {
p = Series::pow(p, sh, prec);
} else if (sh == -1) {
p = Series::series_invert(p, var, prec);
} else {
// Invert and then exponentiate to give the correct behavior
// when expanding 1/x**(prec), which should return x**(-prec),
// not 0.
p = Series::pow(Series::series_invert(p, var, prec), -sh, prec);
}
} else if (is_a<Rational>(*exp)) {
const Rational &rat = (down_cast<const Rational &>(*exp));
const integer_class &expnumz = get_num(rat.as_rational_class());
const integer_class &expdenz = get_den(rat.as_rational_class());
if (not mp_fits_slong_p(expnumz) or not mp_fits_slong_p(expdenz))
throw SymEngineException("series rational power exponent size");
const int num = numeric_cast<int>(mp_get_si(expnumz));
const int den = numeric_cast<int>(mp_get_si(expdenz));
base->accept(*this);
const Poly proot(
Series::series_nthroot(apply(base), den, var, prec));
if (num == 1) {
p = proot;
} else if (num > 0) {
p = Series::pow(proot, num, prec);
} else if (num == -1) {
p = Series::series_invert(proot, var, prec);
} else {
p = Series::series_invert(Series::pow(proot, -num, prec), var,
prec);
}
} else if (eq(*E, *base)) {
p = Series::series_exp(apply(exp), var, prec);
} else {
p = Series::series_exp(
Poly(apply(exp) * Series::series_log(apply(base), var, prec)),
var, prec);
}
}
void bvisit(const Function &x)
{
RCP<const Basic> d = x.rcp_from_this();
RCP<const Symbol> s = symbol(varname);
map_basic_basic m({{s, zero}});
RCP<const Basic> const_term = d->subs(m);
if (const_term == d) {
p = Series::convert(*d);
return;
}
Poly res_p(apply(expand(const_term)));
Coeff prod, t;
prod = 1;
for (unsigned int i = 1; i < prec; i++) {
// Workaround for flint
t = i;
prod /= t;
d = d->diff(s);
res_p += Series::pow(var, i, prec)
* (prod * apply(expand(d->subs(m))));
}
p = res_p;
}
void bvisit(const Gamma &x)
{
RCP<const Symbol> s = symbol(varname);
RCP<const Basic> arg = x.get_args()[0];
if (eq(*arg->subs({{s, zero}}), *zero)) {
RCP<const Basic> g = gamma(add(arg, one));
if (is_a<Gamma>(*g)) {
bvisit(down_cast<const Function &>(*g));
p *= Series::pow(var, -1, prec);
} else {
g->accept(*this);
}
} else {
bvisit(implicit_cast<const Function &>(x));
}
}
void bvisit(const Series &x)
{
if (x.get_var() != varname) {
throw NotImplementedError("Multivariate Series not implemented");
}
if (x.get_degree() < prec) {
throw SymEngineException("Series with lesser prec found");
}
p = x.get_poly();
}
void bvisit(const Integer &x)
{
p = Series::convert(x);
}
void bvisit(const Rational &x)
{
p = Series::convert(x);
}
void bvisit(const Complex &x)
{
p = Series::convert(x);
}
void bvisit(const RealDouble &x)
{
p = Series::convert(x);
}
void bvisit(const ComplexDouble &x)
{
p = Series::convert(x);
}
#ifdef HAVE_SYMENGINE_MPFR
void bvisit(const RealMPFR &x)
{
p = Series::convert(x);
}
#endif
#ifdef HAVE_SYMENGINE_MPC
void bvisit(const ComplexMPC &x)
{
p = Series::convert(x);
}
#endif
void bvisit(const Sin &x)
{
x.get_arg()->accept(*this);
p = Series::series_sin(p, var, prec);
}
void bvisit(const Cos &x)
{
x.get_arg()->accept(*this);
p = Series::series_cos(p, var, prec);
}
void bvisit(const Tan &x)
{
x.get_arg()->accept(*this);
p = Series::series_tan(p, var, prec);
}
void bvisit(const Cot &x)
{
x.get_arg()->accept(*this);
p = Series::series_cot(p, var, prec);
}
void bvisit(const Csc &x)
{
x.get_arg()->accept(*this);
p = Series::series_csc(p, var, prec);
}
void bvisit(const Sec &x)
{
x.get_arg()->accept(*this);
p = Series::series_sec(p, var, prec);
}
void bvisit(const Log &x)
{
x.get_arg()->accept(*this);
p = Series::series_log(p, var, prec);
}
void bvisit(const ASin &x)
{
x.get_arg()->accept(*this);
p = Series::series_asin(p, var, prec);
}
void bvisit(const ACos &x)
{
x.get_arg()->accept(*this);
p = Series::series_acos(p, var, prec);
}
void bvisit(const ATan &x)
{
x.get_arg()->accept(*this);
p = Series::series_atan(p, var, prec);
}
void bvisit(const Sinh &x)
{
x.get_arg()->accept(*this);
p = Series::series_sinh(p, var, prec);
}
void bvisit(const Cosh &x)
{
x.get_arg()->accept(*this);
p = Series::series_cosh(p, var, prec);
}
void bvisit(const Tanh &x)
{
x.get_arg()->accept(*this);
p = Series::series_tanh(p, var, prec);
}
void bvisit(const ASinh &x)
{
x.get_arg()->accept(*this);
p = Series::series_asinh(p, var, prec);
}
void bvisit(const ATanh &x)
{
x.get_arg()->accept(*this);
p = Series::series_atanh(p, var, prec);
}
void bvisit(const LambertW &x)
{
x.get_arg()->accept(*this);
p = Series::series_lambertw(p, var, prec);
}
void bvisit(const Symbol &x)
{
if (x.get_name() == varname) {
p = Series::var(x.get_name());
} else {
p = Series::convert(x);
}
}
void bvisit(const Constant &x)
{
p = Series::convert(x);
}
void bvisit(const Basic &x)
{
if (!has_symbol(x, *symbol(varname))) {
p = Series::convert(x);
} else {
throw NotImplementedError("Not Implemented");
}
}
};
class NeedsSymbolicExpansionVisitor
: public BaseVisitor<NeedsSymbolicExpansionVisitor, StopVisitor>
{
protected:
RCP<const Symbol> x_;
bool needs_;
public:
template <typename T,
typename
= enable_if_t<std::is_base_of<TrigBase, T>::value
or std::is_base_of<HyperbolicBase, T>::value>>
void bvisit(const T &f)
{
auto arg = f.get_arg();
map_basic_basic subsx0{{x_, integer(0)}};
if (arg->subs(subsx0)->__neq__(*integer(0))) {
needs_ = true;
stop_ = true;
}
}
void bvisit(const Pow &pow)
{
auto base = pow.get_base();
auto exp = pow.get_exp();
map_basic_basic subsx0{{x_, integer(0)}};
// exp(const) or x^-1
if ((base->__eq__(*E) and exp->subs(subsx0)->__neq__(*integer(0)))
or (is_a_Number(*exp)
and down_cast<const Number &>(*exp).is_negative()
and base->subs(subsx0)->__eq__(*integer(0)))) {
needs_ = true;
stop_ = true;
}
}
void bvisit(const Log &f)
{
auto arg = f.get_arg();
map_basic_basic subsx0{{x_, integer(0)}};
if (arg->subs(subsx0)->__eq__(*integer(0))) {
needs_ = true;
stop_ = true;
}
}
void bvisit(const LambertW &x)
{
needs_ = true;
stop_ = true;
}
void bvisit(const Basic &x)
{
}
bool apply(const Basic &b, const RCP<const Symbol> &x)
{
x_ = x;
needs_ = false;
stop_ = false;
postorder_traversal_stop(b, *this);
return needs_;
}
};
} // SymEngine
#endif // SYMENGINE_SERIES_VISITOR_H