1 #ifndef SYMENGINE_LAMBDA_DOUBLE_H
2 #define SYMENGINE_LAMBDA_DOUBLE_H
7 #include <symengine/symengine_exception.h>
25 typedef std::function<T(
const T *x)> fn;
26 std::vector<fn> results;
27 std::vector<T> cse_intermediate_results;
30 cse_intermediate_fns_map;
31 std::vector<fn> cse_intermediate_fns;
44 void init(
const vec_basic &x,
const Basic &b,
bool cse =
false)
47 init(x, outputs, cse);
50 void init(
const vec_basic &inputs,
const vec_basic &outputs,
54 cse_intermediate_fns.clear();
57 for (
auto &p : outputs) {
58 results.push_back(apply(*p));
61 vec_basic reduced_exprs;
62 vec_pair replacements;
64 SymEngine::cse(replacements, reduced_exprs, outputs);
65 cse_intermediate_results.resize(replacements.size());
66 for (
auto &rep : replacements) {
67 auto res = apply(*(rep.second));
71 cse_intermediate_fns_map[rep.first]
72 = cse_intermediate_fns.size();
74 cse_intermediate_fns.push_back(res);
77 for (
unsigned i = 0; i < outputs.size(); i++) {
78 results.push_back(apply(*reduced_exprs[i]));
81 cse_intermediate_fns_map.clear();
86 fn apply(
const Basic &b)
92 T call(
const std::vector<T> &vec)
95 call(&res, vec.data());
99 void call(T *outs,
const T *inps)
101 if (cse_intermediate_fns.size() > 0) {
102 for (
unsigned i = 0; i < cse_intermediate_fns.size(); ++i) {
103 cse_intermediate_results[i] = cse_intermediate_fns[i](inps);
106 for (
unsigned i = 0; i < results.size(); ++i) {
107 outs[i] = results[i](inps);
112 void bvisit(
const Symbol &x)
114 for (
unsigned i = 0; i < symbols.size(); ++i) {
115 if (
eq(x, *symbols[i])) {
116 result_ = [=](
const T *x) {
return x[i]; };
121 if (it != cse_intermediate_fns_map.end()) {
122 auto index = it->second;
123 T *cse_intermediate_result = &(cse_intermediate_results[index]);
124 result_ = [=](
const T *x) {
return *cse_intermediate_result; };
127 throw SymEngineException(
"Symbol not in the symbols vector.");
133 result_ = [=](
const T *x_) {
return tmp; };
139 result_ = [=](
const T *x) {
return tmp; };
145 result_ = [=](
const T *x) {
return tmp; };
148 #ifdef HAVE_SYMENGINE_MPFR
151 T tmp = mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN);
152 result_ = [=](
const T *x) {
return tmp; };
156 void bvisit(
const Add &x)
160 for (
const auto &p : x.get_dict()) {
161 tmp1 = apply(*(p.first));
162 tmp2 = apply(*(p.second));
163 tmp = [=](
const T *x) {
return tmp(x) + tmp1(x) * tmp2(x); };
168 void bvisit(
const Mul &x)
170 fn tmp = apply(*x.get_coef());
172 for (
const auto &p : x.get_dict()) {
173 tmp1 = apply(*(p.first));
174 tmp2 = apply(*(p.second));
175 tmp = [=](
const T *x) {
176 return tmp(x) * std::pow(tmp1(x), tmp2(x));
182 void bvisit(
const Pow &x)
184 fn exp_ = apply(*(x.
get_exp()));
186 result_ = [=](
const T *x) {
return std::exp(exp_(x)); };
189 result_ = [=](
const T *x) {
return std::pow(base_(x), exp_(x)); };
193 void bvisit(
const Sin &x)
195 fn tmp = apply(*(x.
get_arg()));
196 result_ = [=](
const T *x) {
return std::sin(tmp(x)); };
199 void bvisit(
const Cos &x)
201 fn tmp = apply(*(x.
get_arg()));
202 result_ = [=](
const T *x) {
return std::cos(tmp(x)); };
205 void bvisit(
const Tan &x)
207 fn tmp = apply(*(x.
get_arg()));
208 result_ = [=](
const T *x) {
return std::tan(tmp(x)); };
211 void bvisit(
const Log &x)
213 fn tmp = apply(*(x.
get_arg()));
214 result_ = [=](
const T *x) {
return std::log(tmp(x)); };
217 void bvisit(
const Cot &x)
219 fn tmp = apply(*(x.
get_arg()));
220 result_ = [=](
const T *x) {
return 1.0 / std::tan(tmp(x)); };
223 void bvisit(
const Csc &x)
225 fn tmp = apply(*(x.
get_arg()));
226 result_ = [=](
const T *x) {
return 1.0 / std::sin(tmp(x)); };
229 void bvisit(
const Sec &x)
231 fn tmp = apply(*(x.
get_arg()));
232 result_ = [=](
const T *x) {
return 1.0 / std::cos(tmp(x)); };
235 void bvisit(
const ASin &x)
237 fn tmp = apply(*(x.
get_arg()));
238 result_ = [=](
const T *x) {
return std::asin(tmp(x)); };
241 void bvisit(
const ACos &x)
243 fn tmp = apply(*(x.
get_arg()));
244 result_ = [=](
const T *x) {
return std::acos(tmp(x)); };
247 void bvisit(
const ASec &x)
249 fn tmp = apply(*(x.
get_arg()));
250 result_ = [=](
const T *x) {
return std::acos(1.0 / tmp(x)); };
253 void bvisit(
const ACsc &x)
255 fn tmp = apply(*(x.
get_arg()));
256 result_ = [=](
const T *x) {
return std::asin(1.0 / tmp(x)); };
259 void bvisit(
const ATan &x)
261 fn tmp = apply(*(x.
get_arg()));
262 result_ = [=](
const T *x) {
return std::atan(tmp(x)); };
265 void bvisit(
const ACot &x)
267 fn tmp = apply(*(x.
get_arg()));
268 result_ = [=](
const T *x) {
return std::atan(1.0 / tmp(x)); };
271 void bvisit(
const Sinh &x)
273 fn tmp = apply(*(x.
get_arg()));
274 result_ = [=](
const T *x) {
return std::sinh(tmp(x)); };
277 void bvisit(
const Csch &x)
279 fn tmp = apply(*(x.
get_arg()));
280 result_ = [=](
const T *x) {
return 1.0 / std::sinh(tmp(x)); };
283 void bvisit(
const Cosh &x)
285 fn tmp = apply(*(x.
get_arg()));
286 result_ = [=](
const T *x) {
return std::cosh(tmp(x)); };
289 void bvisit(
const Sech &x)
291 fn tmp = apply(*(x.
get_arg()));
292 result_ = [=](
const T *x) {
return 1.0 / std::cosh(tmp(x)); };
295 void bvisit(
const Tanh &x)
297 fn tmp = apply(*(x.
get_arg()));
298 result_ = [=](
const T *x) {
return std::tanh(tmp(x)); };
301 void bvisit(
const Coth &x)
303 fn tmp = apply(*(x.
get_arg()));
304 result_ = [=](
const T *x) {
return 1.0 / std::tanh(tmp(x)); };
307 void bvisit(
const ASinh &x)
309 fn tmp = apply(*(x.
get_arg()));
310 result_ = [=](
const T *x) {
return std::asinh(tmp(x)); };
313 void bvisit(
const ACsch &x)
315 fn tmp = apply(*(x.
get_arg()));
316 result_ = [=](
const T *x) {
return std::asinh(1.0 / tmp(x)); };
319 void bvisit(
const ACosh &x)
321 fn tmp = apply(*(x.
get_arg()));
322 result_ = [=](
const T *x) {
return std::acosh(tmp(x)); };
325 void bvisit(
const ATanh &x)
327 fn tmp = apply(*(x.
get_arg()));
328 result_ = [=](
const T *x) {
return std::atanh(tmp(x)); };
331 void bvisit(
const ACoth &x)
333 fn tmp = apply(*(x.
get_arg()));
334 result_ = [=](
const T *x) {
return std::atanh(1.0 / tmp(x)); };
337 void bvisit(
const ASech &x)
339 fn tmp = apply(*(x.
get_arg()));
340 result_ = [=](
const T *x) {
return std::acosh(1.0 / tmp(x)); };
345 T tmp = eval_double(x);
346 result_ = [=](
const T *x) {
return tmp; };
349 void bvisit(
const Abs &x)
351 fn tmp = apply(*(x.
get_arg()));
352 result_ = [=](
const T *x) {
return std::abs(tmp(x)); };
355 void bvisit(
const Basic &)
357 throw NotImplementedError(
"Not Implemented");
367 :
public BaseVisitor<LambdaRealDoubleVisitor, LambdaDoubleVisitor<double>>
375 using LambdaDoubleVisitor::bvisit;
377 void bvisit(
const ATan2 &x)
379 fn num = apply(*(x.
get_num()));
380 fn den = apply(*(x.
get_den()));
381 result_ = [=](
const double *x) {
return std::atan2(num(x), den(x)); };
384 void bvisit(
const Gamma &x)
387 result_ = [=](
const double *x) {
return std::tgamma(tmp(x)); };
393 result_ = [=](
const double *x) {
return std::lgamma(tmp(x)); };
396 void bvisit(
const Erf &x)
399 result_ = [=](
const double *x) {
return std::erf(tmp(x)); };
402 void bvisit(
const Erfc &x)
405 result_ = [=](
const double *x) {
return std::erfc(tmp(x)); };
412 result_ = [=](
const double *x) {
return (lhs_(x) == rhs_(x)); };
419 result_ = [=](
const double *x) {
return (lhs_(x) != rhs_(x)); };
426 result_ = [=](
const double *x) {
return (lhs_(x) <= rhs_(x)); };
433 result_ = [=](
const double *x) {
return (lhs_(x) < rhs_(x)); };
436 void bvisit(
const And &x)
438 std::vector<fn> applys;
439 for (
const auto &p : x.
get_args()) {
440 applys.push_back(apply(*p));
443 result_ = [=](
const double *x) {
444 bool result = bool(applys[0](x));
445 for (
unsigned int i = 0; i < applys.size(); i++) {
446 result = result && bool(applys[i](x));
448 return double(result);
452 void bvisit(
const Or &x)
454 std::vector<fn> applys;
455 for (
const auto &p : x.
get_args()) {
456 applys.push_back(apply(*p));
459 result_ = [=](
const double *x) {
460 bool result = bool(applys[0](x));
461 for (
unsigned int i = 0; i < applys.size(); i++) {
462 result = result || bool(applys[i](x));
464 return double(result);
468 void bvisit(
const Xor &x)
470 std::vector<fn> applys;
471 for (
const auto &p : x.
get_args()) {
472 applys.push_back(apply(*p));
475 result_ = [=](
const double *x) {
476 bool result = bool(applys[0](x));
477 for (
unsigned int i = 0; i < applys.size(); i++) {
478 result = result != bool(applys[i](x));
480 return double(result);
484 void bvisit(
const Not &x)
486 fn tmp = apply(*(x.get_arg()));
487 result_ = [=](
const double *x) {
return double(not
bool(tmp(x))); };
490 void bvisit(
const Max &x)
492 std::vector<fn> applys;
493 for (
const auto &p : x.
get_args()) {
494 applys.push_back(apply(*p));
497 result_ = [=](
const double *x) {
498 double result = applys[0](x);
499 for (
unsigned int i = 0; i < applys.size(); i++) {
500 result = std::max(result, applys[i](x));
506 void bvisit(
const Min &x)
508 std::vector<fn> applys;
509 for (
const auto &p : x.
get_args()) {
510 applys.push_back(apply(*p));
513 result_ = [=](
const double *x) {
514 double result = applys[0](x);
515 for (
unsigned int i = 0; i < applys.size(); i++) {
516 result = std::min(result, applys[i](x));
522 void bvisit(
const Sign &x)
524 fn tmp = apply(*(x.
get_arg()));
525 result_ = [=](
const double *x) {
526 return tmp(x) == 0.0 ? 0.0 : (tmp(x) < 0.0 ? -1.0 : 1.0);
530 void bvisit(
const Floor &x)
532 fn tmp = apply(*(x.
get_arg()));
533 result_ = [=](
const double *x) {
return std::floor(tmp(x)); };
538 fn tmp = apply(*(x.
get_arg()));
539 result_ = [=](
const double *x) {
return std::ceil(tmp(x)); };
544 fn tmp = apply(*(x.
get_arg()));
545 result_ = [=](
const double *x) {
return std::trunc(tmp(x)); };
548 void bvisit(
const Infty &x)
550 if (x.is_negative_infinity()) {
551 result_ = [=](
const double * ) {
552 return -std::numeric_limits<double>::infinity();
554 }
else if (x.is_positive_infinity()) {
555 result_ = [=](
const double * ) {
556 return std::numeric_limits<double>::infinity();
559 throw SymEngineException(
560 "LambdaDouble can only represent real valued infinity");
563 void bvisit(
const NaN &nan)
565 assert(&nan == &(*Nan) );
566 result_ = [](
const double * ) {
567 return std::numeric_limits<double>::signaling_NaN();
572 const auto fn_expr = apply(*cts.get_expr());
573 const auto set = cts.get_set();
574 if (is_a<Interval>(*set)) {
575 const auto &interv = down_cast<const Interval &>(*set);
576 const auto fn_start = apply(*interv.get_start());
577 const auto fn_end = apply(*interv.get_end());
578 const bool left_open = interv.get_left_open();
579 const bool right_open = interv.get_right_open();
580 result_ = [=](
const double *x) {
581 const auto val_expr = fn_expr(x);
582 const auto val_start = fn_start(x);
583 const auto val_end = fn_end(x);
584 bool left_ok, right_ok;
585 if (val_start == -std::numeric_limits<double>::infinity()) {
586 left_ok = !std::isnan(val_expr);
588 left_ok = (left_open) ? (val_start < val_expr)
589 : (val_start <= val_expr);
591 if (val_end == std::numeric_limits<double>::infinity()) {
592 right_ok = !std::isnan(val_expr);
594 right_ok = (right_open) ? (val_expr < val_end)
595 : (val_expr <= val_end);
597 return (left_ok && right_ok) ? 1.0 : 0.0;
600 throw SymEngineException(
"LambdaDoubleVisitor: only ``Interval`` "
601 "implemented for ``Contains``.");
607 const bool val = ba.get_val();
608 result_ = [=](
const double * ) {
return (val) ? 1.0 : 0.0; };
613 SYMENGINE_ASSERT_MSG(
614 eq(*pw.get_vec().back().second, *boolTrue),
615 "LambdaDouble requires a (Expr, True) at the end of Piecewise");
617 std::vector<fn> applys;
618 std::vector<fn> preds;
619 for (
const auto &expr_pred : pw.get_vec()) {
620 applys.push_back(apply(*expr_pred.first));
621 preds.push_back(apply(*expr_pred.second));
623 result_ = [=](
const double *x) {
624 for (
size_t i = 0;; ++i) {
625 if (preds[i](x) == 1.0) {
629 throw SymEngineException(
630 "Unexpectedly reached end of Piecewise function.");
637 LambdaDoubleVisitor<std::complex<double>>>
645 using LambdaDoubleVisitor::bvisit;
649 double t1 = mp_get_d(x.
real_), t2 = mp_get_d(x.imaginary_);
650 result_ = [=](
const std::complex<double> *x) {
651 return std::complex<double>(t1, t2);
657 std::complex<double> tmp = x.i;
658 result_ = [=](
const std::complex<double> *x) {
return tmp; };
660 #ifdef HAVE_SYMENGINE_MPC
663 mpfr_class t(x.get_prec());
665 mpc_real(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
666 real = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
667 mpc_imag(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
668 imag = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
669 std::complex<double> tmp(real, imag);
670 result_ = [=](
const std::complex<double> *x) {
return tmp; };
RCP< const Basic > get_den() const
RCP< const Basic > get_num() const
The base class for representing addition in symbolic expressions.
const RCP< const Number > & get_coef() const
vec_basic get_args() const override
Returns the list of arguments.
The lowest unit of symbolic representation.
Complex Double Class to hold std::complex<double> values.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
const integer_class & as_integer_class() const
Convert to integer_class.
vec_basic get_args() const override
Returns the list of arguments.
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_arg() const
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_exp() const
RCP< const Basic > get_base() const
const rational_class & as_rational_class() const
Convert to rational_class.
RealDouble Class to hold double values.
RCP< const Basic > get_arg1() const
RCP< const Basic > get_arg2() const
vec_basic get_args() const override
Returns the list of arguments.
Main namespace for SymEngine package.
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b