eval_mpc.cpp
1 #include <symengine/visitor.h>
2 #include <symengine/eval_mpc.h>
3 #include <symengine/symengine_exception.h>
4 
5 #ifdef HAVE_SYMENGINE_MPC
6 
7 namespace SymEngine
8 {
9 
10 class EvalMPCVisitor : public BaseVisitor<EvalMPCVisitor>
11 {
12 protected:
13  mpfr_rnd_t rnd_;
14  mpc_ptr result_;
15 
16 public:
17  EvalMPCVisitor(mpfr_rnd_t rnd) : rnd_{rnd} {}
18 
19  void apply(mpc_ptr result, const Basic &b)
20  {
21  mpc_ptr tmp = result_;
22  result_ = result;
23  b.accept(*this);
24  result_ = tmp;
25  }
26 
27  void bvisit(const Integer &x)
28  {
29  mpc_set_z(result_, get_mpz_t(x.as_integer_class()), rnd_);
30  }
31 
32  void bvisit(const Rational &x)
33  {
34  mpc_set_q(result_, get_mpq_t(x.as_rational_class()), rnd_);
35  }
36 
37  void bvisit(const RealDouble &x)
38  {
39  mpc_set_d(result_, x.i, rnd_);
40  }
41 
42  void bvisit(const Complex &x)
43  {
44  mpc_set_q_q(result_, get_mpq_t(x.real_), get_mpq_t(x.imaginary_), rnd_);
45  }
46 
47  void bvisit(const ComplexDouble &x)
48  {
49  mpc_set_d_d(result_, x.i.real(), x.i.imag(), rnd_);
50  }
51 
52  void bvisit(const RealMPFR &x)
53  {
54  mpc_set_fr(result_, x.i.get_mpfr_t(), rnd_);
55  }
56 
57  void bvisit(const ComplexMPC &x)
58  {
59  mpc_set(result_, x.as_mpc().get_mpc_t(), rnd_);
60  }
61 
62  void bvisit(const Add &x)
63  {
64  mpc_t t;
65  mpc_init2(t, mpc_get_prec(result_));
66 
67  auto d = x.get_args();
68  auto p = d.begin();
69  apply(result_, *(*p));
70  p++;
71  for (; p != d.end(); p++) {
72  apply(t, *(*p));
73  mpc_add(result_, result_, t, rnd_);
74  }
75  mpc_clear(t);
76  }
77 
78  void bvisit(const Mul &x)
79  {
80  mpc_t t;
81  mpc_init2(t, mpc_get_prec(result_));
82 
83  auto d = x.get_args();
84  auto p = d.begin();
85  apply(result_, *(*p));
86  p++;
87  for (; p != d.end(); p++) {
88  apply(t, *(*p));
89  mpc_mul(result_, result_, t, rnd_);
90  }
91  mpc_clear(t);
92  }
93 
94  void bvisit(const Pow &x)
95  {
96  if (eq(*x.get_base(), *E)) {
97  apply(result_, *(x.get_exp()));
98  mpc_exp(result_, result_, rnd_);
99  } else {
100  mpc_t t;
101  mpc_init2(t, mpc_get_prec(result_));
102 
103  apply(t, *(x.get_base()));
104  apply(result_, *(x.get_exp()));
105  mpc_pow(result_, t, result_, rnd_);
106 
107  mpc_clear(t);
108  }
109  }
110 
111  void bvisit(const Sin &x)
112  {
113  apply(result_, *(x.get_arg()));
114  mpc_sin(result_, result_, rnd_);
115  }
116 
117  void bvisit(const Cos &x)
118  {
119  apply(result_, *(x.get_arg()));
120  mpc_cos(result_, result_, rnd_);
121  }
122 
123  void bvisit(const Tan &x)
124  {
125  apply(result_, *(x.get_arg()));
126  mpc_tan(result_, result_, rnd_);
127  }
128 
129  void bvisit(const Log &x)
130  {
131  apply(result_, *(x.get_arg()));
132  mpc_log(result_, result_, rnd_);
133  }
134 
135  void bvisit(const Cot &x)
136  {
137  apply(result_, *(x.get_arg()));
138  mpc_tan(result_, result_, rnd_);
139  mpc_ui_div(result_, 1, result_, rnd_);
140  }
141 
142  void bvisit(const Csc &x)
143  {
144  apply(result_, *(x.get_arg()));
145  mpc_sin(result_, result_, rnd_);
146  mpc_ui_div(result_, 1, result_, rnd_);
147  }
148 
149  void bvisit(const Sec &x)
150  {
151  apply(result_, *(x.get_arg()));
152  mpc_cos(result_, result_, rnd_);
153  mpc_ui_div(result_, 1, result_, rnd_);
154  }
155 
156  void bvisit(const ASin &x)
157  {
158  apply(result_, *(x.get_arg()));
159  mpc_asin(result_, result_, rnd_);
160  }
161 
162  void bvisit(const ACos &x)
163  {
164  apply(result_, *(x.get_arg()));
165  mpc_acos(result_, result_, rnd_);
166  }
167 
168  void bvisit(const ASec &x)
169  {
170  apply(result_, *(x.get_arg()));
171  mpc_ui_div(result_, 1, result_, rnd_);
172  mpc_acos(result_, result_, rnd_);
173  }
174 
175  void bvisit(const ACsc &x)
176  {
177  apply(result_, *(x.get_arg()));
178  mpc_ui_div(result_, 1, result_, rnd_);
179  mpc_asin(result_, result_, rnd_);
180  }
181 
182  void bvisit(const ATan &x)
183  {
184  apply(result_, *(x.get_arg()));
185  mpc_atan(result_, result_, rnd_);
186  }
187 
188  void bvisit(const ACot &x)
189  {
190  apply(result_, *(x.get_arg()));
191  mpc_ui_div(result_, 1, result_, rnd_);
192  mpc_atan(result_, result_, rnd_);
193  }
194 
195  void bvisit(const Sinh &x)
196  {
197  apply(result_, *(x.get_arg()));
198  mpc_sinh(result_, result_, rnd_);
199  }
200 
201  void bvisit(const Csch &x)
202  {
203  apply(result_, *(x.get_arg()));
204  mpc_sinh(result_, result_, rnd_);
205  mpc_ui_div(result_, 1, result_, rnd_);
206  }
207 
208  void bvisit(const Cosh &x)
209  {
210  apply(result_, *(x.get_arg()));
211  mpc_cosh(result_, result_, rnd_);
212  }
213 
214  void bvisit(const Sech &x)
215  {
216  apply(result_, *(x.get_arg()));
217  mpc_cosh(result_, result_, rnd_);
218  mpc_ui_div(result_, 1, result_, rnd_);
219  }
220 
221  void bvisit(const Tanh &x)
222  {
223  apply(result_, *(x.get_arg()));
224  mpc_tanh(result_, result_, rnd_);
225  }
226 
227  void bvisit(const Coth &x)
228  {
229  apply(result_, *(x.get_arg()));
230  mpc_tanh(result_, result_, rnd_);
231  mpc_ui_div(result_, 1, result_, rnd_);
232  }
233 
234  void bvisit(const ASinh &x)
235  {
236  apply(result_, *(x.get_arg()));
237  mpc_asinh(result_, result_, rnd_);
238  }
239 
240  void bvisit(const ACsch &x)
241  {
242  apply(result_, *(x.get_arg()));
243  mpc_ui_div(result_, 1, result_, rnd_);
244  mpc_asinh(result_, result_, rnd_);
245  }
246 
247  void bvisit(const ACosh &x)
248  {
249  apply(result_, *(x.get_arg()));
250  mpc_acosh(result_, result_, rnd_);
251  }
252 
253  void bvisit(const ATanh &x)
254  {
255  apply(result_, *(x.get_arg()));
256  mpc_atanh(result_, result_, rnd_);
257  }
258 
259  void bvisit(const ACoth &x)
260  {
261  apply(result_, *(x.get_arg()));
262  mpc_ui_div(result_, 1, result_, rnd_);
263  mpc_atanh(result_, result_, rnd_);
264  }
265 
266  void bvisit(const ASech &x)
267  {
268  apply(result_, *(x.get_arg()));
269  mpc_ui_div(result_, 1, result_, rnd_);
270  mpc_acosh(result_, result_, rnd_);
271  };
272 
273  void bvisit(const Constant &x)
274  {
275  if (x.__eq__(*pi)) {
276  mpfr_t t;
277  mpfr_init2(t, mpc_get_prec(result_));
278  mpfr_const_pi(t, rnd_);
279  mpc_set_fr(result_, t, rnd_);
280  mpfr_clear(t);
281  } else if (x.__eq__(*E)) {
282  mpfr_t t;
283  mpfr_init2(t, mpc_get_prec(result_));
284  mpfr_set_ui(t, 1, rnd_);
285  mpfr_exp(t, t, rnd_);
286  mpc_set_fr(result_, t, rnd_);
287  mpfr_clear(t);
288  } else if (x.__eq__(*EulerGamma)) {
289  mpfr_t t;
290  mpfr_init2(t, mpc_get_prec(result_));
291  mpfr_const_euler(t, rnd_);
292  mpc_set_fr(result_, t, rnd_);
293  mpfr_clear(t);
294  } else if (x.__eq__(*Catalan)) {
295  mpfr_t t;
296  mpfr_init2(t, mpc_get_prec(result_));
297  mpfr_const_catalan(t, rnd_);
298  mpc_set_fr(result_, t, rnd_);
299  mpfr_clear(t);
300  } else if (x.__eq__(*GoldenRatio)) {
301  mpfr_t t;
302  mpfr_init2(t, mpc_get_prec(result_));
303  mpfr_sqrt_ui(t, 5, rnd_);
304  mpfr_add_ui(t, t, 1, rnd_);
305  mpfr_div_ui(t, t, 2, rnd_);
306  mpc_set_fr(result_, t, rnd_);
307  mpfr_clear(t);
308  } else {
309  throw NotImplementedError("Constant " + x.get_name()
310  + " is not implemented.");
311  }
312  }
313 
314  void bvisit(const Gamma &x)
315  {
316  throw NotImplementedError("Not implemented");
317  }
318 
319  void bvisit(const Abs &x)
320  {
321  mpfr_t t;
322  mpfr_init2(t, mpc_get_prec(result_));
323  apply(result_, *(x.get_arg()));
324  mpc_abs(t, result_, rnd_);
325  mpc_set_fr(result_, t, rnd_);
326  mpfr_clear(t);
327  };
328 
329  void bvisit(const NumberWrapper &x)
330  {
331  x.eval(mpc_get_prec(result_))->accept(*this);
332  }
333 
334  void bvisit(const FunctionWrapper &x)
335  {
336  x.eval(mpc_get_prec(result_))->accept(*this);
337  }
338 
339  void bvisit(const UnevaluatedExpr &x)
340  {
341  apply(result_, *x.get_arg());
342  }
343 
344  // Classes not implemented are
345  // Subs, UpperGamma, LowerGamma, Dirichlet_eta, Zeta
346  // LeviCivita, KroneckerDelta, FunctionSymbol, LambertW
347  // Derivative, ATan2, Gamma
348  void bvisit(const Basic &)
349  {
350  throw NotImplementedError("Not Implemented");
351  };
352 };
353 
354 void eval_mpc(mpc_ptr result, const Basic &b, mpfr_rnd_t rnd)
355 {
356  EvalMPCVisitor v(rnd);
357  v.apply(result, b);
358 }
359 
360 } // namespace SymEngine
361 
362 #endif // HAVE_SYMENGINE_MPFR
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21