Program Listing for File lambda_double.h

Return to documentation for file (symengine/symengine/lambda_double.h)

#ifndef SYMENGINE_LAMBDA_DOUBLE_H
#define SYMENGINE_LAMBDA_DOUBLE_H

#include <cmath>
#include <limits>
#include <symengine/eval_double.h>
#include <symengine/symengine_exception.h>
#include <symengine/visitor.h>

namespace SymEngine
{

template <typename T>
class LambdaDoubleVisitor : public BaseVisitor<LambdaDoubleVisitor<T>>
{
protected:
    /*
       The 'result_' variable is assigned into at the very end of each visit()
       methods below. The only place where these methods are called from is the
       line 'b.accept(*this)' in apply() and the 'result_' is immediately
       returned. Thus no corruption can happen and apply() can be safely called
       recursively.
    */

    typedef std::function<T(const T *x)> fn;
    std::vector<fn> results;
    std::vector<T> cse_intermediate_results;

    std::map<RCP<const Basic>, size_t, RCPBasicKeyLess>
        cse_intermediate_fns_map;
    std::vector<fn> cse_intermediate_fns;
    fn result_;
    vec_basic symbols;

public:
    void init(const vec_basic &x, const Basic &b, bool cse = false)
    {
        vec_basic outputs = {b.rcp_from_this()};
        init(x, outputs, cse);
    }

    void init(const vec_basic &inputs, const vec_basic &outputs,
              bool cse = false)
    {
        results.clear();
        cse_intermediate_fns.clear();
        symbols = inputs;
        if (not cse) {
            for (auto &p : outputs) {
                results.push_back(apply(*p));
            }
        } else {
            vec_basic reduced_exprs;
            vec_pair replacements;
            // cse the outputs
            SymEngine::cse(replacements, reduced_exprs, outputs);
            for (auto &rep : replacements) {
                auto res = apply(*(rep.second));
                // Store the replacement symbol values in a dictionary for
                // faster
                // lookup for initialization
                cse_intermediate_fns_map[rep.first]
                    = cse_intermediate_fns.size();
                // Store it in a vector for faster use in call
                cse_intermediate_fns.push_back(res);
            }
            cse_intermediate_results.resize(cse_intermediate_fns.size());
            // Generate functions for all the reduced exprs and save it
            for (unsigned i = 0; i < outputs.size(); i++) {
                results.push_back(apply(*reduced_exprs[i]));
            }
            // We don't need the cse_intermediate_fns_map anymore
            cse_intermediate_fns_map.clear();
            symbols.clear();
        }
    }

    fn apply(const Basic &b)
    {
        b.accept(*this);
        return result_;
    }

    T call(const std::vector<T> &vec)
    {
        T res;
        call(&res, vec.data());
        return res;
    }

    void call(T *outs, const T *inps)
    {
        if (cse_intermediate_fns.size() > 0) {
            for (unsigned i = 0; i < cse_intermediate_fns.size(); ++i) {
                cse_intermediate_results[i] = cse_intermediate_fns[i](inps);
            }
        }
        for (unsigned i = 0; i < results.size(); ++i) {
            outs[i] = results[i](inps);
        }
        return;
    }

    void bvisit(const Symbol &x)
    {
        for (unsigned i = 0; i < symbols.size(); ++i) {
            if (eq(x, *symbols[i])) {
                result_ = [=](const T *x) { return x[i]; };
                return;
            }
        }
        auto it = cse_intermediate_fns_map.find(x.rcp_from_this());
        if (it != cse_intermediate_fns_map.end()) {
            auto index = it->second;
            result_
                = [=](const T *x) { return cse_intermediate_results[index]; };
            return;
        }
        throw SymEngineException("Symbol not in the symbols vector.");
    };

    void bvisit(const Integer &x)
    {
        T tmp = mp_get_d(x.as_integer_class());
        result_ = [=](const T *x_) { return tmp; };
    }

    void bvisit(const Rational &x)
    {
        T tmp = mp_get_d(x.as_rational_class());
        result_ = [=](const T *x) { return tmp; };
    }

    void bvisit(const RealDouble &x)
    {
        T tmp = x.i;
        result_ = [=](const T *x) { return tmp; };
    }

#ifdef HAVE_SYMENGINE_MPFR
    void bvisit(const RealMPFR &x)
    {
        T tmp = mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN);
        result_ = [=](const T *x) { return tmp; };
    }
#endif

    void bvisit(const Add &x)
    {
        fn tmp = apply(*x.get_coef());
        fn tmp1, tmp2;
        for (const auto &p : x.get_dict()) {
            tmp1 = apply(*(p.first));
            tmp2 = apply(*(p.second));
            tmp = [=](const T *x) { return tmp(x) + tmp1(x) * tmp2(x); };
        }
        result_ = tmp;
    }

    void bvisit(const Mul &x)
    {
        fn tmp = apply(*x.get_coef());
        fn tmp1, tmp2;
        for (const auto &p : x.get_dict()) {
            tmp1 = apply(*(p.first));
            tmp2 = apply(*(p.second));
            tmp = [=](const T *x) {
                return tmp(x) * std::pow(tmp1(x), tmp2(x));
            };
        }
        result_ = tmp;
    }

    void bvisit(const Pow &x)
    {
        fn exp_ = apply(*(x.get_exp()));
        if (eq(*(x.get_base()), *E)) {
            result_ = [=](const T *x) { return std::exp(exp_(x)); };
        } else {
            fn base_ = apply(*(x.get_base()));
            result_ = [=](const T *x) { return std::pow(base_(x), exp_(x)); };
        }
    }

    void bvisit(const Sin &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::sin(tmp(x)); };
    }

    void bvisit(const Cos &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::cos(tmp(x)); };
    }

    void bvisit(const Tan &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::tan(tmp(x)); };
    }

    void bvisit(const Log &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::log(tmp(x)); };
    };

    void bvisit(const Cot &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return 1.0 / std::tan(tmp(x)); };
    };

    void bvisit(const Csc &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return 1.0 / std::sin(tmp(x)); };
    };

    void bvisit(const Sec &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return 1.0 / std::cos(tmp(x)); };
    };

    void bvisit(const ASin &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::asin(tmp(x)); };
    };

    void bvisit(const ACos &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::acos(tmp(x)); };
    };

    void bvisit(const ASec &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::acos(1.0 / tmp(x)); };
    };

    void bvisit(const ACsc &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::asin(1.0 / tmp(x)); };
    };

    void bvisit(const ATan &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::atan(tmp(x)); };
    };

    void bvisit(const ACot &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::atan(1.0 / tmp(x)); };
    };

    void bvisit(const Sinh &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::sinh(tmp(x)); };
    };

    void bvisit(const Csch &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return 1.0 / std::sinh(tmp(x)); };
    };

    void bvisit(const Cosh &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::cosh(tmp(x)); };
    };

    void bvisit(const Sech &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return 1.0 / std::cosh(tmp(x)); };
    };

    void bvisit(const Tanh &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::tanh(tmp(x)); };
    };

    void bvisit(const Coth &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return 1.0 / std::tanh(tmp(x)); };
    };

    void bvisit(const ASinh &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::asinh(tmp(x)); };
    };

    void bvisit(const ACsch &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::asinh(1.0 / tmp(x)); };
    };

    void bvisit(const ACosh &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::acosh(tmp(x)); };
    };

    void bvisit(const ATanh &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::atanh(tmp(x)); };
    };

    void bvisit(const ACoth &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::atanh(1.0 / tmp(x)); };
    };

    void bvisit(const ASech &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::acosh(1.0 / tmp(x)); };
    };

    void bvisit(const Constant &x)
    {
        T tmp = eval_double(x);
        result_ = [=](const T *x) { return tmp; };
    };

    void bvisit(const Abs &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const T *x) { return std::abs(tmp(x)); };
    };

    void bvisit(const Basic &)
    {
        throw NotImplementedError("Not Implemented");
    };

    void bvisit(const UnevaluatedExpr &x)
    {
        apply(*x.get_arg());
    };
};

class LambdaRealDoubleVisitor
    : public BaseVisitor<LambdaRealDoubleVisitor, LambdaDoubleVisitor<double>>
{
public:
    // Classes not implemented are
    // Subs, UpperGamma, LowerGamma, Dirichlet_eta, Zeta
    // LeviCivita, KroneckerDelta, FunctionSymbol, LambertW
    // Derivative, Complex, ComplexDouble, ComplexMPC

    using LambdaDoubleVisitor::bvisit;

    void bvisit(const ATan2 &x)
    {
        fn num = apply(*(x.get_num()));
        fn den = apply(*(x.get_den()));
        result_ = [=](const double *x) { return std::atan2(num(x), den(x)); };
    };

    void bvisit(const Gamma &x)
    {
        fn tmp = apply(*(x.get_args()[0]));
        result_ = [=](const double *x) { return std::tgamma(tmp(x)); };
    };

    void bvisit(const LogGamma &x)
    {
        fn tmp = apply(*(x.get_args()[0]));
        result_ = [=](const double *x) { return std::lgamma(tmp(x)); };
    };

    void bvisit(const Erf &x)
    {
        fn tmp = apply(*(x.get_args()[0]));
        result_ = [=](const double *x) { return std::erf(tmp(x)); };
    }

    void bvisit(const Erfc &x)
    {
        fn tmp = apply(*(x.get_args()[0]));
        result_ = [=](const double *x) { return std::erfc(tmp(x)); };
    }

    void bvisit(const Equality &x)
    {
        fn lhs_ = apply(*(x.get_arg1()));
        fn rhs_ = apply(*(x.get_arg2()));
        result_ = [=](const double *x) { return (lhs_(x) == rhs_(x)); };
    }

    void bvisit(const Unequality &x)
    {
        fn lhs_ = apply(*(x.get_arg1()));
        fn rhs_ = apply(*(x.get_arg2()));
        result_ = [=](const double *x) { return (lhs_(x) != rhs_(x)); };
    }

    void bvisit(const LessThan &x)
    {
        fn lhs_ = apply(*(x.get_arg1()));
        fn rhs_ = apply(*(x.get_arg2()));
        result_ = [=](const double *x) { return (lhs_(x) <= rhs_(x)); };
    }

    void bvisit(const StrictLessThan &x)
    {
        fn lhs_ = apply(*(x.get_arg1()));
        fn rhs_ = apply(*(x.get_arg2()));
        result_ = [=](const double *x) { return (lhs_(x) < rhs_(x)); };
    }

    void bvisit(const And &x)
    {
        std::vector<fn> applys;
        for (const auto &p : x.get_args()) {
            applys.push_back(apply(*p));
        }

        result_ = [=](const double *x) {

            bool result = bool(applys[0](x));
            for (unsigned int i = 0; i < applys.size(); i++) {
                result = result && bool(applys[i](x));
            }
            return double(result);
        };
    }

    void bvisit(const Or &x)
    {
        std::vector<fn> applys;
        for (const auto &p : x.get_args()) {
            applys.push_back(apply(*p));
        }

        result_ = [=](const double *x) {

            bool result = bool(applys[0](x));
            for (unsigned int i = 0; i < applys.size(); i++) {
                result = result || bool(applys[i](x));
            }
            return double(result);
        };
    }

    void bvisit(const Xor &x)
    {
        std::vector<fn> applys;
        for (const auto &p : x.get_args()) {
            applys.push_back(apply(*p));
        }

        result_ = [=](const double *x) {

            bool result = bool(applys[0](x));
            for (unsigned int i = 0; i < applys.size(); i++) {
                result = result != bool(applys[i](x));
            }
            return double(result);
        };
    }

    void bvisit(const Not &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const double *x) { return double(not bool(tmp(x))); };
    }

    void bvisit(const Max &x)
    {
        std::vector<fn> applys;
        for (const auto &p : x.get_args()) {
            applys.push_back(apply(*p));
        }

        result_ = [=](const double *x) {

            double result = applys[0](x);
            for (unsigned int i = 0; i < applys.size(); i++) {
                result = std::max(result, applys[i](x));
            }
            return result;
        };
    };

    void bvisit(const Min &x)
    {
        std::vector<fn> applys;
        for (const auto &p : x.get_args()) {
            applys.push_back(apply(*p));
        }

        result_ = [=](const double *x) {

            double result = applys[0](x);
            for (unsigned int i = 0; i < applys.size(); i++) {
                result = std::min(result, applys[i](x));
            }
            return result;
        };
    };

    void bvisit(const Sign &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const double *x) {
            return tmp(x) == 0.0 ? 0.0 : (tmp(x) < 0.0 ? -1.0 : 1.0);
        };
    };

    void bvisit(const Floor &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const double *x) { return std::floor(tmp(x)); };
    };

    void bvisit(const Ceiling &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const double *x) { return std::ceil(tmp(x)); };
    };

    void bvisit(const Truncate &x)
    {
        fn tmp = apply(*(x.get_arg()));
        result_ = [=](const double *x) { return std::trunc(tmp(x)); };
    };

    void bvisit(const Infty &x)
    {
        if (x.is_negative_infinity()) {
            result_ = [=](const double * /* x */) {
                return -std::numeric_limits<double>::infinity();
            };
        } else if (x.is_positive_infinity()) {
            result_ = [=](const double * /* x */) {
                return std::numeric_limits<double>::infinity();
            };
        } else {
            throw SymEngineException(
                "LambdaDouble can only represent real valued infinity");
        }
    }

    void bvisit(const Contains &cts)
    {
        const auto fn_expr = apply(*cts.get_expr());
        const auto set = cts.get_set();
        if (is_a<Interval>(*set)) {
            const auto &interv = down_cast<const Interval &>(*set);
            const auto fn_start = apply(*interv.get_start());
            const auto fn_end = apply(*interv.get_end());
            const bool left_open = interv.get_left_open();
            const bool right_open = interv.get_right_open();
            result_ = [=](const double *x) {
                const auto val_expr = fn_expr(x);
                const auto val_start = fn_start(x);
                const auto val_end = fn_end(x);
                bool left_ok, right_ok;
                if (val_start == -std::numeric_limits<double>::infinity()) {
                    left_ok = !std::isnan(val_expr);
                } else {
                    left_ok = (left_open) ? (val_start < val_expr)
                                          : (val_start <= val_expr);
                }
                if (val_end == std::numeric_limits<double>::infinity()) {
                    right_ok = !std::isnan(val_expr);
                } else {
                    right_ok = (right_open) ? (val_expr < val_end)
                                            : (val_expr <= val_end);
                }
                return (left_ok && right_ok) ? 1.0 : 0.0;
            };
        } else {
            throw SymEngineException("LambdaDoubleVisitor: only ``Interval`` "
                                     "implemented for ``Contains``.");
        }
    }

    void bvisit(const BooleanAtom &ba)
    {
        const bool val = ba.get_val();
        result_ = [=](const double * /* x */) { return (val) ? 1.0 : 0.0; };
    }

    void bvisit(const Piecewise &pw)
    {
        SYMENGINE_ASSERT_MSG(
            eq(*pw.get_vec().back().second, *boolTrue),
            "LambdaDouble requires a (Expr, True) at the end of Piecewise");

        std::vector<fn> applys;
        std::vector<fn> preds;
        for (const auto &expr_pred : pw.get_vec()) {
            applys.push_back(apply(*expr_pred.first));
            preds.push_back(apply(*expr_pred.second));
        }
        result_ = [=](const double *x) {
            for (size_t i = 0;; ++i) {
                if (preds[i](x) == 1.0) {
                    return applys[i](x);
                }
            }
            throw SymEngineException(
                "Unexpectedly reached end of Piecewise function.");
        };
    }
};

class LambdaComplexDoubleVisitor
    : public BaseVisitor<LambdaComplexDoubleVisitor,
                         LambdaDoubleVisitor<std::complex<double>>>
{
public:
    // Classes not implemented are
    // Subs, UpperGamma, LowerGamma, Dirichlet_eta, Zeta
    // LeviCivita, KroneckerDelta, FunctionSymbol, LambertW
    // Derivative, ATan2, Gamma

    using LambdaDoubleVisitor::bvisit;

    void bvisit(const Complex &x)
    {
        double t1 = mp_get_d(x.real_), t2 = mp_get_d(x.imaginary_);
        result_ = [=](const std::complex<double> *x) {
            return std::complex<double>(t1, t2);
        };
    };

    void bvisit(const ComplexDouble &x)
    {
        std::complex<double> tmp = x.i;
        result_ = [=](const std::complex<double> *x) { return tmp; };
    };
#ifdef HAVE_SYMENGINE_MPC
    void bvisit(const ComplexMPC &x)
    {
        mpfr_class t(x.get_prec());
        double real, imag;
        mpc_real(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
        real = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
        mpc_imag(t.get_mpfr_t(), x.as_mpc().get_mpc_t(), MPFR_RNDN);
        imag = mpfr_get_d(t.get_mpfr_t(), MPFR_RNDN);
        std::complex<double> tmp(real, imag);
        result_ = [=](const std::complex<double> *x) { return tmp; };
    }
#endif
};
}
#endif // SYMENGINE_LAMBDA_DOUBLE_H