series_visitor.h
1 #ifndef SYMENGINE_SERIES_VISITOR_H
2 #define SYMENGINE_SERIES_VISITOR_H
3 
4 #include <symengine/visitor.h>
5 
6 namespace SymEngine
7 {
8 
9 template <typename Poly, typename Coeff, typename Series>
10 class SeriesVisitor : public BaseVisitor<SeriesVisitor<Poly, Coeff, Series>>
11 {
12 private:
13  Poly p;
14  const Poly var;
15  const std::string varname;
16  const unsigned prec;
17 
18 public:
19  inline SeriesVisitor(const Poly &var_, const std::string &varname_,
20  const unsigned prec_)
21  : var(var_), varname(varname_), prec(prec_)
22  {
23  }
24  RCP<const Series> series(const RCP<const Basic> &x)
25  {
26  return make_rcp<Series>(apply(x), varname, prec);
27  }
28 
29  Poly apply(const RCP<const Basic> &x)
30  {
31  x->accept(*this);
32  Poly temp(std::move(p));
33  return temp;
34  }
35 
36  void bvisit(const Add &x)
37  {
38  Poly temp(apply(x.get_coef()));
39  for (const auto &term : x.get_dict()) {
40  temp += apply(term.first) * apply(term.second);
41  }
42  p = temp;
43  }
44  void bvisit(const Mul &x)
45  {
46  Poly temp(apply(x.get_coef()));
47  for (const auto &term : x.get_dict()) {
48  temp = Series::mul(temp, apply(pow(term.first, term.second)), prec);
49  }
50  p = temp;
51  }
52  void bvisit(const Pow &x)
53  {
54  const RCP<const Basic> &base = x.get_base(), exp = x.get_exp();
55  if (is_a<Integer>(*exp)) {
56  const Integer &ii = (down_cast<const Integer &>(*exp));
57  if (not mp_fits_slong_p(ii.as_integer_class()))
58  throw SymEngineException("series power exponent size");
59  const int sh = numeric_cast<int>(mp_get_si(ii.as_integer_class()));
60  base->accept(*this);
61  if (sh == 1) {
62  return;
63  } else if (sh > 0) {
64  p = Series::pow(p, sh, prec);
65  } else if (sh == -1) {
66  p = Series::series_invert(p, var, prec);
67  } else {
68  // Invert and then exponentiate to give the correct behavior
69  // when expanding 1/x**(prec), which should return x**(-prec),
70  // not 0.
71  p = Series::pow(Series::series_invert(p, var, prec), -sh, prec);
72  }
73 
74  } else if (is_a<Rational>(*exp)) {
75  const Rational &rat = (down_cast<const Rational &>(*exp));
76  const integer_class &expnumz = get_num(rat.as_rational_class());
77  const integer_class &expdenz = get_den(rat.as_rational_class());
78  if (not mp_fits_slong_p(expnumz) or not mp_fits_slong_p(expdenz))
79  throw SymEngineException("series rational power exponent size");
80  const int num = numeric_cast<int>(mp_get_si(expnumz));
81  const int den = numeric_cast<int>(mp_get_si(expdenz));
82  base->accept(*this);
83  const Poly proot(
84  Series::series_nthroot(apply(base), den, var, prec));
85  if (num == 1) {
86  p = proot;
87  } else if (num > 0) {
88  p = Series::pow(proot, num, prec);
89  } else if (num == -1) {
90  p = Series::series_invert(proot, var, prec);
91  } else {
92  p = Series::series_invert(Series::pow(proot, -num, prec), var,
93  prec);
94  }
95  } else if (eq(*E, *base)) {
96  p = Series::series_exp(apply(exp), var, prec);
97  } else {
98  p = Series::series_exp(
99  Poly(apply(exp) * Series::series_log(apply(base), var, prec)),
100  var, prec);
101  }
102  }
103 
104  void bvisit(const Function &x)
105  {
106  RCP<const Basic> d = x.rcp_from_this();
107  RCP<const Symbol> s = symbol(varname);
108 
109  map_basic_basic m({{s, zero}});
110  RCP<const Basic> const_term = d->subs(m);
111  if (const_term == d) {
112  p = Series::convert(*d);
113  return;
114  }
115  Poly res_p(apply(expand(const_term)));
116  Coeff prod, t;
117  prod = 1;
118 
119  for (unsigned int i = 1; i < prec; i++) {
120  // Workaround for flint
121  t = i;
122  prod /= t;
123  d = d->diff(s);
124  res_p += Series::pow(var, i, prec)
125  * (prod * apply(expand(d->subs(m))));
126  }
127  p = res_p;
128  }
129 
130  void bvisit(const Gamma &x)
131  {
132  RCP<const Symbol> s = symbol(varname);
133  RCP<const Basic> arg = x.get_args()[0];
134  if (eq(*arg->subs({{s, zero}}), *zero)) {
135  RCP<const Basic> g = gamma(add(arg, one));
136  if (is_a<Gamma>(*g)) {
137  bvisit(down_cast<const Function &>(*g));
138  p *= Series::pow(var, -1, prec);
139  } else {
140  g->accept(*this);
141  }
142  } else {
143  bvisit(implicit_cast<const Function &>(x));
144  }
145  }
146 
147  void bvisit(const Series &x)
148  {
149  if (x.get_var() != varname) {
150  throw NotImplementedError("Multivariate Series not implemented");
151  }
152  if (x.get_degree() < prec) {
153  throw SymEngineException("Series with lesser prec found");
154  }
155  p = x.get_poly();
156  }
157  void bvisit(const Integer &x)
158  {
159  p = Series::convert(x);
160  }
161  void bvisit(const Rational &x)
162  {
163  p = Series::convert(x);
164  }
165  void bvisit(const Complex &x)
166  {
167  p = Series::convert(x);
168  }
169  void bvisit(const RealDouble &x)
170  {
171  p = Series::convert(x);
172  }
173  void bvisit(const ComplexDouble &x)
174  {
175  p = Series::convert(x);
176  }
177 #ifdef HAVE_SYMENGINE_MPFR
178  void bvisit(const RealMPFR &x)
179  {
180  p = Series::convert(x);
181  }
182 #endif
183 #ifdef HAVE_SYMENGINE_MPC
184  void bvisit(const ComplexMPC &x)
185  {
186  p = Series::convert(x);
187  }
188 #endif
189  void bvisit(const Sin &x)
190  {
191  x.get_arg()->accept(*this);
192  p = Series::series_sin(p, var, prec);
193  }
194  void bvisit(const Cos &x)
195  {
196  x.get_arg()->accept(*this);
197  p = Series::series_cos(p, var, prec);
198  }
199  void bvisit(const Tan &x)
200  {
201  x.get_arg()->accept(*this);
202  p = Series::series_tan(p, var, prec);
203  }
204  void bvisit(const Cot &x)
205  {
206  x.get_arg()->accept(*this);
207  p = Series::series_cot(p, var, prec);
208  }
209  void bvisit(const Csc &x)
210  {
211  x.get_arg()->accept(*this);
212  p = Series::series_csc(p, var, prec);
213  }
214  void bvisit(const Sec &x)
215  {
216  x.get_arg()->accept(*this);
217  p = Series::series_sec(p, var, prec);
218  }
219  void bvisit(const Log &x)
220  {
221  x.get_arg()->accept(*this);
222  p = Series::series_log(p, var, prec);
223  }
224  void bvisit(const ASin &x)
225  {
226  x.get_arg()->accept(*this);
227  p = Series::series_asin(p, var, prec);
228  }
229  void bvisit(const ACos &x)
230  {
231  x.get_arg()->accept(*this);
232  p = Series::series_acos(p, var, prec);
233  }
234  void bvisit(const ATan &x)
235  {
236  x.get_arg()->accept(*this);
237  p = Series::series_atan(p, var, prec);
238  }
239  void bvisit(const Sinh &x)
240  {
241  x.get_arg()->accept(*this);
242  p = Series::series_sinh(p, var, prec);
243  }
244  void bvisit(const Cosh &x)
245  {
246  x.get_arg()->accept(*this);
247  p = Series::series_cosh(p, var, prec);
248  }
249  void bvisit(const Tanh &x)
250  {
251  x.get_arg()->accept(*this);
252  p = Series::series_tanh(p, var, prec);
253  }
254  void bvisit(const ASinh &x)
255  {
256  x.get_arg()->accept(*this);
257  p = Series::series_asinh(p, var, prec);
258  }
259  void bvisit(const ATanh &x)
260  {
261  x.get_arg()->accept(*this);
262  p = Series::series_atanh(p, var, prec);
263  }
264  void bvisit(const LambertW &x)
265  {
266  x.get_arg()->accept(*this);
267  p = Series::series_lambertw(p, var, prec);
268  }
269  void bvisit(const Symbol &x)
270  {
271  if (x.get_name() == varname) {
272  p = Series::var(x.get_name());
273  } else {
274  p = Series::convert(x);
275  }
276  }
277  void bvisit(const Constant &x)
278  {
279  p = Series::convert(x);
280  }
281  void bvisit(const Basic &x)
282  {
283  if (!has_symbol(x, *symbol(varname))) {
284  p = Series::convert(x);
285  } else {
286  throw NotImplementedError("Not Implemented");
287  }
288  }
289 };
290 
292  : public BaseVisitor<NeedsSymbolicExpansionVisitor, StopVisitor>
293 {
294 protected:
295  RCP<const Symbol> x_;
296  bool needs_;
297 
298 public:
299  template <typename T,
300  typename
301  = enable_if_t<std::is_base_of<TrigBase, T>::value
303  void bvisit(const T &f)
304  {
305  auto arg = f.get_arg();
306  map_basic_basic subsx0{{x_, integer(0)}};
307  if (arg->subs(subsx0)->__neq__(*integer(0))) {
308  needs_ = true;
309  stop_ = true;
310  }
311  }
312 
313  void bvisit(const Pow &pow)
314  {
315  auto base = pow.get_base();
316  auto exp = pow.get_exp();
317  map_basic_basic subsx0{{x_, integer(0)}};
318  // exp(const) or x^-1
319  if ((base->__eq__(*E) and exp->subs(subsx0)->__neq__(*integer(0)))
320  or (is_a_Number(*exp)
321  and down_cast<const Number &>(*exp).is_negative()
322  and base->subs(subsx0)->__eq__(*integer(0)))) {
323  needs_ = true;
324  stop_ = true;
325  }
326  }
327 
328  void bvisit(const Log &f)
329  {
330  auto arg = f.get_arg();
331  map_basic_basic subsx0{{x_, integer(0)}};
332  if (arg->subs(subsx0)->__eq__(*integer(0))) {
333  needs_ = true;
334  stop_ = true;
335  }
336  }
337 
338  void bvisit(const LambertW &x)
339  {
340  needs_ = true;
341  stop_ = true;
342  }
343 
344  void bvisit(const Basic &x) {}
345 
346  bool apply(const Basic &b, const RCP<const Symbol> &x)
347  {
348  x_ = x;
349  needs_ = false;
350  stop_ = false;
351  postorder_traversal_stop(b, *this);
352  return needs_;
353  }
354 };
355 
356 } // namespace SymEngine
357 #endif // SYMENGINE_SERIES_VISITOR_H
The base class for representing addition in symbolic expressions.
Definition: add.h:27
const RCP< const Number > & get_coef() const
Definition: add.h:142
The lowest unit of symbolic representation.
Definition: basic.h:97
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
Integer Class.
Definition: integer.h:19
const integer_class & as_integer_class() const
Convert to integer_class.
Definition: integer.h:45
vec_basic get_args() const override
Returns the list of arguments.
Definition: functions.h:40
RCP< const Basic > get_arg() const
Definition: functions.h:36
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_base() const
Definition: pow.h:37
Rational Class.
Definition: rational.h:16
const rational_class & as_rational_class() const
Convert to rational_class.
Definition: rational.h:50
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a_Number(const Basic &b)
Definition: number.h:130
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
RCP< const Symbol > symbol(const std::string &name)
inline version to return Symbol
Definition: symbol.h:82
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > gamma(const RCP< const Basic > &arg)
Canonicalize Gamma:
Definition: functions.cpp:3014
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
Definition: pow.cpp:271
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
RCP< const Basic > expand(const RCP< const Basic > &self, bool deep=true)
Expands self
Definition: expand.cpp:369