Loading...
Searching...
No Matches
lambda_double.h
1#ifndef SYMENGINE_LAMBDA_DOUBLE_H
2#define SYMENGINE_LAMBDA_DOUBLE_H
3
4#include <cmath>
5#include <limits>
7#include <symengine/symengine_exception.h>
8#include <symengine/visitor.h>
9
10namespace SymEngine
11{
12
13template <typename T>
14class LambdaDoubleVisitor : public BaseVisitor<LambdaDoubleVisitor<T>>
15{
16protected:
17 /*
18 The 'result_' variable is assigned into at the very end of each visit()
19 methods below. The only place where these methods are called from is the
20 line 'b.accept(*this)' in apply() and the 'result_' is immediately
21 returned. Thus no corruption can happen and apply() can be safely called
22 recursively.
23 */
24
25 typedef std::function<T(const T *x)> fn;
26 std::vector<fn> results;
27 std::vector<T> cse_intermediate_results;
28
30 cse_intermediate_fns_map;
31 std::vector<fn> cse_intermediate_fns;
32 fn result_;
33 vec_basic symbols;
34
35public:
36 LambdaDoubleVisitor() = default;
38 LambdaDoubleVisitor &operator=(LambdaDoubleVisitor &&) = default;
39 // delete copy constructor:
40 // https://github.com/symengine/symengine/issues/1674
42 LambdaDoubleVisitor &operator=(const LambdaDoubleVisitor &) = delete;
43
44 void init(const vec_basic &x, const Basic &b, bool cse = false)
45 {
46 vec_basic outputs = {b.rcp_from_this()};
47 init(x, outputs, cse);
48 }
49
50 void init(const vec_basic &inputs, const vec_basic &outputs,
51 bool cse = false)
52 {
53 results.clear();
54 cse_intermediate_fns.clear();
55 symbols = inputs;
56 if (not cse) {
57 for (auto &p : outputs) {
58 results.push_back(apply(*p));
59 }
60 } else {
61 vec_basic reduced_exprs;
62 vec_pair replacements;
63 // cse the outputs
64 SymEngine::cse(replacements, reduced_exprs, outputs);
65 cse_intermediate_results.resize(replacements.size());
66 for (auto &rep : replacements) {
67 auto res = apply(*(rep.second));
68 // Store the replacement symbol values in a dictionary for
69 // faster
70 // lookup for initialization
71 cse_intermediate_fns_map[rep.first]
72 = cse_intermediate_fns.size();
73 // Store it in a vector for faster use in call
74 cse_intermediate_fns.push_back(res);
75 }
76 // Generate functions for all the reduced exprs and save it
77 for (unsigned i = 0; i < outputs.size(); i++) {
78 results.push_back(apply(*reduced_exprs[i]));
79 }
80 // We don't need the cse_intermediate_fns_map anymore
81 cse_intermediate_fns_map.clear();
82 symbols.clear();
83 }
84 }
85
86 fn apply(const Basic &b)
87 {
88 b.accept(*this);
89 return result_;
90 }
91
92 T call(const std::vector<T> &vec)
93 {
94 T res;
95 call(&res, vec.data());
96 return res;
97 }
98
99 void call(T *outs, const T *inps)
100 {
101 if (cse_intermediate_fns.size() > 0) {
102 for (unsigned i = 0; i < cse_intermediate_fns.size(); ++i) {
103 cse_intermediate_results[i] = cse_intermediate_fns[i](inps);
104 }
105 }
106 for (unsigned i = 0; i < results.size(); ++i) {
107 outs[i] = results[i](inps);
108 }
109 return;
110 }
111
112 void bvisit(const Symbol &x)
113 {
114 for (unsigned i = 0; i < symbols.size(); ++i) {
115 if (eq(x, *symbols[i])) {
116 result_ = [=](const T *x) { return x[i]; };
117 return;
118 }
119 }
120 auto it = cse_intermediate_fns_map.find(x.rcp_from_this());
121 if (it != cse_intermediate_fns_map.end()) {
122 auto index = it->second;
123 T *cse_intermediate_result = &(cse_intermediate_results[index]);
124 result_ = [=](const T *x) { return *cse_intermediate_result; };
125 return;
126 }
127 throw SymEngineException("Symbol not in the symbols vector.");
128 };
129
130 void bvisit(const Integer &x)
131 {
132 T tmp = mp_get_d(x.as_integer_class());
133 result_ = [=](const T *x_) { return tmp; };
134 }
135
136 void bvisit(const Rational &x)
137 {
138 T tmp = mp_get_d(x.as_rational_class());
139 result_ = [=](const T *x) { return tmp; };
140 }
141
142 void bvisit(const RealDouble &x)
143 {
144 T tmp = x.i;
145 result_ = [=](const T *x) { return tmp; };
146 }
147
148#ifdef HAVE_SYMENGINE_MPFR
149 void bvisit(const RealMPFR &x)
150 {
151 T tmp = mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN);
152 result_ = [=](const T *x) { return tmp; };
153 }
154#endif
155
156 void bvisit(const Add &x)
157 {
158 fn tmp = apply(*x.get_coef());
159 fn tmp1, tmp2;
160 for (const auto &p : x.get_dict()) {
161 tmp1 = apply(*(p.first));
162 tmp2 = apply(*(p.second));
163 tmp = [=](const T *x) { return tmp(x) + tmp1(x) * tmp2(x); };
164 }
165 result_ = tmp;
166 }
167
168 void bvisit(const Mul &x)
169 {
170 fn tmp = apply(*x.get_coef());
171 fn tmp1, tmp2;
172 for (const auto &p : x.get_dict()) {
173 tmp1 = apply(*(p.first));
174 tmp2 = apply(*(p.second));
175 tmp = [=](const T *x) {
176 return tmp(x) * std::pow(tmp1(x), tmp2(x));
177 };
178 }
179 result_ = tmp;
180 }
181
182 void bvisit(const Pow &x)
183 {
184 fn exp_ = apply(*(x.get_exp()));
185 if (eq(*(x.get_base()), *E)) {
186 result_ = [=](const T *x) { return std::exp(exp_(x)); };
187 } else {
188 fn base_ = apply(*(x.get_base()));
189 result_ = [=](const T *x) { return std::pow(base_(x), exp_(x)); };
190 }
191 }
192
193 void bvisit(const Sin &x)
194 {
195 fn tmp = apply(*(x.get_arg()));
196 result_ = [=](const T *x) { return std::sin(tmp(x)); };
197 }
198
199 void bvisit(const Cos &x)
200 {
201 fn tmp = apply(*(x.get_arg()));
202 result_ = [=](const T *x) { return std::cos(tmp(x)); };
203 }
204
205 void bvisit(const Tan &x)
206 {
207 fn tmp = apply(*(x.get_arg()));
208 result_ = [=](const T *x) { return std::tan(tmp(x)); };
209 }
210
211 void bvisit(const Log &x)
212 {
213 fn tmp = apply(*(x.get_arg()));
214 result_ = [=](const T *x) { return std::log(tmp(x)); };
215 };
216
217 void bvisit(const Cot &x)
218 {
219 fn tmp = apply(*(x.get_arg()));
220 result_ = [=](const T *x) { return 1.0 / std::tan(tmp(x)); };
221 };
222
223 void bvisit(const Csc &x)
224 {
225 fn tmp = apply(*(x.get_arg()));
226 result_ = [=](const T *x) { return 1.0 / std::sin(tmp(x)); };
227 };
228
229 void bvisit(const Sec &x)
230 {
231 fn tmp = apply(*(x.get_arg()));
232 result_ = [=](const T *x) { return 1.0 / std::cos(tmp(x)); };
233 };
234
235 void bvisit(const ASin &x)
236 {
237 fn tmp = apply(*(x.get_arg()));
238 result_ = [=](const T *x) { return std::asin(tmp(x)); };
239 };
240
241 void bvisit(const ACos &x)
242 {
243 fn tmp = apply(*(x.get_arg()));
244 result_ = [=](const T *x) { return std::acos(tmp(x)); };
245 };
246
247 void bvisit(const ASec &x)
248 {
249 fn tmp = apply(*(x.get_arg()));
250 result_ = [=](const T *x) { return std::acos(1.0 / tmp(x)); };
251 };
252
253 void bvisit(const ACsc &x)
254 {
255 fn tmp = apply(*(x.get_arg()));
256 result_ = [=](const T *x) { return std::asin(1.0 / tmp(x)); };
257 };
258
259 void bvisit(const ATan &x)
260 {
261 fn tmp = apply(*(x.get_arg()));
262 result_ = [=](const T *x) { return std::atan(tmp(x)); };
263 };
264
265 void bvisit(const ACot &x)
266 {
267 fn tmp = apply(*(x.get_arg()));
268 result_ = [=](const T *x) { return std::atan(1.0 / tmp(x)); };
269 };
270
271 void bvisit(const Sinh &x)
272 {
273 fn tmp = apply(*(x.get_arg()));
274 result_ = [=](const T *x) { return std::sinh(tmp(x)); };
275 };
276
277 void bvisit(const Csch &x)
278 {
279 fn tmp = apply(*(x.get_arg()));
280 result_ = [=](const T *x) { return 1.0 / std::sinh(tmp(x)); };
281 };
282
283 void bvisit(const Cosh &x)
284 {
285 fn tmp = apply(*(x.get_arg()));
286 result_ = [=](const T *x) { return std::cosh(tmp(x)); };
287 };
288
289 void bvisit(const Sech &x)
290 {
291 fn tmp = apply(*(x.get_arg()));
292 result_ = [=](const T *x) { return 1.0 / std::cosh(tmp(x)); };
293 };
294
295 void bvisit(const Tanh &x)
296 {
297 fn tmp = apply(*(x.get_arg()));
298 result_ = [=](const T *x) { return std::tanh(tmp(x)); };
299 };
300
301 void bvisit(const Coth &x)
302 {
303 fn tmp = apply(*(x.get_arg()));
304 result_ = [=](const T *x) { return 1.0 / std::tanh(tmp(x)); };
305 };
306
307 void bvisit(const ASinh &x)
308 {
309 fn tmp = apply(*(x.get_arg()));
310 result_ = [=](const T *x) { return std::asinh(tmp(x)); };
311 };
312
313 void bvisit(const ACsch &x)
314 {
315 fn tmp = apply(*(x.get_arg()));
316 result_ = [=](const T *x) { return std::asinh(1.0 / tmp(x)); };
317 };
318
319 void bvisit(const ACosh &x)
320 {
321 fn tmp = apply(*(x.get_arg()));
322 result_ = [=](const T *x) { return std::acosh(tmp(x)); };
323 };
324
325 void bvisit(const ATanh &x)
326 {
327 fn tmp = apply(*(x.get_arg()));
328 result_ = [=](const T *x) { return std::atanh(tmp(x)); };
329 };
330
331 void bvisit(const ACoth &x)
332 {
333 fn tmp = apply(*(x.get_arg()));
334 result_ = [=](const T *x) { return std::atanh(1.0 / tmp(x)); };
335 };
336
337 void bvisit(const ASech &x)
338 {
339 fn tmp = apply(*(x.get_arg()));
340 result_ = [=](const T *x) { return std::acosh(1.0 / tmp(x)); };
341 };
342
343 void bvisit(const Constant &x)
344 {
345 T tmp = eval_double(x);
346 result_ = [=](const T *x) { return tmp; };
347 };
348
349 void bvisit(const Abs &x)
350 {
351 fn tmp = apply(*(x.get_arg()));
352 result_ = [=](const T *x) { return std::abs(tmp(x)); };
353 };
354
355 void bvisit(const Basic &)
356 {
357 throw NotImplementedError("Not Implemented");
358 };
359
360 void bvisit(const UnevaluatedExpr &x)
361 {
362 apply(*x.get_arg());
363 };
364};
365
367 : public BaseVisitor<LambdaRealDoubleVisitor, LambdaDoubleVisitor<double>>
368{
369public:
370 // Classes not implemented are
371 // Subs, UpperGamma, LowerGamma, Dirichlet_eta, Zeta
372 // LeviCivita, KroneckerDelta, FunctionSymbol, LambertW
373 // Derivative, Complex, ComplexDouble, ComplexMPC
374
375 using LambdaDoubleVisitor::bvisit;
376
377 void bvisit(const ATan2 &x)
378 {
379 fn num = apply(*(x.get_num()));
380 fn den = apply(*(x.get_den()));
381 result_ = [=](const double *x) { return std::atan2(num(x), den(x)); };
382 };
383
384 void bvisit(const Gamma &x)
385 {
386 fn tmp = apply(*(x.get_args()[0]));
387 result_ = [=](const double *x) { return std::tgamma(tmp(x)); };
388 };
389
390 void bvisit(const LogGamma &x)
391 {
392 fn tmp = apply(*(x.get_args()[0]));
393 result_ = [=](const double *x) { return std::lgamma(tmp(x)); };
394 };
395
396 void bvisit(const Erf &x)
397 {
398 fn tmp = apply(*(x.get_args()[0]));
399 result_ = [=](const double *x) { return std::erf(tmp(x)); };
400 }
401
402 void bvisit(const Erfc &x)
403 {
404 fn tmp = apply(*(x.get_args()[0]));
405 result_ = [=](const double *x) { return std::erfc(tmp(x)); };
406 }
407
408 void bvisit(const Equality &x)
409 {
410 fn lhs_ = apply(*(x.get_arg1()));
411 fn rhs_ = apply(*(x.get_arg2()));
412 result_ = [=](const double *x) { return (lhs_(x) == rhs_(x)); };
413 }
414
415 void bvisit(const Unequality &x)
416 {
417 fn lhs_ = apply(*(x.get_arg1()));
418 fn rhs_ = apply(*(x.get_arg2()));
419 result_ = [=](const double *x) { return (lhs_(x) != rhs_(x)); };
420 }
421
422 void bvisit(const LessThan &x)
423 {
424 fn lhs_ = apply(*(x.get_arg1()));
425 fn rhs_ = apply(*(x.get_arg2()));
426 result_ = [=](const double *x) { return (lhs_(x) <= rhs_(x)); };
427 }
428
429 void bvisit(const StrictLessThan &x)
430 {
431 fn lhs_ = apply(*(x.get_arg1()));
432 fn rhs_ = apply(*(x.get_arg2()));
433 result_ = [=](const double *x) { return (lhs_(x) < rhs_(x)); };
434 }
435
436 void bvisit(const And &x)
437 {
438 std::vector<fn> applys;
439 for (const auto &p : x.get_args()) {
440 applys.push_back(apply(*p));
441 }
442
443 result_ = [=](const double *x) {
444 bool result = bool(applys[0](x));
445 for (unsigned int i = 0; i < applys.size(); i++) {
446 result = result && bool(applys[i](x));
447 }
448 return double(result);
449 };
450 }
451
452 void bvisit(const Or &x)
453 {
454 std::vector<fn> applys;
455 for (const auto &p : x.get_args()) {
456 applys.push_back(apply(*p));
457 }
458
459 result_ = [=](const double *x) {
460 bool result = bool(applys[0](x));
461 for (unsigned int i = 0; i < applys.size(); i++) {
462 result = result || bool(applys[i](x));
463 }
464 return double(result);
465 };
466 }
467
468 void bvisit(const Xor &x)
469 {
470 std::vector<fn> applys;
471 for (const auto &p : x.get_args()) {
472 applys.push_back(apply(*p));
473 }
474
475 result_ = [=](const double *x) {
476 bool result = bool(applys[0](x));
477 for (unsigned int i = 0; i < applys.size(); i++) {
478 result = result != bool(applys[i](x));
479 }
480 return double(result);
481 };
482 }
483
484 void bvisit(const Not &x)
485 {
486 fn tmp = apply(*(x.get_arg()));
487 result_ = [=](const double *x) { return double(not bool(tmp(x))); };
488 }
489
490 void bvisit(const Max &x)
491 {
492 std::vector<fn> applys;
493 for (const auto &p : x.get_args()) {
494 applys.push_back(apply(*p));
495 }
496
497 result_ = [=](const double *x) {
498 double result = applys[0](x);
499 for (unsigned int i = 0; i < applys.size(); i++) {
500 result = std::max(result, applys[i](x));
501 }
502 return result;
503 };
504 };
505
506 void bvisit(const Min &x)
507 {
508 std::vector<fn> applys;
509 for (const auto &p : x.get_args()) {
510 applys.push_back(apply(*p));
511 }
512
513 result_ = [=](const double *x) {
514 double result = applys[0](x);
515 for (unsigned int i = 0; i < applys.size(); i++) {
516 result = std::min(result, applys[i](x));
517 }
518 return result;
519 };
520 };
521
522 void bvisit(const Sign &x)
523 {
524 fn tmp = apply(*(x.get_arg()));
525 result_ = [=](const double *x) {
526 return tmp(x) == 0.0 ? 0.0 : (tmp(x) < 0.0 ? -1.0 : 1.0);
527 };
528 };
529
530 void bvisit(const Floor &x)
531 {
532 fn tmp = apply(*(x.get_arg()));
533 result_ = [=](const double *x) { return std::floor(tmp(x)); };
534 };
535
536 void bvisit(const Ceiling &x)
537 {
538 fn tmp = apply(*(x.get_arg()));
539 result_ = [=](const double *x) { return std::ceil(tmp(x)); };
540 };
541
542 void bvisit(const Truncate &x)
543 {
544 fn tmp = apply(*(x.get_arg()));
545 result_ = [=](const double *x) { return std::trunc(tmp(x)); };
546 };
547
548 void bvisit(const Infty &x)
549 {
550 if (x.is_negative_infinity()) {
551 result_ = [=](const double * /* x */) {
553 };
554 } else if (x.is_positive_infinity()) {
555 result_ = [=](const double * /* x */) {
557 };
558 } else {
559 throw SymEngineException(
560 "LambdaDouble can only represent real valued infinity");
561 }
562 }
563 void bvisit(const NaN &nan)
564 {
565 assert(&nan == &(*Nan) /* singleton, or do we support NaN quiet/singaling nan with payload? */);
566 result_ = [](const double * /* x */) {
568 };
569 }
570 void bvisit(const Contains &cts)
571 {
572 const auto fn_expr = apply(*cts.get_expr());
573 const auto set = cts.get_set();
574 if (is_a<Interval>(*set)) {
575 const auto &interv = down_cast<const Interval &>(*set);
576 const auto fn_start = apply(*interv.get_start());
577 const auto fn_end = apply(*interv.get_end());
578 const bool left_open = interv.get_left_open();
579 const bool right_open = interv.get_right_open();
580 result_ = [=](const double *x) {
581 const auto val_expr = fn_expr(x);
582 const auto val_start = fn_start(x);
583 const auto val_end = fn_end(x);
584 bool left_ok, right_ok;
585 if (val_start == -std::numeric_limits<double>::infinity()) {
586 left_ok = !std::isnan(val_expr);
587 } else {
588 left_ok = (left_open) ? (val_start < val_expr)
589 : (val_start <= val_expr);
590 }
591 if (val_end == std::numeric_limits<double>::infinity()) {
592 right_ok = !std::isnan(val_expr);
593 } else {
594 right_ok = (right_open) ? (val_expr < val_end)
595 : (val_expr <= val_end);
596 }
597 return (left_ok && right_ok) ? 1.0 : 0.0;
598 };
599 } else {
600 throw SymEngineException("LambdaDoubleVisitor: only ``Interval`` "
601 "implemented for ``Contains``.");
602 }
603 }
604
605 void bvisit(const BooleanAtom &ba)
606 {
607 const bool val = ba.get_val();
608 result_ = [=](const double * /* x */) { return (val) ? 1.0 : 0.0; };
609 }
610
611 void bvisit(const Piecewise &pw)
612 {
613 SYMENGINE_ASSERT_MSG(
614 eq(*pw.get_vec().back().second, *boolTrue),
615 "LambdaDouble requires a (Expr, True) at the end of Piecewise");
616
617 std::vector<fn> applys;
618 std::vector<fn> preds;
619 for (const auto &expr_pred : pw.get_vec()) {
620 applys.push_back(apply(*expr_pred.first));
621 preds.push_back(apply(*expr_pred.second));
622 }
623 result_ = [=](const double *x) {
624 for (size_t i = 0;; ++i) {
625 if (preds[i](x) == 1.0) {
626 return applys[i](x);
627 }
628 }
629 throw SymEngineException(
630 "Unexpectedly reached end of Piecewise function.");
631 };
632 }
633};
634
636 : public BaseVisitor<LambdaComplexDoubleVisitor,
637 LambdaDoubleVisitor<std::complex<double>>>
638{
639public:
640 // Classes not implemented are
641 // Subs, UpperGamma, LowerGamma, Dirichlet_eta, Zeta
642 // LeviCivita, KroneckerDelta, FunctionSymbol, LambertW
643 // Derivative, ATan2, Gamma
644
645 using LambdaDoubleVisitor::bvisit;
646
647 void bvisit(const Complex &x)
648 {
649 double t1 = mp_get_d(x.real_), t2 = mp_get_d(x.imaginary_);
650 result_ = [=](const std::complex<double> *x) {
651 return std::complex<double>(t1, t2);
652 };
653 };
654
655 void bvisit(const ComplexDouble &x)
656 {
657 std::complex<double> tmp = x.i;
658 result_ = [=](const std::complex<double> *x) { return tmp; };
659 };
660#ifdef HAVE_SYMENGINE_MPC
661 void bvisit(const ComplexMPC &x)
662 {
663 mpfr_class t(x.get_prec());
664 double real, imag;
665 mpc_real(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
666 real = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
667 mpc_imag(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
668 imag = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
669 std::complex<double> tmp(real, imag);
670 result_ = [=](const std::complex<double> *x) { return tmp; };
671 }
672#endif
673};
674} // namespace SymEngine
675#endif // SYMENGINE_LAMBDA_DOUBLE_H
T acos(T... args)
T acosh(T... args)
T asin(T... args)
T asinh(T... args)
T atan2(T... args)
T atan(T... args)
T atanh(T... args)
T back(T... args)
T ceil(T... args)
RCP< const Basic > get_num() const
Definition: functions.h:511
RCP< const Basic > get_den() const
Definition: functions.h:516
The base class for representing addition in symbolic expressions.
Definition: add.h:27
const RCP< const Number > & get_coef() const
Definition: add.h:142
vec_basic get_args() const override
Returns the list of arguments.
Definition: logic.cpp:234
The lowest unit of symbolic representation.
Definition: basic.h:97
Complex Double Class to hold std::complex<double> values.
Complex Class.
Definition: complex.h:33
rational_class real_
Definition: complex.h:38
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
Integer Class.
Definition: integer.h:19
const integer_class & as_integer_class() const
Convert to integer_class.
Definition: integer.h:45
vec_basic get_args() const override
Returns the list of arguments.
Definition: functions.h:159
RCP< const Basic > get_arg() const
Definition: functions.h:36
vec_basic get_args() const override
Returns the list of arguments.
Definition: functions.h:40
vec_basic get_args() const override
Returns the list of arguments.
Definition: logic.cpp:302
RCP< const Basic > get_base() const
Definition: pow.h:37
RCP< const Basic > get_exp() const
Definition: pow.h:42
Rational Class.
Definition: rational.h:16
const rational_class & as_rational_class() const
Convert to rational_class.
Definition: rational.h:50
RealDouble Class to hold double values.
Definition: real_double.h:20
RCP< const Basic > get_arg2() const
Definition: functions.h:96
RCP< const Basic > get_arg1() const
Definition: functions.h:91
vec_basic get_args() const override
Returns the list of arguments.
Definition: logic.cpp:412
T clear(T... args)
T cos(T... args)
T cosh(T... args)
T data(T... args)
T end(T... args)
T erf(T... args)
T erfc(T... args)
T exp(T... args)
T find(T... args)
T floor(T... args)
T infinity(T... args)
T isnan(T... args)
T lgamma(T... args)
T log(T... args)
T max(T... args)
T min(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
T pow(T... args)
T push_back(T... args)
T resize(T... args)
T signaling_NaN(T... args)
T sin(T... args)
T sinh(T... args)
T size(T... args)
Our less operator (<):
Definition: basic.h:228
T tan(T... args)
T tanh(T... args)
T tgamma(T... args)
T trunc(T... args)