1 #ifndef SYMENGINE_LAMBDA_DOUBLE_H
2 #define SYMENGINE_LAMBDA_DOUBLE_H
7 #include <symengine/symengine_exception.h>
30 cse_intermediate_fns_map;
47 init(x, outputs, cse);
54 cse_intermediate_fns.
clear();
57 for (
auto &p : outputs) {
64 SymEngine::cse(replacements, reduced_exprs, outputs);
65 cse_intermediate_results.
resize(replacements.
size());
66 for (
auto &rep : replacements) {
67 auto res = apply(*(rep.second));
71 cse_intermediate_fns_map[rep.first]
72 = cse_intermediate_fns.
size();
77 for (
unsigned i = 0; i < outputs.
size(); i++) {
78 results.
push_back(apply(*reduced_exprs[i]));
81 cse_intermediate_fns_map.
clear();
95 call(&res, vec.
data());
99 void call(T *outs,
const T *inps)
101 if (cse_intermediate_fns.
size() > 0) {
102 for (
unsigned i = 0; i < cse_intermediate_fns.
size(); ++i) {
103 cse_intermediate_results[i] = cse_intermediate_fns[i](inps);
106 for (
unsigned i = 0; i < results.
size(); ++i) {
107 outs[i] = results[i](inps);
112 void bvisit(
const Symbol &x)
114 for (
unsigned i = 0; i < symbols.
size(); ++i) {
115 if (
eq(x, *symbols[i])) {
116 result_ = [=](
const T *x) {
return x[i]; };
121 if (it != cse_intermediate_fns_map.
end()) {
122 auto index = it->second;
123 T *cse_intermediate_result = &(cse_intermediate_results[index]);
124 result_ = [=](
const T *x) {
return *cse_intermediate_result; };
127 throw SymEngineException(
"Symbol not in the symbols vector.");
133 result_ = [=](
const T *x_) {
return tmp; };
139 result_ = [=](
const T *x) {
return tmp; };
145 result_ = [=](
const T *x) {
return tmp; };
148 #ifdef HAVE_SYMENGINE_MPFR
151 T tmp = mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN);
152 result_ = [=](
const T *x) {
return tmp; };
156 void bvisit(
const Add &x)
160 for (
const auto &p : x.get_dict()) {
161 tmp1 = apply(*(p.first));
162 tmp2 = apply(*(p.second));
163 tmp = [=](
const T *x) {
return tmp(x) + tmp1(x) * tmp2(x); };
168 void bvisit(
const Mul &x)
170 fn tmp = apply(*x.get_coef());
172 for (
const auto &p : x.get_dict()) {
173 tmp1 = apply(*(p.first));
174 tmp2 = apply(*(p.second));
175 tmp = [=](
const T *x) {
176 return tmp(x) *
std::pow(tmp1(x), tmp2(x));
182 void bvisit(
const Pow &x)
186 result_ = [=](
const T *x) {
return std::exp(exp_(x)); };
189 result_ = [=](
const T *x) {
return std::pow(base_(x), exp_(x)); };
193 void bvisit(
const Sin &x)
196 result_ = [=](
const T *x) {
return std::sin(tmp(x)); };
199 void bvisit(
const Cos &x)
202 result_ = [=](
const T *x) {
return std::cos(tmp(x)); };
205 void bvisit(
const Tan &x)
208 result_ = [=](
const T *x) {
return std::tan(tmp(x)); };
211 void bvisit(
const Log &x)
214 result_ = [=](
const T *x) {
return std::log(tmp(x)); };
217 void bvisit(
const Cot &x)
220 result_ = [=](
const T *x) {
return 1.0 /
std::tan(tmp(x)); };
223 void bvisit(
const Csc &x)
226 result_ = [=](
const T *x) {
return 1.0 /
std::sin(tmp(x)); };
229 void bvisit(
const Sec &x)
232 result_ = [=](
const T *x) {
return 1.0 /
std::cos(tmp(x)); };
235 void bvisit(
const ASin &x)
238 result_ = [=](
const T *x) {
return std::asin(tmp(x)); };
241 void bvisit(
const ACos &x)
244 result_ = [=](
const T *x) {
return std::acos(tmp(x)); };
247 void bvisit(
const ASec &x)
250 result_ = [=](
const T *x) {
return std::acos(1.0 / tmp(x)); };
253 void bvisit(
const ACsc &x)
256 result_ = [=](
const T *x) {
return std::asin(1.0 / tmp(x)); };
259 void bvisit(
const ATan &x)
262 result_ = [=](
const T *x) {
return std::atan(tmp(x)); };
265 void bvisit(
const ACot &x)
268 result_ = [=](
const T *x) {
return std::atan(1.0 / tmp(x)); };
271 void bvisit(
const Sinh &x)
274 result_ = [=](
const T *x) {
return std::sinh(tmp(x)); };
277 void bvisit(
const Csch &x)
280 result_ = [=](
const T *x) {
return 1.0 /
std::sinh(tmp(x)); };
283 void bvisit(
const Cosh &x)
286 result_ = [=](
const T *x) {
return std::cosh(tmp(x)); };
289 void bvisit(
const Sech &x)
292 result_ = [=](
const T *x) {
return 1.0 /
std::cosh(tmp(x)); };
295 void bvisit(
const Tanh &x)
298 result_ = [=](
const T *x) {
return std::tanh(tmp(x)); };
301 void bvisit(
const Coth &x)
304 result_ = [=](
const T *x) {
return 1.0 /
std::tanh(tmp(x)); };
307 void bvisit(
const ASinh &x)
310 result_ = [=](
const T *x) {
return std::asinh(tmp(x)); };
313 void bvisit(
const ACsch &x)
316 result_ = [=](
const T *x) {
return std::asinh(1.0 / tmp(x)); };
319 void bvisit(
const ACosh &x)
322 result_ = [=](
const T *x) {
return std::acosh(tmp(x)); };
325 void bvisit(
const ATanh &x)
328 result_ = [=](
const T *x) {
return std::atanh(tmp(x)); };
331 void bvisit(
const ACoth &x)
334 result_ = [=](
const T *x) {
return std::atanh(1.0 / tmp(x)); };
337 void bvisit(
const ASech &x)
340 result_ = [=](
const T *x) {
return std::acosh(1.0 / tmp(x)); };
345 T tmp = eval_double(x);
346 result_ = [=](
const T *x) {
return tmp; };
349 void bvisit(
const Abs &x)
352 result_ = [=](
const T *x) {
return std::abs(tmp(x)); };
355 void bvisit(
const Basic &)
357 throw NotImplementedError(
"Not Implemented");
367 :
public BaseVisitor<LambdaRealDoubleVisitor, LambdaDoubleVisitor<double>>
375 using LambdaDoubleVisitor::bvisit;
377 void bvisit(
const ATan2 &x)
381 result_ = [=](
const double *x) {
return std::atan2(num(x), den(x)); };
384 void bvisit(
const Gamma &x)
387 result_ = [=](
const double *x) {
return std::tgamma(tmp(x)); };
393 result_ = [=](
const double *x) {
return std::lgamma(tmp(x)); };
396 void bvisit(
const Erf &x)
399 result_ = [=](
const double *x) {
return std::erf(tmp(x)); };
402 void bvisit(
const Erfc &x)
405 result_ = [=](
const double *x) {
return std::erfc(tmp(x)); };
412 result_ = [=](
const double *x) {
return (lhs_(x) == rhs_(x)); };
419 result_ = [=](
const double *x) {
return (lhs_(x) != rhs_(x)); };
426 result_ = [=](
const double *x) {
return (lhs_(x) <= rhs_(x)); };
433 result_ = [=](
const double *x) {
return (lhs_(x) < rhs_(x)); };
436 void bvisit(
const And &x)
439 for (
const auto &p : x.
get_args()) {
443 result_ = [=](
const double *x) {
444 bool result = bool(applys[0](x));
445 for (
unsigned int i = 0; i < applys.
size(); i++) {
446 result = result && bool(applys[i](x));
448 return double(result);
452 void bvisit(
const Or &x)
455 for (
const auto &p : x.
get_args()) {
459 result_ = [=](
const double *x) {
460 bool result = bool(applys[0](x));
461 for (
unsigned int i = 0; i < applys.
size(); i++) {
462 result = result || bool(applys[i](x));
464 return double(result);
468 void bvisit(
const Xor &x)
471 for (
const auto &p : x.
get_args()) {
475 result_ = [=](
const double *x) {
476 bool result = bool(applys[0](x));
477 for (
unsigned int i = 0; i < applys.
size(); i++) {
478 result = result != bool(applys[i](x));
480 return double(result);
484 void bvisit(
const Not &x)
486 fn tmp = apply(*(x.get_arg()));
487 result_ = [=](
const double *x) {
return double(not
bool(tmp(x))); };
490 void bvisit(
const Max &x)
493 for (
const auto &p : x.
get_args()) {
497 result_ = [=](
const double *x) {
498 double result = applys[0](x);
499 for (
unsigned int i = 0; i < applys.
size(); i++) {
500 result =
std::max(result, applys[i](x));
506 void bvisit(
const Min &x)
509 for (
const auto &p : x.
get_args()) {
513 result_ = [=](
const double *x) {
514 double result = applys[0](x);
515 for (
unsigned int i = 0; i < applys.
size(); i++) {
516 result =
std::min(result, applys[i](x));
522 void bvisit(
const Sign &x)
525 result_ = [=](
const double *x) {
526 return tmp(x) == 0.0 ? 0.0 : (tmp(x) < 0.0 ? -1.0 : 1.0);
530 void bvisit(
const Floor &x)
533 result_ = [=](
const double *x) {
return std::floor(tmp(x)); };
539 result_ = [=](
const double *x) {
return std::ceil(tmp(x)); };
545 result_ = [=](
const double *x) {
return std::trunc(tmp(x)); };
548 void bvisit(
const Infty &x)
550 if (x.is_negative_infinity()) {
551 result_ = [=](
const double * ) {
554 }
else if (x.is_positive_infinity()) {
555 result_ = [=](
const double * ) {
559 throw SymEngineException(
560 "LambdaDouble can only represent real valued infinity");
563 void bvisit(
const NaN &nan)
565 assert(&nan == &(*Nan) );
566 result_ = [](
const double * ) {
572 const auto fn_expr = apply(*cts.get_expr());
573 const auto set = cts.get_set();
574 if (is_a<Interval>(*set)) {
575 const auto &interv = down_cast<const Interval &>(*set);
576 const auto fn_start = apply(*interv.get_start());
577 const auto fn_end = apply(*interv.get_end());
578 const bool left_open = interv.get_left_open();
579 const bool right_open = interv.get_right_open();
580 result_ = [=](
const double *x) {
581 const auto val_expr = fn_expr(x);
582 const auto val_start = fn_start(x);
583 const auto val_end = fn_end(x);
584 bool left_ok, right_ok;
588 left_ok = (left_open) ? (val_start < val_expr)
589 : (val_start <= val_expr);
594 right_ok = (right_open) ? (val_expr < val_end)
595 : (val_expr <= val_end);
597 return (left_ok && right_ok) ? 1.0 : 0.0;
600 throw SymEngineException(
"LambdaDoubleVisitor: only ``Interval`` "
601 "implemented for ``Contains``.");
607 const bool val = ba.get_val();
608 result_ = [=](
const double * ) {
return (val) ? 1.0 : 0.0; };
613 SYMENGINE_ASSERT_MSG(
614 eq(*pw.get_vec().
back().second, *boolTrue),
615 "LambdaDouble requires a (Expr, True) at the end of Piecewise");
619 for (
const auto &expr_pred : pw.get_vec()) {
620 applys.
push_back(apply(*expr_pred.first));
621 preds.
push_back(apply(*expr_pred.second));
623 result_ = [=](
const double *x) {
624 for (
size_t i = 0;; ++i) {
625 if (preds[i](x) == 1.0) {
629 throw SymEngineException(
630 "Unexpectedly reached end of Piecewise function.");
637 LambdaDoubleVisitor<std::complex<double>>>
645 using LambdaDoubleVisitor::bvisit;
649 double t1 = mp_get_d(x.
real_), t2 = mp_get_d(x.imaginary_);
660 #ifdef HAVE_SYMENGINE_MPC
663 mpfr_class t(x.get_prec());
665 mpc_real(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
666 real = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
667 mpc_imag(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
668 imag = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
RCP< const Basic > get_den() const
RCP< const Basic > get_num() const
The base class for representing addition in symbolic expressions.
const RCP< const Number > & get_coef() const
vec_basic get_args() const override
Returns the list of arguments.
The lowest unit of symbolic representation.
Complex Double Class to hold std::complex<double> values.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
const integer_class & as_integer_class() const
Convert to integer_class.
vec_basic get_args() const override
Returns the list of arguments.
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_arg() const
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_exp() const
RCP< const Basic > get_base() const
const rational_class & as_rational_class() const
Convert to rational_class.
RealDouble Class to hold double values.
RCP< const Basic > get_arg1() const
RCP< const Basic > get_arg2() const
vec_basic get_args() const override
Returns the list of arguments.
Main namespace for SymEngine package.
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
T signaling_NaN(T... args)