Program Listing for File eval_mpfr.cpp

Return to documentation for file (symengine/symengine/eval_mpfr.cpp)

#include <symengine/visitor.h>
#include <symengine/eval_mpfr.h>
#include <symengine/symengine_exception.h>

#ifdef HAVE_SYMENGINE_MPFR

namespace SymEngine
{

class EvalMPFRVisitor : public BaseVisitor<EvalMPFRVisitor>
{
protected:
    mpfr_rnd_t rnd_;
    mpfr_ptr result_;

public:
    EvalMPFRVisitor(mpfr_rnd_t rnd) : rnd_{rnd}
    {
    }

    void apply(mpfr_ptr result, const Basic &b)
    {
        mpfr_ptr tmp = result_;
        result_ = result;
        b.accept(*this);
        result_ = tmp;
    }

    void bvisit(const Integer &x)
    {
        mpfr_set_z(result_, get_mpz_t(x.as_integer_class()), rnd_);
    }

    void bvisit(const Rational &x)
    {
        mpfr_set_q(result_, get_mpq_t(x.as_rational_class()), rnd_);
    }

    void bvisit(const RealDouble &x)
    {
        mpfr_set_d(result_, x.i, rnd_);
    }

    void bvisit(const RealMPFR &x)
    {
        mpfr_set(result_, x.i.get_mpfr_t(), rnd_);
    }

    void bvisit(const Add &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        auto d = x.get_args();
        auto p = d.begin();
        apply(result_, *(*p));
        p++;
        for (; p != d.end(); p++) {
            apply(t.get_mpfr_t(), *(*p));
            mpfr_add(result_, result_, t.get_mpfr_t(), rnd_);
        }
    }

    void bvisit(const Mul &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        auto d = x.get_args();
        auto p = d.begin();
        apply(result_, *(*p));
        p++;
        for (; p != d.end(); p++) {
            apply(t.get_mpfr_t(), *(*p));
            mpfr_mul(result_, result_, t.get_mpfr_t(), rnd_);
        }
    }

    void bvisit(const Pow &x)
    {
        if (eq(*x.get_base(), *E)) {
            apply(result_, *(x.get_exp()));
            mpfr_exp(result_, result_, rnd_);
        } else {
            mpfr_class b(mpfr_get_prec(result_));
            apply(b.get_mpfr_t(), *(x.get_base()));
            apply(result_, *(x.get_exp()));
            mpfr_pow(result_, b.get_mpfr_t(), result_, rnd_);
        }
    }

    void bvisit(const Equality &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        apply(t.get_mpfr_t(), *(x.get_arg1()));
        apply(result_, *(x.get_arg2()));
        if (mpfr_equal_p(t.get_mpfr_t(), result_)) {
            mpfr_set_ui(result_, 1, rnd_);
        } else {
            mpfr_set_ui(result_, 0, rnd_);
        }
    }

    void bvisit(const Unequality &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        apply(t.get_mpfr_t(), *(x.get_arg1()));
        apply(result_, *(x.get_arg2()));
        if (mpfr_lessgreater_p(t.get_mpfr_t(), result_)) {
            mpfr_set_ui(result_, 1, rnd_);
        } else {
            mpfr_set_ui(result_, 0, rnd_);
        }
    }

    void bvisit(const LessThan &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        apply(t.get_mpfr_t(), *(x.get_arg1()));
        apply(result_, *(x.get_arg2()));
        if (mpfr_lessequal_p(t.get_mpfr_t(), result_)) {
            mpfr_set_ui(result_, 1, rnd_);
        } else {
            mpfr_set_ui(result_, 0, rnd_);
        }
    }

    void bvisit(const StrictLessThan &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        apply(t.get_mpfr_t(), *(x.get_arg1()));
        apply(result_, *(x.get_arg2()));
        if (mpfr_less_p(t.get_mpfr_t(), result_)) {
            mpfr_set_ui(result_, 1, rnd_);
        } else {
            mpfr_set_ui(result_, 0, rnd_);
        }
    }

    void bvisit(const Sin &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_sin(result_, result_, rnd_);
    }

    void bvisit(const Cos &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_cos(result_, result_, rnd_);
    }

    void bvisit(const Tan &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_tan(result_, result_, rnd_);
    }

    void bvisit(const Log &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_log(result_, result_, rnd_);
    }

    void bvisit(const Cot &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_cot(result_, result_, rnd_);
    }

    void bvisit(const Csc &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_csc(result_, result_, rnd_);
    }

    void bvisit(const Sec &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_sec(result_, result_, rnd_);
    }

    void bvisit(const ASin &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_asin(result_, result_, rnd_);
    }

    void bvisit(const ACos &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_acos(result_, result_, rnd_);
    }

    void bvisit(const ASec &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_ui_div(result_, 1, result_, rnd_);
        mpfr_asin(result_, result_, rnd_);
    }

    void bvisit(const ACsc &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_ui_div(result_, 1, result_, rnd_);
        mpfr_acos(result_, result_, rnd_);
    }

    void bvisit(const ATan &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_atan(result_, result_, rnd_);
    }

    void bvisit(const ACot &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_ui_div(result_, 1, result_, rnd_);
        mpfr_atan(result_, result_, rnd_);
    }

    void bvisit(const ATan2 &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        apply(t.get_mpfr_t(), *(x.get_num()));
        apply(result_, *(x.get_den()));
        mpfr_atan2(result_, t.get_mpfr_t(), result_, rnd_);
    }

    void bvisit(const Sinh &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_sinh(result_, result_, rnd_);
    }

    void bvisit(const Csch &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_csch(result_, result_, rnd_);
    }

    void bvisit(const Cosh &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_cosh(result_, result_, rnd_);
    }

    void bvisit(const Sech &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_sech(result_, result_, rnd_);
    }

    void bvisit(const Tanh &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_tanh(result_, result_, rnd_);
    }

    void bvisit(const Coth &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_coth(result_, result_, rnd_);
    }

    void bvisit(const ASinh &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_asinh(result_, result_, rnd_);
    }

    void bvisit(const ACsch &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_ui_div(result_, 1, result_, rnd_);
        mpfr_asinh(result_, result_, rnd_);
    };

    void bvisit(const ACosh &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_acosh(result_, result_, rnd_);
    }

    void bvisit(const ATanh &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_atanh(result_, result_, rnd_);
    }

    void bvisit(const ACoth &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_ui_div(result_, 1, result_, rnd_);
        mpfr_atanh(result_, result_, rnd_);
    }

    void bvisit(const ASech &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_ui_div(result_, 1, result_, rnd_);
        mpfr_acosh(result_, result_, rnd_);
    };

    void bvisit(const Gamma &x)
    {
        apply(result_, *(x.get_args()[0]));
        mpfr_gamma(result_, result_, rnd_);
    };
#if MPFR_VERSION_MAJOR > 3
    void bvisit(const UpperGamma &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        apply(result_, *(x.get_args()[1]));
        apply(t.get_mpfr_t(), *(x.get_args()[0]));
        mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
    };

    void bvisit(const LowerGamma &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        apply(result_, *(x.get_args()[1]));
        apply(t.get_mpfr_t(), *(x.get_args()[0]));
        mpfr_gamma_inc(result_, t.get_mpfr_t(), result_, rnd_);
        mpfr_gamma(t.get_mpfr_t(), t.get_mpfr_t(), rnd_);
        mpfr_sub(result_, t.get_mpfr_t(), result_, rnd_);
    };
#endif
    void bvisit(const LogGamma &x)
    {
        apply(result_, *(x.get_args()[0]));
        mpfr_lngamma(result_, result_, rnd_);
    }

    void bvisit(const Beta &x)
    {
        apply(result_, *(x.rewrite_as_gamma()));
    };

    void bvisit(const Constant &x)
    {
        if (x.__eq__(*pi)) {
            mpfr_const_pi(result_, rnd_);
        } else if (x.__eq__(*E)) {
            mpfr_t one_;
            mpfr_init2(one_, mpfr_get_prec(result_));
            mpfr_set_ui(one_, 1, rnd_);
            mpfr_exp(result_, one_, rnd_);
            mpfr_clear(one_);
        } else if (x.__eq__(*EulerGamma)) {
            mpfr_const_euler(result_, rnd_);
        } else if (x.__eq__(*Catalan)) {
            mpfr_const_catalan(result_, rnd_);
        } else if (x.__eq__(*GoldenRatio)) {
            mpfr_sqrt_ui(result_, 5, rnd_);
            mpfr_add_ui(result_, result_, 1, rnd_);
            mpfr_div_ui(result_, result_, 2, rnd_);
        } else {
            throw NotImplementedError("Constant " + x.get_name()
                                      + " is not implemented.");
        }
    }

    void bvisit(const Abs &x)
    {
        apply(result_, *(x.get_arg()));
        mpfr_abs(result_, result_, rnd_);
    };

    void bvisit(const NumberWrapper &x)
    {
        x.eval(mpfr_get_prec(result_))->accept(*this);
    }

    void bvisit(const FunctionWrapper &x)
    {
        x.eval(mpfr_get_prec(result_))->accept(*this);
    }
    void bvisit(const Erf &x)
    {
        apply(result_, *(x.get_args()[0]));
        mpfr_erf(result_, result_, rnd_);
    }

    void bvisit(const Erfc &x)
    {
        apply(result_, *(x.get_args()[0]));
        mpfr_erfc(result_, result_, rnd_);
    }

    void bvisit(const Max &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        auto d = x.get_args();
        auto p = d.begin();
        apply(result_, *(*p));
        p++;
        for (; p != d.end(); p++) {
            apply(t.get_mpfr_t(), *(*p));
            mpfr_max(result_, result_, t.get_mpfr_t(), rnd_);
        }
    }

    void bvisit(const Min &x)
    {
        mpfr_class t(mpfr_get_prec(result_));
        auto d = x.get_args();
        auto p = d.begin();
        apply(result_, *(*p));
        p++;
        for (; p != d.end(); p++) {
            apply(t.get_mpfr_t(), *(*p));
            mpfr_min(result_, result_, t.get_mpfr_t(), rnd_);
        }
    }

    void bvisit(const UnevaluatedExpr &x)
    {
        apply(result_, *x.get_arg());
    }

    // Classes not implemented are
    // Subs, Dirichlet_eta, Zeta
    // LeviCivita, KroneckerDelta, LambertW
    // Derivative, Complex, ComplexDouble, ComplexMPC
    void bvisit(const Basic &)
    {
        throw NotImplementedError("Not Implemented");
    };
};

void eval_mpfr(mpfr_ptr result, const Basic &b, mpfr_rnd_t rnd)
{
    EvalMPFRVisitor v(rnd);
    v.apply(result, b);
}

} // SymEngine

#endif // HAVE_SYMENGINE_MPFR