3 #include <symengine/symengine_exception.h>
5 #ifdef HAVE_SYMENGINE_MPFR
10 class EvalMPFRVisitor :
public BaseVisitor<EvalMPFRVisitor>
17 EvalMPFRVisitor(mpfr_rnd_t rnd) : rnd_{rnd} {}
19 void apply(mpfr_ptr result,
const Basic &b)
21 mpfr_ptr tmp = result_;
27 void bvisit(
const Integer &x)
29 mpfr_set_z(result_, get_mpz_t(x.as_integer_class()), rnd_);
32 void bvisit(
const Rational &x)
34 mpfr_set_q(result_, get_mpq_t(x.as_rational_class()), rnd_);
37 void bvisit(
const RealDouble &x)
39 mpfr_set_d(result_, x.i, rnd_);
42 void bvisit(
const RealMPFR &x)
44 mpfr_set(result_, x.i.get_mpfr_t(), rnd_);
47 void bvisit(
const Add &x)
49 mpfr_class t(mpfr_get_prec(result_));
50 auto d = x.get_args();
52 apply(result_, *(*p));
54 for (; p != d.end(); p++) {
55 apply(t.get_mpfr_t(), *(*p));
56 mpfr_add(result_, result_, t.get_mpfr_t(), rnd_);
60 void bvisit(
const Mul &x)
62 mpfr_class t(mpfr_get_prec(result_));
63 auto d = x.get_args();
65 apply(result_, *(*p));
67 for (; p != d.end(); p++) {
68 apply(t.get_mpfr_t(), *(*p));
69 mpfr_mul(result_, result_, t.get_mpfr_t(), rnd_);
73 void bvisit(
const Pow &x)
75 if (
eq(*x.get_base(), *E)) {
76 apply(result_, *(x.get_exp()));
77 mpfr_exp(result_, result_, rnd_);
79 mpfr_class b(mpfr_get_prec(result_));
80 apply(b.get_mpfr_t(), *(x.get_base()));
81 apply(result_, *(x.get_exp()));
82 mpfr_pow(result_, b.get_mpfr_t(), result_, rnd_);
86 void bvisit(
const Equality &x)
88 mpfr_class t(mpfr_get_prec(result_));
89 apply(t.get_mpfr_t(), *(x.get_arg1()));
90 apply(result_, *(x.get_arg2()));
91 if (mpfr_equal_p(t.get_mpfr_t(), result_)) {
92 mpfr_set_ui(result_, 1, rnd_);
94 mpfr_set_ui(result_, 0, rnd_);
98 void bvisit(
const Unequality &x)
100 mpfr_class t(mpfr_get_prec(result_));
101 apply(t.get_mpfr_t(), *(x.get_arg1()));
102 apply(result_, *(x.get_arg2()));
103 if (mpfr_lessgreater_p(t.get_mpfr_t(), result_)) {
104 mpfr_set_ui(result_, 1, rnd_);
106 mpfr_set_ui(result_, 0, rnd_);
110 void bvisit(
const LessThan &x)
112 mpfr_class t(mpfr_get_prec(result_));
113 apply(t.get_mpfr_t(), *(x.get_arg1()));
114 apply(result_, *(x.get_arg2()));
115 if (mpfr_lessequal_p(t.get_mpfr_t(), result_)) {
116 mpfr_set_ui(result_, 1, rnd_);
118 mpfr_set_ui(result_, 0, rnd_);
122 void bvisit(
const StrictLessThan &x)
124 mpfr_class t(mpfr_get_prec(result_));
125 apply(t.get_mpfr_t(), *(x.get_arg1()));
126 apply(result_, *(x.get_arg2()));
127 if (mpfr_less_p(t.get_mpfr_t(), result_)) {
128 mpfr_set_ui(result_, 1, rnd_);
130 mpfr_set_ui(result_, 0, rnd_);
134 void bvisit(
const Sin &x)
136 apply(result_, *(x.get_arg()));
137 mpfr_sin(result_, result_, rnd_);
140 void bvisit(
const Cos &x)
142 apply(result_, *(x.get_arg()));
143 mpfr_cos(result_, result_, rnd_);
146 void bvisit(
const Tan &x)
148 apply(result_, *(x.get_arg()));
149 mpfr_tan(result_, result_, rnd_);
152 void bvisit(
const Log &x)
154 apply(result_, *(x.get_arg()));
155 mpfr_log(result_, result_, rnd_);
158 void bvisit(
const Cot &x)
160 apply(result_, *(x.get_arg()));
161 mpfr_cot(result_, result_, rnd_);
164 void bvisit(
const Csc &x)
166 apply(result_, *(x.get_arg()));
167 mpfr_csc(result_, result_, rnd_);
170 void bvisit(
const Sec &x)
172 apply(result_, *(x.get_arg()));
173 mpfr_sec(result_, result_, rnd_);
176 void bvisit(
const ASin &x)
178 apply(result_, *(x.get_arg()));
179 mpfr_asin(result_, result_, rnd_);
182 void bvisit(
const ACos &x)
184 apply(result_, *(x.get_arg()));
185 mpfr_acos(result_, result_, rnd_);
188 void bvisit(
const ASec &x)
190 apply(result_, *(x.get_arg()));
191 mpfr_ui_div(result_, 1, result_, rnd_);
192 mpfr_asin(result_, result_, rnd_);
195 void bvisit(
const ACsc &x)
197 apply(result_, *(x.get_arg()));
198 mpfr_ui_div(result_, 1, result_, rnd_);
199 mpfr_acos(result_, result_, rnd_);
202 void bvisit(
const ATan &x)
204 apply(result_, *(x.get_arg()));
205 mpfr_atan(result_, result_, rnd_);
208 void bvisit(
const ACot &x)
210 apply(result_, *(x.get_arg()));
211 mpfr_ui_div(result_, 1, result_, rnd_);
212 mpfr_atan(result_, result_, rnd_);
215 void bvisit(
const ATan2 &x)
217 mpfr_class t(mpfr_get_prec(result_));
218 apply(t.get_mpfr_t(), *(x.get_num()));
219 apply(result_, *(x.get_den()));
220 mpfr_atan2(result_, t.get_mpfr_t(), result_, rnd_);
223 void bvisit(
const Sinh &x)
225 apply(result_, *(x.get_arg()));
226 mpfr_sinh(result_, result_, rnd_);
229 void bvisit(
const Csch &x)
231 apply(result_, *(x.get_arg()));
232 mpfr_csch(result_, result_, rnd_);
235 void bvisit(
const Cosh &x)
237 apply(result_, *(x.get_arg()));
238 mpfr_cosh(result_, result_, rnd_);
241 void bvisit(
const Sech &x)
243 apply(result_, *(x.get_arg()));
244 mpfr_sech(result_, result_, rnd_);
247 void bvisit(
const Tanh &x)
249 apply(result_, *(x.get_arg()));
250 mpfr_tanh(result_, result_, rnd_);
253 void bvisit(
const Coth &x)
255 apply(result_, *(x.get_arg()));
256 mpfr_coth(result_, result_, rnd_);
259 void bvisit(
const ASinh &x)
261 apply(result_, *(x.get_arg()));
262 mpfr_asinh(result_, result_, rnd_);
265 void bvisit(
const ACsch &x)
267 apply(result_, *(x.get_arg()));
268 mpfr_ui_div(result_, 1, result_, rnd_);
269 mpfr_asinh(result_, result_, rnd_);
272 void bvisit(
const ACosh &x)
274 apply(result_, *(x.get_arg()));
275 mpfr_acosh(result_, result_, rnd_);
278 void bvisit(
const ATanh &x)
280 apply(result_, *(x.get_arg()));
281 mpfr_atanh(result_, result_, rnd_);
284 void bvisit(
const ACoth &x)
286 apply(result_, *(x.get_arg()));
287 mpfr_ui_div(result_, 1, result_, rnd_);
288 mpfr_atanh(result_, result_, rnd_);
291 void bvisit(
const ASech &x)
293 apply(result_, *(x.get_arg()));
294 mpfr_ui_div(result_, 1, result_, rnd_);
295 mpfr_acosh(result_, result_, rnd_);
298 void bvisit(
const Gamma &x)
300 apply(result_, *(x.get_args()[0]));
301 mpfr_gamma(result_, result_, rnd_);
303 #if MPFR_VERSION_MAJOR > 3
304 void bvisit(
const UpperGamma &x)
306 mpfr_class t(mpfr_get_prec(result_));
307 apply(result_, *(x.get_args()[1]));
308 apply(t.get_mpfr_t(), *(x.get_args()[0]));
309 mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
312 void bvisit(
const LowerGamma &x)
314 mpfr_class t(mpfr_get_prec(result_));
315 apply(result_, *(x.get_args()[1]));
316 apply(t.get_mpfr_t(), *(x.get_args()[0]));
317 mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
318 mpfr_gamma(t.get_mpfr_t(), t.get_mpfr_t(), rnd_);
319 mpfr_sub(result_, t.get_mpfr_t(), result_, rnd_);
322 void bvisit(
const LogGamma &x)
324 apply(result_, *(x.get_args()[0]));
325 mpfr_lngamma(result_, result_, rnd_);
328 void bvisit(
const Beta &x)
330 apply(result_, *(x.rewrite_as_gamma()));
333 void bvisit(
const Constant &x)
336 mpfr_const_pi(result_, rnd_);
337 }
else if (x.__eq__(*E)) {
339 mpfr_init2(one_, mpfr_get_prec(result_));
340 mpfr_set_ui(one_, 1, rnd_);
341 mpfr_exp(result_, one_, rnd_);
343 }
else if (x.__eq__(*EulerGamma)) {
344 mpfr_const_euler(result_, rnd_);
345 }
else if (x.__eq__(*Catalan)) {
346 mpfr_const_catalan(result_, rnd_);
347 }
else if (x.__eq__(*GoldenRatio)) {
348 mpfr_sqrt_ui(result_, 5, rnd_);
349 mpfr_add_ui(result_, result_, 1, rnd_);
350 mpfr_div_ui(result_, result_, 2, rnd_);
352 throw NotImplementedError(
"Constant " + x.get_name()
353 +
" is not implemented.");
357 void bvisit(
const Abs &x)
359 apply(result_, *(x.get_arg()));
360 mpfr_abs(result_, result_, rnd_);
363 void bvisit(
const NumberWrapper &x)
365 x.eval(mpfr_get_prec(result_))->accept(*
this);
368 void bvisit(
const FunctionWrapper &x)
370 x.eval(mpfr_get_prec(result_))->accept(*
this);
372 void bvisit(
const Erf &x)
374 apply(result_, *(x.get_args()[0]));
375 mpfr_erf(result_, result_, rnd_);
378 void bvisit(
const Erfc &x)
380 apply(result_, *(x.get_args()[0]));
381 mpfr_erfc(result_, result_, rnd_);
384 void bvisit(
const Max &x)
386 mpfr_class t(mpfr_get_prec(result_));
387 auto d = x.get_args();
389 apply(result_, *(*p));
391 for (; p != d.end(); p++) {
392 apply(t.get_mpfr_t(), *(*p));
393 mpfr_max(result_, result_, t.get_mpfr_t(), rnd_);
397 void bvisit(
const Min &x)
399 mpfr_class t(mpfr_get_prec(result_));
400 auto d = x.get_args();
402 apply(result_, *(*p));
404 for (; p != d.end(); p++) {
405 apply(t.get_mpfr_t(), *(*p));
406 mpfr_min(result_, result_, t.get_mpfr_t(), rnd_);
410 void bvisit(
const UnevaluatedExpr &x)
412 apply(result_, *x.get_arg());
419 void bvisit(
const Basic &)
421 throw NotImplementedError(
"Not Implemented");
425 void eval_mpfr(mpfr_ptr result,
const Basic &b, mpfr_rnd_t rnd)
427 EvalMPFRVisitor v(rnd);
Main namespace for SymEngine package.
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b