eval_mpfr.cpp
1 #include <symengine/visitor.h>
2 #include <symengine/eval_mpfr.h>
3 #include <symengine/symengine_exception.h>
4 
5 #ifdef HAVE_SYMENGINE_MPFR
6 
7 namespace SymEngine
8 {
9 
10 class EvalMPFRVisitor : public BaseVisitor<EvalMPFRVisitor>
11 {
12 protected:
13  mpfr_rnd_t rnd_;
14  mpfr_ptr result_;
15 
16 public:
17  EvalMPFRVisitor(mpfr_rnd_t rnd) : rnd_{rnd} {}
18 
19  void apply(mpfr_ptr result, const Basic &b)
20  {
21  mpfr_ptr tmp = result_;
22  result_ = result;
23  b.accept(*this);
24  result_ = tmp;
25  }
26 
27  void bvisit(const Integer &x)
28  {
29  mpfr_set_z(result_, get_mpz_t(x.as_integer_class()), rnd_);
30  }
31 
32  void bvisit(const Rational &x)
33  {
34  mpfr_set_q(result_, get_mpq_t(x.as_rational_class()), rnd_);
35  }
36 
37  void bvisit(const RealDouble &x)
38  {
39  mpfr_set_d(result_, x.i, rnd_);
40  }
41 
42  void bvisit(const RealMPFR &x)
43  {
44  mpfr_set(result_, x.i.get_mpfr_t(), rnd_);
45  }
46 
47  void bvisit(const Add &x)
48  {
49  mpfr_class t(mpfr_get_prec(result_));
50  auto d = x.get_args();
51  auto p = d.begin();
52  apply(result_, *(*p));
53  p++;
54  for (; p != d.end(); p++) {
55  apply(t.get_mpfr_t(), *(*p));
56  mpfr_add(result_, result_, t.get_mpfr_t(), rnd_);
57  }
58  }
59 
60  void bvisit(const Mul &x)
61  {
62  mpfr_class t(mpfr_get_prec(result_));
63  auto d = x.get_args();
64  auto p = d.begin();
65  apply(result_, *(*p));
66  p++;
67  for (; p != d.end(); p++) {
68  apply(t.get_mpfr_t(), *(*p));
69  mpfr_mul(result_, result_, t.get_mpfr_t(), rnd_);
70  }
71  }
72 
73  void bvisit(const Pow &x)
74  {
75  if (eq(*x.get_base(), *E)) {
76  apply(result_, *(x.get_exp()));
77  mpfr_exp(result_, result_, rnd_);
78  } else {
79  mpfr_class b(mpfr_get_prec(result_));
80  apply(b.get_mpfr_t(), *(x.get_base()));
81  apply(result_, *(x.get_exp()));
82  mpfr_pow(result_, b.get_mpfr_t(), result_, rnd_);
83  }
84  }
85 
86  void bvisit(const Equality &x)
87  {
88  mpfr_class t(mpfr_get_prec(result_));
89  apply(t.get_mpfr_t(), *(x.get_arg1()));
90  apply(result_, *(x.get_arg2()));
91  if (mpfr_equal_p(t.get_mpfr_t(), result_)) {
92  mpfr_set_ui(result_, 1, rnd_);
93  } else {
94  mpfr_set_ui(result_, 0, rnd_);
95  }
96  }
97 
98  void bvisit(const Unequality &x)
99  {
100  mpfr_class t(mpfr_get_prec(result_));
101  apply(t.get_mpfr_t(), *(x.get_arg1()));
102  apply(result_, *(x.get_arg2()));
103  if (mpfr_lessgreater_p(t.get_mpfr_t(), result_)) {
104  mpfr_set_ui(result_, 1, rnd_);
105  } else {
106  mpfr_set_ui(result_, 0, rnd_);
107  }
108  }
109 
110  void bvisit(const LessThan &x)
111  {
112  mpfr_class t(mpfr_get_prec(result_));
113  apply(t.get_mpfr_t(), *(x.get_arg1()));
114  apply(result_, *(x.get_arg2()));
115  if (mpfr_lessequal_p(t.get_mpfr_t(), result_)) {
116  mpfr_set_ui(result_, 1, rnd_);
117  } else {
118  mpfr_set_ui(result_, 0, rnd_);
119  }
120  }
121 
122  void bvisit(const StrictLessThan &x)
123  {
124  mpfr_class t(mpfr_get_prec(result_));
125  apply(t.get_mpfr_t(), *(x.get_arg1()));
126  apply(result_, *(x.get_arg2()));
127  if (mpfr_less_p(t.get_mpfr_t(), result_)) {
128  mpfr_set_ui(result_, 1, rnd_);
129  } else {
130  mpfr_set_ui(result_, 0, rnd_);
131  }
132  }
133 
134  void bvisit(const Sin &x)
135  {
136  apply(result_, *(x.get_arg()));
137  mpfr_sin(result_, result_, rnd_);
138  }
139 
140  void bvisit(const Cos &x)
141  {
142  apply(result_, *(x.get_arg()));
143  mpfr_cos(result_, result_, rnd_);
144  }
145 
146  void bvisit(const Tan &x)
147  {
148  apply(result_, *(x.get_arg()));
149  mpfr_tan(result_, result_, rnd_);
150  }
151 
152  void bvisit(const Log &x)
153  {
154  apply(result_, *(x.get_arg()));
155  mpfr_log(result_, result_, rnd_);
156  }
157 
158  void bvisit(const Cot &x)
159  {
160  apply(result_, *(x.get_arg()));
161  mpfr_cot(result_, result_, rnd_);
162  }
163 
164  void bvisit(const Csc &x)
165  {
166  apply(result_, *(x.get_arg()));
167  mpfr_csc(result_, result_, rnd_);
168  }
169 
170  void bvisit(const Sec &x)
171  {
172  apply(result_, *(x.get_arg()));
173  mpfr_sec(result_, result_, rnd_);
174  }
175 
176  void bvisit(const ASin &x)
177  {
178  apply(result_, *(x.get_arg()));
179  mpfr_asin(result_, result_, rnd_);
180  }
181 
182  void bvisit(const ACos &x)
183  {
184  apply(result_, *(x.get_arg()));
185  mpfr_acos(result_, result_, rnd_);
186  }
187 
188  void bvisit(const ASec &x)
189  {
190  apply(result_, *(x.get_arg()));
191  mpfr_ui_div(result_, 1, result_, rnd_);
192  mpfr_asin(result_, result_, rnd_);
193  }
194 
195  void bvisit(const ACsc &x)
196  {
197  apply(result_, *(x.get_arg()));
198  mpfr_ui_div(result_, 1, result_, rnd_);
199  mpfr_acos(result_, result_, rnd_);
200  }
201 
202  void bvisit(const ATan &x)
203  {
204  apply(result_, *(x.get_arg()));
205  mpfr_atan(result_, result_, rnd_);
206  }
207 
208  void bvisit(const ACot &x)
209  {
210  apply(result_, *(x.get_arg()));
211  mpfr_ui_div(result_, 1, result_, rnd_);
212  mpfr_atan(result_, result_, rnd_);
213  }
214 
215  void bvisit(const ATan2 &x)
216  {
217  mpfr_class t(mpfr_get_prec(result_));
218  apply(t.get_mpfr_t(), *(x.get_num()));
219  apply(result_, *(x.get_den()));
220  mpfr_atan2(result_, t.get_mpfr_t(), result_, rnd_);
221  }
222 
223  void bvisit(const Sinh &x)
224  {
225  apply(result_, *(x.get_arg()));
226  mpfr_sinh(result_, result_, rnd_);
227  }
228 
229  void bvisit(const Csch &x)
230  {
231  apply(result_, *(x.get_arg()));
232  mpfr_csch(result_, result_, rnd_);
233  }
234 
235  void bvisit(const Cosh &x)
236  {
237  apply(result_, *(x.get_arg()));
238  mpfr_cosh(result_, result_, rnd_);
239  }
240 
241  void bvisit(const Sech &x)
242  {
243  apply(result_, *(x.get_arg()));
244  mpfr_sech(result_, result_, rnd_);
245  }
246 
247  void bvisit(const Tanh &x)
248  {
249  apply(result_, *(x.get_arg()));
250  mpfr_tanh(result_, result_, rnd_);
251  }
252 
253  void bvisit(const Coth &x)
254  {
255  apply(result_, *(x.get_arg()));
256  mpfr_coth(result_, result_, rnd_);
257  }
258 
259  void bvisit(const ASinh &x)
260  {
261  apply(result_, *(x.get_arg()));
262  mpfr_asinh(result_, result_, rnd_);
263  }
264 
265  void bvisit(const ACsch &x)
266  {
267  apply(result_, *(x.get_arg()));
268  mpfr_ui_div(result_, 1, result_, rnd_);
269  mpfr_asinh(result_, result_, rnd_);
270  };
271 
272  void bvisit(const ACosh &x)
273  {
274  apply(result_, *(x.get_arg()));
275  mpfr_acosh(result_, result_, rnd_);
276  }
277 
278  void bvisit(const ATanh &x)
279  {
280  apply(result_, *(x.get_arg()));
281  mpfr_atanh(result_, result_, rnd_);
282  }
283 
284  void bvisit(const ACoth &x)
285  {
286  apply(result_, *(x.get_arg()));
287  mpfr_ui_div(result_, 1, result_, rnd_);
288  mpfr_atanh(result_, result_, rnd_);
289  }
290 
291  void bvisit(const ASech &x)
292  {
293  apply(result_, *(x.get_arg()));
294  mpfr_ui_div(result_, 1, result_, rnd_);
295  mpfr_acosh(result_, result_, rnd_);
296  };
297 
298  void bvisit(const Gamma &x)
299  {
300  apply(result_, *(x.get_args()[0]));
301  mpfr_gamma(result_, result_, rnd_);
302  };
303 #if MPFR_VERSION_MAJOR > 3
304  void bvisit(const UpperGamma &x)
305  {
306  mpfr_class t(mpfr_get_prec(result_));
307  apply(result_, *(x.get_args()[1]));
308  apply(t.get_mpfr_t(), *(x.get_args()[0]));
309  mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
310  };
311 
312  void bvisit(const LowerGamma &x)
313  {
314  mpfr_class t(mpfr_get_prec(result_));
315  apply(result_, *(x.get_args()[1]));
316  apply(t.get_mpfr_t(), *(x.get_args()[0]));
317  mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
318  mpfr_gamma(t.get_mpfr_t(), t.get_mpfr_t(), rnd_);
319  mpfr_sub(result_, t.get_mpfr_t(), result_, rnd_);
320  };
321 #endif
322  void bvisit(const LogGamma &x)
323  {
324  apply(result_, *(x.get_args()[0]));
325  mpfr_lngamma(result_, result_, rnd_);
326  }
327 
328  void bvisit(const Beta &x)
329  {
330  apply(result_, *(x.rewrite_as_gamma()));
331  };
332 
333  void bvisit(const Constant &x)
334  {
335  if (x.__eq__(*pi)) {
336  mpfr_const_pi(result_, rnd_);
337  } else if (x.__eq__(*E)) {
338  mpfr_t one_;
339  mpfr_init2(one_, mpfr_get_prec(result_));
340  mpfr_set_ui(one_, 1, rnd_);
341  mpfr_exp(result_, one_, rnd_);
342  mpfr_clear(one_);
343  } else if (x.__eq__(*EulerGamma)) {
344  mpfr_const_euler(result_, rnd_);
345  } else if (x.__eq__(*Catalan)) {
346  mpfr_const_catalan(result_, rnd_);
347  } else if (x.__eq__(*GoldenRatio)) {
348  mpfr_sqrt_ui(result_, 5, rnd_);
349  mpfr_add_ui(result_, result_, 1, rnd_);
350  mpfr_div_ui(result_, result_, 2, rnd_);
351  } else {
352  throw NotImplementedError("Constant " + x.get_name()
353  + " is not implemented.");
354  }
355  }
356 
357  void bvisit(const Abs &x)
358  {
359  apply(result_, *(x.get_arg()));
360  mpfr_abs(result_, result_, rnd_);
361  };
362 
363  void bvisit(const NumberWrapper &x)
364  {
365  x.eval(mpfr_get_prec(result_))->accept(*this);
366  }
367 
368  void bvisit(const FunctionWrapper &x)
369  {
370  x.eval(mpfr_get_prec(result_))->accept(*this);
371  }
372  void bvisit(const Erf &x)
373  {
374  apply(result_, *(x.get_args()[0]));
375  mpfr_erf(result_, result_, rnd_);
376  }
377 
378  void bvisit(const Erfc &x)
379  {
380  apply(result_, *(x.get_args()[0]));
381  mpfr_erfc(result_, result_, rnd_);
382  }
383 
384  void bvisit(const Max &x)
385  {
386  mpfr_class t(mpfr_get_prec(result_));
387  auto d = x.get_args();
388  auto p = d.begin();
389  apply(result_, *(*p));
390  p++;
391  for (; p != d.end(); p++) {
392  apply(t.get_mpfr_t(), *(*p));
393  mpfr_max(result_, result_, t.get_mpfr_t(), rnd_);
394  }
395  }
396 
397  void bvisit(const Min &x)
398  {
399  mpfr_class t(mpfr_get_prec(result_));
400  auto d = x.get_args();
401  auto p = d.begin();
402  apply(result_, *(*p));
403  p++;
404  for (; p != d.end(); p++) {
405  apply(t.get_mpfr_t(), *(*p));
406  mpfr_min(result_, result_, t.get_mpfr_t(), rnd_);
407  }
408  }
409 
410  void bvisit(const UnevaluatedExpr &x)
411  {
412  apply(result_, *x.get_arg());
413  }
414 
415  // Classes not implemented are
416  // Subs, Dirichlet_eta, Zeta
417  // LeviCivita, KroneckerDelta, LambertW
418  // Derivative, Complex, ComplexDouble, ComplexMPC
419  void bvisit(const Basic &)
420  {
421  throw NotImplementedError("Not Implemented");
422  };
423 };
424 
425 void eval_mpfr(mpfr_ptr result, const Basic &b, mpfr_rnd_t rnd)
426 {
427  EvalMPFRVisitor v(rnd);
428  v.apply(result, b);
429 }
430 
431 } // namespace SymEngine
432 
433 #endif // HAVE_SYMENGINE_MPFR
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21