eval_arb.cpp
1 #include <symengine/visitor.h>
2 #include <symengine/eval_arb.h>
3 #include <symengine/symengine_exception.h>
4 
5 #ifdef HAVE_SYMENGINE_ARB
6 
7 namespace SymEngine
8 {
9 
10 class EvalArbVisitor : public BaseVisitor<EvalArbVisitor>
11 {
12 protected:
13  long prec_;
14  arb_ptr result_;
15 
16 public:
17  EvalArbVisitor(long precision) : prec_{precision} {}
18 
19  void apply(arb_ptr result, const Basic &b)
20  {
21  arb_ptr tmp = result_;
22  result_ = result;
23  b.accept(*this);
24  result_ = tmp;
25  }
26 
27  void bvisit(const Integer &x)
28  {
29  fmpz_t z_;
30  fmpz_init(z_);
31  fmpz_set_mpz(z_, get_mpz_t(x.as_integer_class()));
32  arb_set_fmpz(result_, z_);
33  fmpz_clear(z_);
34  }
35 
36  void bvisit(const Rational &x)
37  {
38  fmpq_t q_;
39  fmpq_init(q_);
40  fmpq_set_mpq(q_, get_mpq_t(x.as_rational_class()));
41  arb_set_fmpq(result_, q_, prec_);
42  fmpq_clear(q_);
43  }
44 
45  void bvisit(const RealDouble &x)
46  {
47  arf_t f_;
48  arf_init(f_);
49  arf_set_d(f_, x.i);
50  arb_set_arf(result_, f_);
51  arf_clear(f_);
52  }
53 
54  void bvisit(const Add &x)
55  {
56  arb_t t;
57  arb_init(t);
58 
59  auto d = x.get_args();
60  for (auto p = d.begin(); p != d.end(); p++) {
61 
62  if (p == d.begin()) {
63  apply(result_, *(*p));
64  } else {
65  apply(t, *(*p));
66  arb_add(result_, result_, t, prec_);
67  }
68  }
69 
70  arb_clear(t);
71  }
72 
73  void bvisit(const Mul &x)
74  {
75  arb_t t;
76  arb_init(t);
77 
78  auto d = x.get_args();
79  for (auto p = d.begin(); p != d.end(); p++) {
80 
81  if (p == d.begin()) {
82  apply(result_, *(*p));
83  } else {
84  apply(t, *(*p));
85  arb_mul(result_, result_, t, prec_);
86  }
87  }
88 
89  arb_clear(t);
90  }
91 
92  void bvisit(const Pow &x)
93  {
94  if (eq(*x.get_base(), *E)) {
95  apply(result_, *(x.get_exp()));
96  arb_exp(result_, result_, prec_);
97  } else {
98  arb_t b;
99  arb_init(b);
100 
101  apply(b, *(x.get_base()));
102  apply(result_, *(x.get_exp()));
103  arb_pow(result_, b, result_, prec_);
104 
105  arb_clear(b);
106  }
107  }
108 
109  void bvisit(const Sin &x)
110  {
111  apply(result_, *(x.get_arg()));
112  arb_sin(result_, result_, prec_);
113  }
114 
115  void bvisit(const Cos &x)
116  {
117  apply(result_, *(x.get_arg()));
118  arb_cos(result_, result_, prec_);
119  }
120 
121  void bvisit(const Tan &x)
122  {
123  apply(result_, *(x.get_arg()));
124  arb_tan(result_, result_, prec_);
125  }
126 
127  void bvisit(const Symbol &)
128  {
129  throw SymEngineException("Symbol cannot be evaluated as an arb type.");
130  }
131 
132  void bvisit(const UIntPoly &)
133  {
134  throw NotImplementedError("Not Implemented");
135  }
136 
137  void bvisit(const Complex &)
138  {
139  throw NotImplementedError("Not Implemented");
140  }
141 
142  void bvisit(const ComplexDouble &)
143  {
144  throw NotImplementedError("Not Implemented");
145  }
146 
147  void bvisit(const RealMPFR &)
148  {
149  throw NotImplementedError("Not Implemented");
150  }
151 #ifdef HAVE_SYMENGINE_MPC
152  void bvisit(const ComplexMPC &)
153  {
154  throw NotImplementedError("Not Implemented");
155  }
156 #endif
157  void bvisit(const Log &x)
158  {
159  apply(result_, *(x.get_arg()));
160  arb_log(result_, result_, prec_);
161  }
162 
163  void bvisit(const Derivative &)
164  {
165  throw NotImplementedError("Not Implemented");
166  }
167 
168  void bvisit(const Cot &x)
169  {
170  apply(result_, *(x.get_arg()));
171  arb_cot(result_, result_, prec_);
172  }
173 
174  void bvisit(const Csc &x)
175  {
176  apply(result_, *(x.get_arg()));
177  arb_sin(result_, result_, prec_);
178  arb_inv(result_, result_, prec_);
179  }
180 
181  void bvisit(const Sec &x)
182  {
183  apply(result_, *(x.get_arg()));
184  arb_cos(result_, result_, prec_);
185  arb_inv(result_, result_, prec_);
186  }
187 
188  void bvisit(const ASin &x)
189  {
190  apply(result_, *(x.get_arg()));
191  arb_asin(result_, result_, prec_);
192  }
193 
194  void bvisit(const ACos &x)
195  {
196  apply(result_, *(x.get_arg()));
197  arb_acos(result_, result_, prec_);
198  }
199 
200  void bvisit(const ASec &x)
201  {
202  apply(result_, *(x.get_arg()));
203  arb_inv(result_, result_, prec_);
204  arb_acos(result_, result_, prec_);
205  }
206 
207  void bvisit(const ACsc &x)
208  {
209  apply(result_, *(x.get_arg()));
210  arb_inv(result_, result_, prec_);
211  arb_asin(result_, result_, prec_);
212  }
213 
214  void bvisit(const ATan &x)
215  {
216  apply(result_, *(x.get_arg()));
217  arb_atan(result_, result_, prec_);
218  }
219 
220  void bvisit(const ACot &x)
221  {
222  apply(result_, *(x.get_arg()));
223  arb_inv(result_, result_, prec_);
224  arb_atan(result_, result_, prec_);
225  }
226 
227  void bvisit(const ATan2 &x)
228  {
229  arb_t t;
230  arb_init(t);
231 
232  apply(t, *(x.get_num()));
233  apply(result_, *(x.get_den()));
234  arb_atan2(result_, t, result_, prec_);
235 
236  arb_clear(t);
237  }
238 
239  void bvisit(const LambertW &)
240  {
241  throw NotImplementedError("Not Implemented");
242  }
243 
244  void bvisit(const FunctionWrapper &x)
245  {
246  x.eval(prec_)->accept(*this);
247  }
248 
249  void bvisit(const Sinh &x)
250  {
251  apply(result_, *(x.get_arg()));
252  arb_sinh(result_, result_, prec_);
253  }
254 
255  void bvisit(const Csch &)
256  {
257  throw NotImplementedError("Not Implemented");
258  }
259 
260  void bvisit(const Cosh &x)
261  {
262  apply(result_, *(x.get_arg()));
263  arb_cosh(result_, result_, prec_);
264  }
265 
266  void bvisit(const Sech &)
267  {
268  throw NotImplementedError("Not Implemented");
269  }
270 
271  void bvisit(const Tanh &x)
272  {
273  apply(result_, *(x.get_arg()));
274  arb_tanh(result_, result_, prec_);
275  }
276 
277  void bvisit(const Coth &x)
278  {
279  apply(result_, *(x.get_arg()));
280  arb_coth(result_, result_, prec_);
281  }
282 
283  void bvisit(const Max &x)
284  {
285  arb_t t;
286  arb_init(t);
287 
288  auto d = x.get_args();
289  auto p = d.begin();
290  apply(result_, *(*p));
291  p++;
292 
293  for (; p != d.end(); p++) {
294 
295  apply(t, *(*p));
296  if (arb_gt(t, result_))
297  arb_set(result_, t);
298  }
299 
300  arb_clear(t);
301  }
302 
303  void bvisit(const Min &x)
304  {
305  arb_t t;
306  arb_init(t);
307 
308  auto d = x.get_args();
309  auto p = d.begin();
310  apply(result_, *(*p));
311  p++;
312 
313  for (; p != d.end(); p++) {
314 
315  apply(t, *(*p));
316  if (arb_lt(t, result_))
317  arb_set(result_, t);
318  }
319 
320  arb_clear(t);
321  }
322 
323  void bvisit(const ACsch &)
324  {
325  throw NotImplementedError("Not Implemented");
326  }
327 
328  void bvisit(const ASinh &x)
329  {
330  apply(result_, *(x.get_arg()));
331  arb_asinh(result_, result_, prec_);
332  }
333 
334  void bvisit(const ACosh &x)
335  {
336  apply(result_, *(x.get_arg()));
337  arb_acosh(result_, result_, prec_);
338  }
339 
340  void bvisit(const ATanh &x)
341  {
342  apply(result_, *(x.get_arg()));
343  arb_atanh(result_, result_, prec_);
344  }
345 
346  void bvisit(const ACoth &x)
347  {
348  apply(result_, *(x.get_arg()));
349  arb_inv(result_, result_, prec_);
350  arb_atanh(result_, result_, prec_);
351  }
352 
353  void bvisit(const ASech &x)
354  {
355  apply(result_, *(x.get_arg()));
356  arb_inv(result_, result_, prec_);
357  arb_acosh(result_, result_, prec_);
358  }
359 
360  void bvisit(const KroneckerDelta &)
361  {
362  throw NotImplementedError("Not Implemented");
363  }
364 
365  void bvisit(const LeviCivita &)
366  {
367  throw NotImplementedError("Not Implemented");
368  }
369 
370  void bvisit(const Zeta &x)
371  {
372  arb_t t_;
373  arb_init(t_);
374 
375  apply(t_, *(x.get_arg1()));
376  apply(result_, *(x.get_arg2()));
377  arb_hurwitz_zeta(result_, t_, result_, prec_);
378 
379  arb_clear(t_);
380  }
381 
382  void bvisit(const Dirichlet_eta &)
383  {
384  throw NotImplementedError("Not Implemented");
385  }
386 
387  void bvisit(const Gamma &x)
388  {
389  apply(result_, *(x.get_args())[0]);
390  arb_gamma(result_, result_, prec_);
391  }
392 
393  void bvisit(const LogGamma &x)
394  {
395  apply(result_, *(x.get_args())[0]);
396  arb_lgamma(result_, result_, prec_);
397  }
398 
399  void bvisit(const LowerGamma &)
400  {
401  throw NotImplementedError("Not Implemented");
402  }
403 
404  void bvisit(const UpperGamma &)
405  {
406  throw NotImplementedError("Not Implemented");
407  }
408 
409  void bvisit(const Constant &x)
410  {
411  if (x.__eq__(*pi)) {
412  arb_const_pi(result_, prec_);
413  } else if (x.__eq__(*E)) {
414  arb_const_e(result_, prec_);
415  } else if (x.__eq__(*EulerGamma)) {
416  arb_const_euler(result_, prec_);
417  } else if (x.__eq__(*Catalan)) {
418  arb_const_catalan(result_, prec_);
419  } else if (x.__eq__(*GoldenRatio)) {
420  arb_sqrt_ui(result_, 5, prec_);
421  arb_add_ui(result_, result_, 1, prec_);
422  arb_div_ui(result_, result_, 2, prec_);
423  } else {
424  throw NotImplementedError("Constant " + x.get_name()
425  + " is not implemented.");
426  }
427  }
428 
429  void bvisit(const Abs &x)
430  {
431  apply(result_, *(x.get_arg()));
432  arb_abs(result_, result_);
433  }
434 
435  void bvisit(const Basic &)
436  {
437  throw NotImplementedError("Not Implemented");
438  }
439 
440  void bvisit(const NumberWrapper &x)
441  {
442  x.eval(prec_)->accept(*this);
443  }
444 
445  void bvisit(const UnevaluatedExpr &x)
446  {
447  apply(result_, *x.get_arg());
448  }
449 };
450 
451 void eval_arb(arb_t result, const Basic &b, long precision)
452 {
453  EvalArbVisitor v(precision);
454  v.apply(result, b);
455 }
456 
457 } // namespace SymEngine
458 
459 #endif // HAVE_SYMENGINE_ARB
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21