Loading...
Searching...
No Matches
eval_mpfr.cpp
1#include <symengine/visitor.h>
3#include <symengine/symengine_exception.h>
4
5#ifdef HAVE_SYMENGINE_MPFR
6
7namespace SymEngine
8{
9
10class EvalMPFRVisitor : public BaseVisitor<EvalMPFRVisitor>
11{
12protected:
13 mpfr_rnd_t rnd_;
14 mpfr_ptr result_;
15
16public:
17 EvalMPFRVisitor(mpfr_rnd_t rnd) : rnd_{rnd} {}
18
19 void apply(mpfr_ptr result, const Basic &b)
20 {
21 mpfr_ptr tmp = result_;
22 result_ = result;
23 b.accept(*this);
24 result_ = tmp;
25 }
26
27 void bvisit(const Integer &x)
28 {
29 mpfr_set_z(result_, get_mpz_t(x.as_integer_class()), rnd_);
30 }
31
32 void bvisit(const Rational &x)
33 {
34 mpfr_set_q(result_, get_mpq_t(x.as_rational_class()), rnd_);
35 }
36
37 void bvisit(const RealDouble &x)
38 {
39 mpfr_set_d(result_, x.i, rnd_);
40 }
41
42 void bvisit(const RealMPFR &x)
43 {
44 mpfr_set(result_, x.i.get_mpfr_t(), rnd_);
45 }
46
47 void bvisit(const Add &x)
48 {
49 mpfr_class t(mpfr_get_prec(result_));
50 auto d = x.get_args();
51 auto p = d.begin();
52 apply(result_, *(*p));
53 p++;
54 for (; p != d.end(); p++) {
55 apply(t.get_mpfr_t(), *(*p));
56 mpfr_add(result_, result_, t.get_mpfr_t(), rnd_);
57 }
58 }
59
60 void bvisit(const Mul &x)
61 {
62 mpfr_class t(mpfr_get_prec(result_));
63 auto d = x.get_args();
64 auto p = d.begin();
65 apply(result_, *(*p));
66 p++;
67 for (; p != d.end(); p++) {
68 apply(t.get_mpfr_t(), *(*p));
69 mpfr_mul(result_, result_, t.get_mpfr_t(), rnd_);
70 }
71 }
72
73 void bvisit(const Pow &x)
74 {
75 if (eq(*x.get_base(), *E)) {
76 apply(result_, *(x.get_exp()));
77 mpfr_exp(result_, result_, rnd_);
78 } else {
79 mpfr_class b(mpfr_get_prec(result_));
80 apply(b.get_mpfr_t(), *(x.get_base()));
81 apply(result_, *(x.get_exp()));
82 mpfr_pow(result_, b.get_mpfr_t(), result_, rnd_);
83 }
84 }
85
86 void bvisit(const Equality &x)
87 {
88 mpfr_class t(mpfr_get_prec(result_));
89 apply(t.get_mpfr_t(), *(x.get_arg1()));
90 apply(result_, *(x.get_arg2()));
91 if (mpfr_equal_p(t.get_mpfr_t(), result_)) {
92 mpfr_set_ui(result_, 1, rnd_);
93 } else {
94 mpfr_set_ui(result_, 0, rnd_);
95 }
96 }
97
98 void bvisit(const Unequality &x)
99 {
100 mpfr_class t(mpfr_get_prec(result_));
101 apply(t.get_mpfr_t(), *(x.get_arg1()));
102 apply(result_, *(x.get_arg2()));
103 if (mpfr_lessgreater_p(t.get_mpfr_t(), result_)) {
104 mpfr_set_ui(result_, 1, rnd_);
105 } else {
106 mpfr_set_ui(result_, 0, rnd_);
107 }
108 }
109
110 void bvisit(const LessThan &x)
111 {
112 mpfr_class t(mpfr_get_prec(result_));
113 apply(t.get_mpfr_t(), *(x.get_arg1()));
114 apply(result_, *(x.get_arg2()));
115 if (mpfr_lessequal_p(t.get_mpfr_t(), result_)) {
116 mpfr_set_ui(result_, 1, rnd_);
117 } else {
118 mpfr_set_ui(result_, 0, rnd_);
119 }
120 }
121
122 void bvisit(const StrictLessThan &x)
123 {
124 mpfr_class t(mpfr_get_prec(result_));
125 apply(t.get_mpfr_t(), *(x.get_arg1()));
126 apply(result_, *(x.get_arg2()));
127 if (mpfr_less_p(t.get_mpfr_t(), result_)) {
128 mpfr_set_ui(result_, 1, rnd_);
129 } else {
130 mpfr_set_ui(result_, 0, rnd_);
131 }
132 }
133
134 void bvisit(const Sin &x)
135 {
136 apply(result_, *(x.get_arg()));
137 mpfr_sin(result_, result_, rnd_);
138 }
139
140 void bvisit(const Cos &x)
141 {
142 apply(result_, *(x.get_arg()));
143 mpfr_cos(result_, result_, rnd_);
144 }
145
146 void bvisit(const Tan &x)
147 {
148 apply(result_, *(x.get_arg()));
149 mpfr_tan(result_, result_, rnd_);
150 }
151
152 void bvisit(const Log &x)
153 {
154 apply(result_, *(x.get_arg()));
155 mpfr_log(result_, result_, rnd_);
156 }
157
158 void bvisit(const Cot &x)
159 {
160 apply(result_, *(x.get_arg()));
161 mpfr_cot(result_, result_, rnd_);
162 }
163
164 void bvisit(const Csc &x)
165 {
166 apply(result_, *(x.get_arg()));
167 mpfr_csc(result_, result_, rnd_);
168 }
169
170 void bvisit(const Sec &x)
171 {
172 apply(result_, *(x.get_arg()));
173 mpfr_sec(result_, result_, rnd_);
174 }
175
176 void bvisit(const ASin &x)
177 {
178 apply(result_, *(x.get_arg()));
179 mpfr_asin(result_, result_, rnd_);
180 }
181
182 void bvisit(const ACos &x)
183 {
184 apply(result_, *(x.get_arg()));
185 mpfr_acos(result_, result_, rnd_);
186 }
187
188 void bvisit(const ASec &x)
189 {
190 apply(result_, *(x.get_arg()));
191 mpfr_ui_div(result_, 1, result_, rnd_);
192 mpfr_asin(result_, result_, rnd_);
193 }
194
195 void bvisit(const ACsc &x)
196 {
197 apply(result_, *(x.get_arg()));
198 mpfr_ui_div(result_, 1, result_, rnd_);
199 mpfr_acos(result_, result_, rnd_);
200 }
201
202 void bvisit(const ATan &x)
203 {
204 apply(result_, *(x.get_arg()));
205 mpfr_atan(result_, result_, rnd_);
206 }
207
208 void bvisit(const ACot &x)
209 {
210 apply(result_, *(x.get_arg()));
211 mpfr_ui_div(result_, 1, result_, rnd_);
212 mpfr_atan(result_, result_, rnd_);
213 }
214
215 void bvisit(const ATan2 &x)
216 {
217 mpfr_class t(mpfr_get_prec(result_));
218 apply(t.get_mpfr_t(), *(x.get_num()));
219 apply(result_, *(x.get_den()));
220 mpfr_atan2(result_, t.get_mpfr_t(), result_, rnd_);
221 }
222
223 void bvisit(const Sinh &x)
224 {
225 apply(result_, *(x.get_arg()));
226 mpfr_sinh(result_, result_, rnd_);
227 }
228
229 void bvisit(const Csch &x)
230 {
231 apply(result_, *(x.get_arg()));
232 mpfr_csch(result_, result_, rnd_);
233 }
234
235 void bvisit(const Cosh &x)
236 {
237 apply(result_, *(x.get_arg()));
238 mpfr_cosh(result_, result_, rnd_);
239 }
240
241 void bvisit(const Sech &x)
242 {
243 apply(result_, *(x.get_arg()));
244 mpfr_sech(result_, result_, rnd_);
245 }
246
247 void bvisit(const Tanh &x)
248 {
249 apply(result_, *(x.get_arg()));
250 mpfr_tanh(result_, result_, rnd_);
251 }
252
253 void bvisit(const Coth &x)
254 {
255 apply(result_, *(x.get_arg()));
256 mpfr_coth(result_, result_, rnd_);
257 }
258
259 void bvisit(const ASinh &x)
260 {
261 apply(result_, *(x.get_arg()));
262 mpfr_asinh(result_, result_, rnd_);
263 }
264
265 void bvisit(const ACsch &x)
266 {
267 apply(result_, *(x.get_arg()));
268 mpfr_ui_div(result_, 1, result_, rnd_);
269 mpfr_asinh(result_, result_, rnd_);
270 };
271
272 void bvisit(const ACosh &x)
273 {
274 apply(result_, *(x.get_arg()));
275 mpfr_acosh(result_, result_, rnd_);
276 }
277
278 void bvisit(const ATanh &x)
279 {
280 apply(result_, *(x.get_arg()));
281 mpfr_atanh(result_, result_, rnd_);
282 }
283
284 void bvisit(const ACoth &x)
285 {
286 apply(result_, *(x.get_arg()));
287 mpfr_ui_div(result_, 1, result_, rnd_);
288 mpfr_atanh(result_, result_, rnd_);
289 }
290
291 void bvisit(const ASech &x)
292 {
293 apply(result_, *(x.get_arg()));
294 mpfr_ui_div(result_, 1, result_, rnd_);
295 mpfr_acosh(result_, result_, rnd_);
296 };
297
298 void bvisit(const Gamma &x)
299 {
300 apply(result_, *(x.get_args()[0]));
301 mpfr_gamma(result_, result_, rnd_);
302 };
303#if MPFR_VERSION_MAJOR > 3
304 void bvisit(const UpperGamma &x)
305 {
306 mpfr_class t(mpfr_get_prec(result_));
307 apply(result_, *(x.get_args()[1]));
308 apply(t.get_mpfr_t(), *(x.get_args()[0]));
309 mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
310 };
311
312 void bvisit(const LowerGamma &x)
313 {
314 mpfr_class t(mpfr_get_prec(result_));
315 apply(result_, *(x.get_args()[1]));
316 apply(t.get_mpfr_t(), *(x.get_args()[0]));
317 mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
318 mpfr_gamma(t.get_mpfr_t(), t.get_mpfr_t(), rnd_);
319 mpfr_sub(result_, t.get_mpfr_t(), result_, rnd_);
320 };
321#endif
322 void bvisit(const LogGamma &x)
323 {
324 apply(result_, *(x.get_args()[0]));
325 mpfr_lngamma(result_, result_, rnd_);
326 }
327
328 void bvisit(const Beta &x)
329 {
330 apply(result_, *(x.rewrite_as_gamma()));
331 };
332
333 void bvisit(const Constant &x)
334 {
335 if (x.__eq__(*pi)) {
336 mpfr_const_pi(result_, rnd_);
337 } else if (x.__eq__(*E)) {
338 mpfr_t one_;
339 mpfr_init2(one_, mpfr_get_prec(result_));
340 mpfr_set_ui(one_, 1, rnd_);
341 mpfr_exp(result_, one_, rnd_);
342 mpfr_clear(one_);
343 } else if (x.__eq__(*EulerGamma)) {
344 mpfr_const_euler(result_, rnd_);
345 } else if (x.__eq__(*Catalan)) {
346 mpfr_const_catalan(result_, rnd_);
347 } else if (x.__eq__(*GoldenRatio)) {
348 mpfr_sqrt_ui(result_, 5, rnd_);
349 mpfr_add_ui(result_, result_, 1, rnd_);
350 mpfr_div_ui(result_, result_, 2, rnd_);
351 } else {
352 throw NotImplementedError("Constant " + x.get_name()
353 + " is not implemented.");
354 }
355 }
356
357 void bvisit(const Abs &x)
358 {
359 apply(result_, *(x.get_arg()));
360 mpfr_abs(result_, result_, rnd_);
361 };
362
363 void bvisit(const NumberWrapper &x)
364 {
365 x.eval(mpfr_get_prec(result_))->accept(*this);
366 }
367
368 void bvisit(const FunctionWrapper &x)
369 {
370 x.eval(mpfr_get_prec(result_))->accept(*this);
371 }
372 void bvisit(const Erf &x)
373 {
374 apply(result_, *(x.get_args()[0]));
375 mpfr_erf(result_, result_, rnd_);
376 }
377
378 void bvisit(const Erfc &x)
379 {
380 apply(result_, *(x.get_args()[0]));
381 mpfr_erfc(result_, result_, rnd_);
382 }
383
384 void bvisit(const Max &x)
385 {
386 mpfr_class t(mpfr_get_prec(result_));
387 auto d = x.get_args();
388 auto p = d.begin();
389 apply(result_, *(*p));
390 p++;
391 for (; p != d.end(); p++) {
392 apply(t.get_mpfr_t(), *(*p));
393 mpfr_max(result_, result_, t.get_mpfr_t(), rnd_);
394 }
395 }
396
397 void bvisit(const Min &x)
398 {
399 mpfr_class t(mpfr_get_prec(result_));
400 auto d = x.get_args();
401 auto p = d.begin();
402 apply(result_, *(*p));
403 p++;
404 for (; p != d.end(); p++) {
405 apply(t.get_mpfr_t(), *(*p));
406 mpfr_min(result_, result_, t.get_mpfr_t(), rnd_);
407 }
408 }
409
410 void bvisit(const UnevaluatedExpr &x)
411 {
412 apply(result_, *x.get_arg());
413 }
414
415 // Classes not implemented are
416 // Subs, Dirichlet_eta, Zeta
417 // LeviCivita, KroneckerDelta, LambertW
418 // Derivative, Complex, ComplexDouble, ComplexMPC
419 void bvisit(const Basic &)
420 {
421 throw NotImplementedError("Not Implemented");
422 };
423};
424
425void eval_mpfr(mpfr_ptr result, const Basic &b, mpfr_rnd_t rnd)
426{
427 EvalMPFRVisitor v(rnd);
428 v.apply(result, b);
429}
430
431} // namespace SymEngine
432
433#endif // HAVE_SYMENGINE_MPFR
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21