Program Listing for File llvm_double.cpp

Return to documentation for file (symengine/symengine/llvm_double.cpp)

#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/Passes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/GenericValue.h"
#include "llvm/ExecutionEngine/MCJIT.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Vectorize.h"
#include "llvm/ExecutionEngine/ObjectCache.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include <algorithm>
#include <cassert>
#include <memory>
#include <vector>
#include <fstream>

#if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9)                       \
    || (LLVM_VERSION_MAJOR > 3)
#include <llvm/Transforms/Scalar/GVN.h>
#endif

#if (LLVM_VERSION_MAJOR >= 7)
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar/InstSimplifyPass.h>
#include <llvm/Transforms/Utils.h>
#endif

#include <symengine/llvm_double.h>
#include <symengine/eval_double.h>
#include <symengine/eval.h>

namespace SymEngine
{

#if (LLVM_VERSION_MAJOR >= 10)
using std::make_unique;
#else
using llvm::make_unique;
#endif

class IRBuilder : public llvm::IRBuilder<>
{
};

llvm::Value *LLVMVisitor::apply(const Basic &b)
{
    b.accept(*this);
    return result_;
}

void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
                       unsigned opt_level)
{
    init(x, b, symbolic_cse, LLVMVisitor::create_default_passes(opt_level),
         opt_level);
}

void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
                       const std::vector<llvm::Pass *> &passes,
                       unsigned opt_level)
{
    init(x, {b.rcp_from_this()}, symbolic_cse, passes, opt_level);
}

llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
{
    std::vector<llvm::Type *> inp;
    for (int i = 0; i < 2; i++) {
        inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
    }
    llvm::FunctionType *function_type = llvm::FunctionType::get(
        llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
    auto F = llvm::Function::Create(function_type,
                                    llvm::Function::InternalLinkage, "", mod);
    F->setCallingConv(llvm::CallingConv::C);
#if (LLVM_VERSION_MAJOR < 5)
    {
        llvm::SmallVector<llvm::AttributeSet, 4> attrs;
        llvm::AttributeSet PAS;
        {
            llvm::AttrBuilder B;
            B.addAttribute(llvm::Attribute::ReadOnly);
            B.addAttribute(llvm::Attribute::NoCapture);
            PAS = llvm::AttributeSet::get(mod->getContext(), 1U, B);
        }

        attrs.push_back(PAS);
        {
            llvm::AttrBuilder B;
            B.addAttribute(llvm::Attribute::NoCapture);
            PAS = llvm::AttributeSet::get(mod->getContext(), 2U, B);
        }

        attrs.push_back(PAS);
        {
            llvm::AttrBuilder B;
            B.addAttribute(llvm::Attribute::NoUnwind);
            B.addAttribute(llvm::Attribute::UWTable);
            PAS = llvm::AttributeSet::get(mod->getContext(), ~0U, B);
        }

        attrs.push_back(PAS);

        F->setAttributes(llvm::AttributeSet::get(mod->getContext(), attrs));
    }
#else
    F->addParamAttr(0, llvm::Attribute::ReadOnly);
    F->addParamAttr(0, llvm::Attribute::NoCapture);
    F->addParamAttr(1, llvm::Attribute::NoCapture);
    F->addFnAttr(llvm::Attribute::NoUnwind);
    F->addFnAttr(llvm::Attribute::UWTable);
#endif
    return F;
}

std::vector<llvm::Pass *> LLVMVisitor::create_default_passes(int optlevel)
{
    std::vector<llvm::Pass *> passes;
    if (optlevel == 0) {
        return passes;
    }
#if (LLVM_VERSION_MAJOR < 4)
    passes.push_back(llvm::createInstructionCombiningPass());
#else
    passes.push_back(llvm::createInstructionCombiningPass(optlevel > 1));
#endif
    passes.push_back(llvm::createDeadCodeEliminationPass());
    passes.push_back(llvm::createPromoteMemoryToRegisterPass());
    passes.push_back(llvm::createReassociatePass());
    passes.push_back(llvm::createGVNPass());
    passes.push_back(llvm::createCFGSimplificationPass());
    passes.push_back(llvm::createPartiallyInlineLibCallsPass());
#if (LLVM_VERSION_MAJOR < 5)
    passes.push_back(llvm::createLoadCombinePass());
#endif
#if LLVM_VERSION_MAJOR >= 7
    passes.push_back(llvm::createInstSimplifyLegacyPass());
#else
    passes.push_back(llvm::createInstructionSimplifierPass());
#endif
    passes.push_back(llvm::createMemCpyOptPass());
    passes.push_back(llvm::createSROAPass());
    passes.push_back(llvm::createMergedLoadStoreMotionPass());
    passes.push_back(llvm::createBitTrackingDCEPass());
    passes.push_back(llvm::createAggressiveDCEPass());
    if (optlevel > 2) {
        passes.push_back(llvm::createSLPVectorizerPass());
#if LLVM_VERSION_MAJOR >= 7
        passes.push_back(llvm::createInstSimplifyLegacyPass());
#else
        passes.push_back(llvm::createInstructionSimplifierPass());
#endif
    }
    return passes;
}

void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
                       const bool symbolic_cse, unsigned opt_level)
{
    init(inputs, outputs, symbolic_cse,
         LLVMVisitor::create_default_passes(opt_level), opt_level);
}

void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
                       const bool symbolic_cse,
                       const std::vector<llvm::Pass *> &passes,
                       unsigned opt_level)
{
    executionengine.reset();
    llvm::InitializeNativeTarget();
    llvm::InitializeNativeTargetAsmPrinter();
    llvm::InitializeNativeTargetAsmParser();
    context = std::make_shared<llvm::LLVMContext>();
    symbols = inputs;

    // Create some module to put our function into it.
    std::unique_ptr<llvm::Module> module
        = make_unique<llvm::Module>("SymEngine", *context.get());
    module->setDataLayout("");
    mod = module.get();

    // Create a new pass manager attached to it.
    fpm = std::make_shared<llvm::legacy::FunctionPassManager>(mod);
    for (auto pass : passes) {
        fpm->add(pass);
    }
    fpm->doInitialization();

    auto F = get_function_type(context.get());

    // Add a basic block to the function. As before, it automatically
    // inserts
    // because of the last argument.
    llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);

    // Create a basic block builder with default parameters.  The builder
    // will
    // automatically append instructions to the basic block `BB'.
    llvm::IRBuilder<> _builder(BB);
    builder = reinterpret_cast<IRBuilder *>(&_builder);
    builder->SetInsertPoint(BB);
    auto fmf = llvm::FastMathFlags();
    // fmf.setUnsafeAlgebra();
    builder->setFastMathFlags(fmf);

    // Load all the symbols and create references
    auto input_arg = &(*(F->args().begin()));
    for (unsigned i = 0; i < inputs.size(); i++) {
        if (not is_a<Symbol>(*inputs[i])) {
            throw SymEngineException("Input contains a non-symbol.");
        }
        auto index
            = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
        auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
                                      index);
        result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
        symbol_ptrs.push_back(result_);
    }

    auto it = F->args().begin();
#if (LLVM_VERSION_MAJOR < 5)
    auto out = &(*(++it));
#else
    auto out = &(*(it + 1));
#endif
    std::vector<llvm::Value *> output_vals;

    if (symbolic_cse) {
        vec_basic reduced_exprs;
        vec_pair replacements;
        // cse the outputs
        SymEngine::cse(replacements, reduced_exprs, outputs);
        for (auto &rep : replacements) {
            // Store the replacement symbol values in a dictionary
            replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
        }
        // Generate IR for all the reduced exprs and save references
        for (unsigned i = 0; i < outputs.size(); i++) {
            output_vals.push_back(apply(*reduced_exprs[i]));
        }
    } else {
        // Generate IR for all the output exprs and save references
        for (unsigned i = 0; i < outputs.size(); i++) {
            output_vals.push_back(apply(*outputs[i]));
        }
    }

    // Store all the output exprs at the end
    for (unsigned i = 0; i < outputs.size(); i++) {
        auto index
            = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
        auto ptr
            = builder->CreateGEP(get_float_type(context.get()), out, index);
        builder->CreateStore(output_vals[i], ptr);
    }

    // Create the return instruction and add it to the basic block
    builder->CreateRetVoid();

    // Validate the generated code, checking for consistency.
    llvm::verifyFunction(*F);

    //     std::cout << "LLVM IR" << std::endl;
    // #if (LLVM_VERSION_MAJOR < 5)
    //     module->dump();
    // #else
    //     module->print(llvm::errs(), nullptr);
    // #endif

    // Optimize the function.
    fpm->run(*F);

    // std::cout << "Optimized LLVM IR" << std::endl;
    // module->dump();

    // Now we create the JIT.
    std::string error;
    executionengine = std::shared_ptr<llvm::ExecutionEngine>(
        llvm::EngineBuilder(std::move(module))
            .setEngineKind(llvm::EngineKind::Kind::JIT)
            .setOptLevel(static_cast<llvm::CodeGenOpt::Level>(opt_level))
            .setErrorStr(&error)
            .create());

    // This is a hack to get the MemoryBuffer of a compiled object.
    class MemoryBufferRefCallback : public llvm::ObjectCache
    {
    public:
        std::string &ss_;
        MemoryBufferRefCallback(std::string &ss) : ss_(ss)
        {
        }

        virtual void notifyObjectCompiled(const llvm::Module *M,
                                          llvm::MemoryBufferRef obj)
        {
            const char *c = obj.getBufferStart();
            // Saving the object code in a std::string
            ss_.assign(c, obj.getBufferSize());
        }

        virtual std::unique_ptr<llvm::MemoryBuffer>
        getObject(const llvm::Module *M)
        {
            return NULL;
        }
    };

    MemoryBufferRefCallback callback(membuffer);
    executionengine->setObjectCache(&callback);
    // std::cout << error << std::endl;
    executionengine->finalizeObject();

    // Get the symbol's address
    func = (intptr_t)executionengine->getPointerToFunction(F);
    symbol_ptrs.clear();
    replacement_symbol_ptrs.clear();
    symbols.clear();
}

double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
{
    double ret;
    ((double (*)(const double *, double *))func)(vec.data(), &ret);
    return ret;
}

void LLVMDoubleVisitor::call(double *outs, const double *inps) const
{
    ((double (*)(const double *, double *))func)(inps, outs);
}

#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
long double
LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
{
    long double ret;
    ((long double (*)(const long double *, long double *))func)(vec.data(),
                                                                &ret);
    return ret;
}

void LLVMLongDoubleVisitor::call(long double *outs,
                                 const long double *inps) const
{
    ((long double (*)(const long double *, long double *))func)(inps, outs);
}
#endif

float LLVMFloatVisitor::call(const std::vector<float> &vec) const
{
    float ret;
    ((float (*)(const float *, float *))func)(vec.data(), &ret);
    return ret;
}

void LLVMFloatVisitor::call(float *outs, const float *inps) const
{
    ((float (*)(const float *, float *))func)(inps, outs);
}

void LLVMVisitor::set_double(double d)
{
    result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
}

void LLVMVisitor::bvisit(const Integer &x)
{
    result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
                                    mp_get_d(x.as_integer_class()));
}

#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
{
#ifndef HAVE_SYMENGINE_MPFR
    throw NotImplementedError("Cannot convert to long double without MPFR");
#else
    RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
    result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
                                    m->__str__());
#endif
}

void LLVMLongDoubleVisitor::visit(const Integer &x)
{
    result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
                                    x.__str__());
}
#endif

void LLVMVisitor::bvisit(const Rational &x)
{
    set_double(mp_get_d(x.as_rational_class()));
}

#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
void LLVMLongDoubleVisitor::visit(const Rational &x)
{
    convert_from_mpfr(x);
}
#endif

void LLVMVisitor::bvisit(const RealDouble &x)
{
    set_double(x.i);
}

#ifdef HAVE_SYMENGINE_MPFR
void LLVMVisitor::bvisit(const RealMPFR &x)
{
    set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
}
#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
{
    convert_from_mpfr(x);
}
#endif
#endif

void LLVMVisitor::bvisit(const Add &x)
{
    llvm::Value *tmp, *tmp1, *tmp2;
    auto it = x.get_dict().begin();

    if (eq(*x.get_coef(), *zero)) {
        // `x + 0.0` is not optimized out
        if (eq(*one, *(it->second))) {
            tmp = apply(*(it->first));
        } else {
            tmp1 = apply(*(it->first));
            tmp2 = apply(*(it->second));
            tmp = builder->CreateFMul(tmp1, tmp2);
        }
        ++it;
    } else {
        tmp = apply(*x.get_coef());
    }

    for (; it != x.get_dict().end(); ++it) {
        if (eq(*one, *(it->second))) {
            tmp1 = apply(*(it->first));
            tmp = builder->CreateFAdd(tmp, tmp1);
        } else {
            //    std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
            //    tmp =
            //    builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
            //    llvm::Intrinsic::fma,
            //    3, context), args);
            tmp1 = apply(*(it->first));
            tmp2 = apply(*(it->second));
            tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
        }
    }
    result_ = tmp;
}

void LLVMVisitor::bvisit(const Mul &x)
{
    llvm::Value *tmp = nullptr;
    bool first = true;
    for (const auto &p : x.get_args()) {
        if (first) {
            tmp = apply(*p);
        } else {
            tmp = builder->CreateFMul(tmp, apply(*p));
        }
        first = false;
    }
    result_ = tmp;
}

llvm::Function *LLVMVisitor::get_powi()
{
    std::vector<llvm::Type *> arg_type;
    arg_type.push_back(get_float_type(&mod->getContext()));
    arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
    return llvm::Intrinsic::getDeclaration(mod, llvm::Intrinsic::powi,
                                           arg_type);
}

llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
                                    unsigned n, llvm::Module *mod)
{
    std::vector<llvm::Type *> arg_type(n, type);
    return llvm::Intrinsic::getDeclaration(mod, id, arg_type);
}

void LLVMVisitor::bvisit(const Pow &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    if (eq(*(x.get_base()), *E)) {
        args.push_back(apply(*x.get_exp()));
        fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                                  llvm::Intrinsic::exp, 1, mod);

    } else if (eq(*(x.get_base()), *integer(2))) {
        args.push_back(apply(*x.get_exp()));
        fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                                  llvm::Intrinsic::exp2, 1, mod);

    } else {
        if (is_a<Integer>(*x.get_exp())) {
            if (eq(*x.get_exp(), *integer(2))) {
                llvm::Value *tmp = apply(*x.get_base());
                result_ = builder->CreateFMul(tmp, tmp);
                return;
            } else {
                args.push_back(apply(*x.get_base()));
                int d = numeric_cast<int>(
                    mp_get_si(static_cast<const Integer &>(*x.get_exp())
                                  .as_integer_class()));
                result_ = llvm::ConstantInt::get(
                    llvm::Type::getInt32Ty(mod->getContext()), d, true);
                args.push_back(result_);
                fun = get_powi();
            }
        } else {
            args.push_back(apply(*x.get_base()));
            args.push_back(apply(*x.get_exp()));
            fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                                      llvm::Intrinsic::pow, 2, mod);
        }
    }
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

void LLVMVisitor::bvisit(const Sin &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    args.push_back(apply(*x.get_arg()));
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::sin, 1, mod);
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

void LLVMVisitor::bvisit(const Cos &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    args.push_back(apply(*x.get_arg()));
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::cos, 1, mod);
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

void LLVMVisitor::bvisit(const Piecewise &x)
{
    std::vector<llvm::BasicBlock> blocks;

    RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();

    if (neq(*pw->get_vec().back().second, *boolTrue)) {
        throw SymEngineException(
            "LLVMDouble requires a (Expr, True) at the end of Piecewise");
    }

    if (pw->get_vec().size() > 2) {
        PiecewiseVec rest = pw->get_vec();
        rest.erase(rest.begin());
        auto rest_pw = piecewise(std::move(rest));
        PiecewiseVec new_pw;
        new_pw.push_back(*pw->get_vec().begin());
        new_pw.push_back({rest_pw, pw->get_vec().back().second});
        pw = piecewise(std::move(new_pw))
                 ->rcp_from_this_cast<const Piecewise>();
    } else if (pw->get_vec().size() < 2) {
        throw SymEngineException("Invalid Piecewise object");
    }

    auto cond_basic = pw->get_vec().front().second;
    llvm::Value *cond = apply(*cond_basic);
    // check if cond != 0.0
    cond = builder->CreateFCmpONE(
        cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
        "ifcond");
    llvm::Function *function = builder->GetInsertBlock()->getParent();
    // Create blocks for the then and else cases.  Insert the 'then' block at
    // the
    // end of the function.
    llvm::BasicBlock *then_bb
        = llvm::BasicBlock::Create(mod->getContext(), "then", function);
    llvm::BasicBlock *else_bb
        = llvm::BasicBlock::Create(mod->getContext(), "else");
    llvm::BasicBlock *merge_bb
        = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
    builder->CreateCondBr(cond, then_bb, else_bb);

    // Emit then value.
    builder->SetInsertPoint(then_bb);
    llvm::Value *then_value = apply(*pw->get_vec().front().first);
    builder->CreateBr(merge_bb);

    // Codegen of 'then_value' can change the current block, update then_bb for
    // the PHI.
    then_bb = builder->GetInsertBlock();

    // Emit else block.
    function->getBasicBlockList().push_back(else_bb);
    builder->SetInsertPoint(else_bb);
    llvm::Value *else_value = apply(*pw->get_vec().back().first);
    builder->CreateBr(merge_bb);

    // Codegen of 'else_value' can change the current block, update else_bb for
    // the PHI.
    else_bb = builder->GetInsertBlock();

    // Emit merge block.
    function->getBasicBlockList().push_back(merge_bb);
    builder->SetInsertPoint(merge_bb);
    llvm::PHINode *phi_node
        = builder->CreatePHI(get_float_type(&mod->getContext()), 2);

    phi_node->addIncoming(then_value, then_bb);
    phi_node->addIncoming(else_value, else_bb);
    result_ = phi_node;
}

void LLVMVisitor::bvisit(const Sign &x)
{
    const auto x2 = x.get_arg();
    PiecewiseVec new_pw;
    new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
    new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
    new_pw.push_back({real_double(1.0), boolTrue});
    auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
    bvisit(*pw);
}

void LLVMVisitor::bvisit(const Contains &cts)
{
    llvm::Value *expr = apply(*cts.get_expr());
    const auto set = cts.get_set();
    if (is_a<Interval>(*set)) {
        const auto &interv = down_cast<const Interval &>(*set);
        llvm::Value *start = apply(*interv.get_start());
        llvm::Value *end = apply(*interv.get_end());
        const bool left_open = interv.get_left_open();
        const bool right_open = interv.get_right_open();
        llvm::Value *left_ok;
        llvm::Value *right_ok;
        left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
                              : builder->CreateFCmpOLE(start, expr);
        right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
                                : builder->CreateFCmpOLE(expr, end);
        result_ = builder->CreateAnd(left_ok, right_ok);
        result_ = builder->CreateUIToFP(result_,
                                        get_float_type(&mod->getContext()));
    } else {
        throw SymEngineException("LLVMVisitor: only ``Interval`` "
                                 "implemented for ``Contains``.");
    }
}

void LLVMVisitor::bvisit(const Infty &x)
{
    if (x.is_negative_infinity()) {
        result_ = llvm::ConstantFP::getInfinity(
            get_float_type(&mod->getContext()), true);
    } else if (x.is_positive_infinity()) {
        result_ = llvm::ConstantFP::getInfinity(
            get_float_type(&mod->getContext()), false);
    } else {
        throw SymEngineException(
            "LLVMDouble can only represent real valued infinity");
    }
}

void LLVMVisitor::bvisit(const BooleanAtom &x)
{
    const bool val = x.get_val();
    set_double(val ? 1.0 : 0.0);
}

void LLVMVisitor::bvisit(const Log &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    args.push_back(apply(*x.get_arg()));
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::log, 1, mod);
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

#define SYMENGINE_LOGIC_FUNCTION(Class, method)                                \
    void LLVMVisitor::bvisit(const Class &x)                                   \
    {                                                                          \
        llvm::Value *value = nullptr;                                          \
        llvm::Value *tmp;                                                      \
        set_double(0.0);                                                       \
        llvm::Value *zero_val = result_;                                       \
        for (auto &p : x.get_container()) {                                    \
            tmp = builder->CreateFCmpONE(apply(*p), zero_val);                 \
            if (value == nullptr) {                                            \
                value = tmp;                                                   \
            } else {                                                           \
                value = builder->method(value, tmp);                           \
            }                                                                  \
        }                                                                      \
        result_ = builder->CreateUIToFP(value,                                 \
                                        get_float_type(&mod->getContext()));   \
    }

SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);

void LLVMVisitor::bvisit(const Not &x)
{
    builder->CreateNot(apply(*x.get_arg()));
}

#define SYMENGINE_RELATIONAL_FUNCTION(Class, method)                           \
    void LLVMVisitor::bvisit(const Class &x)                                   \
    {                                                                          \
        llvm::Value *left = apply(*x.get_arg1());                              \
        llvm::Value *right = apply(*x.get_arg2());                             \
        result_ = builder->method(left, right);                                \
        result_ = builder->CreateUIToFP(result_,                               \
                                        get_float_type(&mod->getContext()));   \
    }

SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);

#define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)                         \
    void LLVMDoubleVisitor::visit(const Class &x)                              \
    {                                                                          \
        vec_basic basic_args = x.get_args();                                   \
        llvm::Function *func = get_external_function(#ext, basic_args.size()); \
        std::vector<llvm::Value *> args;                                       \
        for (const auto &arg : basic_args) {                                   \
            args.push_back(apply(*arg));                                       \
        }                                                                      \
        auto r = builder->CreateCall(func, args);                              \
        r->setTailCall(true);                                                  \
        result_ = r;                                                           \
    }                                                                          \
    void LLVMFloatVisitor::visit(const Class &x)                               \
    {                                                                          \
        vec_basic basic_args = x.get_args();                                   \
        llvm::Function *func = get_external_function(#ext + std::string("f"),  \
                                                     basic_args.size());       \
        std::vector<llvm::Value *> args;                                       \
        for (const auto &arg : basic_args) {                                   \
            args.push_back(apply(*arg));                                       \
        }                                                                      \
        auto r = builder->CreateCall(func, args);                              \
        r->setTailCall(true);                                                  \
        result_ = r;                                                           \
    }

#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
#define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)                          \
    _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)                             \
    void LLVMLongDoubleVisitor::visit(const Class &x)                          \
    {                                                                          \
        vec_basic basic_args = x.get_args();                                   \
        llvm::Function *func = get_external_function(#ext + std::string("l"),  \
                                                     basic_args.size());       \
        std::vector<llvm::Value *> args;                                       \
        for (const auto &arg : basic_args) {                                   \
            args.push_back(apply(*arg));                                       \
        }                                                                      \
        auto r = builder->CreateCall(func, args);                              \
        r->setTailCall(true);                                                  \
        result_ = r;                                                           \
    }
#else
#define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)                          \
    _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
#endif

SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)

void LLVMVisitor::bvisit(const Abs &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    args.push_back(apply(*x.get_arg()));
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::fabs, 1, mod);
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

void LLVMVisitor::bvisit(const Min &x)
{
    llvm::Value *value = nullptr;
    llvm::Function *fun;
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::minnum, 2, mod);
    for (auto &arg : x.get_vec()) {
        if (value != nullptr) {
            std::vector<llvm::Value *> args;
            args.push_back(value);
            args.push_back(apply(*arg));
            auto r = builder->CreateCall(fun, args);
            r->setTailCall(true);
            value = r;
        } else {
            value = apply(*arg);
        }
    }
    result_ = value;
}

void LLVMVisitor::bvisit(const Max &x)
{
    llvm::Value *value = nullptr;
    llvm::Function *fun;
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::maxnum, 2, mod);
    for (auto &arg : x.get_vec()) {
        if (value != nullptr) {
            std::vector<llvm::Value *> args;
            args.push_back(value);
            args.push_back(apply(*arg));
            auto r = builder->CreateCall(fun, args);
            r->setTailCall(true);
            value = r;
        } else {
            value = apply(*arg);
        }
    }
    result_ = value;
}

void LLVMVisitor::bvisit(const Symbol &x)
{
    unsigned i = 0;
    for (auto &symb : symbols) {
        if (eq(x, *symb)) {
            result_ = symbol_ptrs[i];
            return;
        }
        ++i;
    }
    auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
    if (it != replacement_symbol_ptrs.end()) {
        result_ = it->second;
        return;
    }
    throw std::runtime_error("Symbol " + x.__str__()
                             + " not in the symbols vector.");
}

llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
                                                   size_t nargs)
{
    std::vector<llvm::Type *> func_args(nargs,
                                        get_float_type(&mod->getContext()));
    llvm::FunctionType *func_type = llvm::FunctionType::get(
        get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);

    llvm::Function *func = mod->getFunction(name);
    if (!func) {
        func = llvm::Function::Create(
            func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
        func->setCallingConv(llvm::CallingConv::C);
    }
#if (LLVM_VERSION_MAJOR < 5)
    llvm::AttributeSet func_attr_set;
    {
        llvm::SmallVector<llvm::AttributeSet, 4> attrs;
        llvm::AttributeSet attr_set;
        {
            llvm::AttrBuilder attr_builder;
            attr_builder.addAttribute(llvm::Attribute::NoUnwind);
            attr_set
                = llvm::AttributeSet::get(mod->getContext(), ~0U, attr_builder);
        }

        attrs.push_back(attr_set);
        func_attr_set = llvm::AttributeSet::get(mod->getContext(), attrs);
    }
    func->setAttributes(func_attr_set);
#else
    func->addFnAttr(llvm::Attribute::NoUnwind);
#endif
    return func;
}

void LLVMVisitor::bvisit(const Constant &x)
{
    set_double(eval_double(x));
}

#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
void LLVMLongDoubleVisitor::visit(const Constant &x)
{
    convert_from_mpfr(x);
}
#endif

void LLVMVisitor::bvisit(const Basic &)
{
    throw std::runtime_error("Not implemented.");
}

const std::string &LLVMVisitor::dumps() const
{
    return membuffer;
};

void LLVMVisitor::loads(const std::string &s)
{
    membuffer = s;
    llvm::InitializeNativeTarget();
    llvm::InitializeNativeTargetAsmPrinter();
    llvm::InitializeNativeTargetAsmParser();
    context = std::make_shared<llvm::LLVMContext>();

    // Create some module to put our function into it.
    std::unique_ptr<llvm::Module> module
        = make_unique<llvm::Module>("SymEngine", *context);
    module->setDataLayout("");
    mod = module.get();

    // Only defining the prototype for the function here.
    // Since we know where the function is stored that's enough
    // llvm::ObjectCache is designed for caching objects, but it
    // is used here for loading one specific object.
    auto F = get_function_type(context.get());

    std::string error;
    executionengine = std::shared_ptr<llvm::ExecutionEngine>(
        llvm::EngineBuilder(std::move(module))
            .setEngineKind(llvm::EngineKind::Kind::JIT)
            .setOptLevel(llvm::CodeGenOpt::Level::Aggressive)
            .setErrorStr(&error)
            .create());

    class MCJITObjectLoader : public llvm::ObjectCache
    {
        const std::string &s_;

    public:
        MCJITObjectLoader(const std::string &s) : s_(s)
        {
        }
        virtual void notifyObjectCompiled(const llvm::Module *M,
                                          llvm::MemoryBufferRef obj)
        {
        }

        // No need to check M because there is only one function
        // Return it after reading from the file.
        virtual std::unique_ptr<llvm::MemoryBuffer>
        getObject(const llvm::Module *M)
        {
            return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
        }
    };

    MCJITObjectLoader loader(s);
    executionengine->setObjectCache(&loader);
    executionengine->finalizeObject();
    // Set func to compiled function pointer
    func = (intptr_t)executionengine->getPointerToFunction(F);
}

void LLVMVisitor::bvisit(const Floor &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    args.push_back(apply(*x.get_arg()));
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::floor, 1, mod);
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

void LLVMVisitor::bvisit(const Ceiling &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    args.push_back(apply(*x.get_arg()));
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::ceil, 1, mod);
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
{
    apply(*x.get_arg());
}

void LLVMVisitor::bvisit(const Truncate &x)
{
    std::vector<llvm::Value *> args;
    llvm::Function *fun;
    args.push_back(apply(*x.get_arg()));
    fun = get_float_intrinsic(get_float_type(&mod->getContext()),
                              llvm::Intrinsic::trunc, 1, mod);
    auto r = builder->CreateCall(fun, args);
    r->setTailCall(true);
    result_ = r;
}

llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
{
    return llvm::Type::getDoubleTy(*context);
}

llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
{
    return llvm::Type::getFloatTy(*context);
}

#if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
{
    return llvm::Type::getX86_FP80Ty(*context);
}
#endif

} // namespace SymEngine