1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
37 #include <symengine/llvm_double.h>
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
47 using llvm::make_unique;
50 #if (LLVM_VERSION_MAJOR >= 18)
51 using CodeGenOptLevel = llvm::CodeGenOptLevel;
53 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
56 #if (LLVM_VERSION_MAJOR >= 20)
57 auto GetDeclaration = [](llvm::Module *M, llvm::Intrinsic::ID id,
58 llvm::ArrayRef<llvm::Type *> Tys) {
59 return llvm::Intrinsic::getOrInsertDeclaration(M,
id, Tys);
62 const auto &GetDeclaration = llvm::Intrinsic::getDeclaration;
65 #if (LLVM_VERSION_MAJOR >= 21)
66 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
68 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
75 LLVMVisitor::LLVMVisitor() =
default;
76 LLVMVisitor::~LLVMVisitor() =
default;
78 llvm::Value *LLVMVisitor::apply(
const Basic &b)
84 void LLVMVisitor::init(
const vec_basic &x,
const Basic &b,
bool symbolic_cse,
87 init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
90 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
92 std::vector<llvm::Type *> inp;
93 for (
int i = 0; i < 2; i++) {
94 #if (LLVM_VERSION_MAJOR >= 22)
95 inp.push_back(llvm::PointerType::get(*context, 0));
97 inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
100 llvm::FunctionType *function_type = llvm::FunctionType::get(
101 llvm::Type::getVoidTy(*context), inp,
false);
102 auto F = llvm::Function::Create(
103 function_type, llvm::Function::InternalLinkage,
"symengine_func",
mod);
104 F->setCallingConv(llvm::CallingConv::C);
105 F->addParamAttr(0, llvm::Attribute::ReadOnly);
106 #if (LLVM_VERSION_MAJOR >= 21)
107 F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
108 *context, llvm::CaptureInfo::none()));
109 F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
110 *context, llvm::CaptureInfo::none()));
112 F->addParamAttr(0, llvm::Attribute::NoCapture);
113 F->addParamAttr(1, llvm::Attribute::NoCapture);
115 F->addFnAttr(llvm::Attribute::NoUnwind);
116 #if (LLVM_VERSION_MAJOR < 15)
117 F->addFnAttr(llvm::Attribute::UWTable);
119 F->addFnAttr(llvm::Attribute::getWithUWTableKind(
120 *context, llvm::UWTableKind::Default));
125 void LLVMVisitor::init(
const vec_basic &inputs,
const vec_basic &outputs,
126 const bool symbolic_cse,
unsigned opt_level)
128 executionengine.reset();
129 llvm::InitializeNativeTarget();
130 llvm::InitializeNativeTargetAsmPrinter();
131 llvm::InitializeNativeTargetAsmParser();
132 context = make_unique<llvm::LLVMContext>();
136 std::unique_ptr<llvm::Module> module
137 = make_unique<llvm::Module>(
"SymEngine", *context.get());
138 module->setDataLayout(
"");
141 auto F = get_function_type(context.get());
146 llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context,
"EntryBlock", F);
151 llvm::IRBuilder<> _builder(BB);
152 builder =
reinterpret_cast<IRBuilder *
>(&_builder);
153 builder->SetInsertPoint(BB);
154 auto fmf = llvm::FastMathFlags();
156 builder->setFastMathFlags(fmf);
159 auto input_arg = &(*(F->args().begin()));
160 for (
unsigned i = 0; i < inputs.size(); i++) {
161 if (not is_a<Symbol>(*inputs[i])) {
162 throw SymEngineException(
"Input contains a non-symbol.");
165 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
166 auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
168 result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
169 symbol_ptrs.push_back(result_);
172 auto it = F->args().begin();
173 auto out = &(*(it + 1));
174 std::vector<llvm::Value *> output_vals;
177 vec_basic reduced_exprs;
178 vec_pair replacements;
180 SymEngine::cse(replacements, reduced_exprs, outputs);
181 for (
auto &rep : replacements) {
183 replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
186 for (
unsigned i = 0; i < outputs.size(); i++) {
187 output_vals.push_back(apply(*reduced_exprs[i]));
191 for (
unsigned i = 0; i < outputs.size(); i++) {
192 output_vals.push_back(apply(*outputs[i]));
197 for (
unsigned i = 0; i < outputs.size(); i++) {
199 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
201 = builder->CreateGEP(get_float_type(context.get()), out, index);
202 builder->CreateStore(output_vals[i], ptr);
206 builder->CreateRetVoid();
209 llvm::verifyFunction(*F, &llvm::outs());
216 #if (LLVM_VERSION_MAJOR < 14)
217 using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
219 using OptimizationLevel = llvm::OptimizationLevel;
221 llvm::PassBuilder PB;
222 llvm::ModuleAnalysisManager MAM;
223 llvm::CGSCCAnalysisManager CGAM;
224 llvm::FunctionAnalysisManager FAM;
225 llvm::LoopAnalysisManager LAM;
226 PB.registerModuleAnalyses(MAM);
227 PB.registerCGSCCAnalyses(CGAM);
228 PB.registerFunctionAnalyses(FAM);
229 PB.registerLoopAnalyses(LAM);
230 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
231 llvm::FunctionPassManager FPM;
232 OptimizationLevel pb_opt_level{OptimizationLevel::O3};
233 if (opt_level == 0) {
234 pb_opt_level = OptimizationLevel::O0;
235 }
else if (opt_level == 1) {
236 pb_opt_level = OptimizationLevel::O1;
237 }
else if (opt_level == 2) {
238 pb_opt_level = OptimizationLevel::O2;
241 if (opt_level != 0) {
242 #if (LLVM_VERSION_MAJOR < 6)
243 FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
244 #elif (LLVM_VERSION_MAJOR < 12)
245 FPM = PB.buildFunctionSimplificationPipeline(
246 pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
248 FPM = PB.buildFunctionSimplificationPipeline(
249 pb_opt_level, llvm::ThinOrFullLTOPhase::None);
259 executionengine = std::unique_ptr<llvm::ExecutionEngine>(
260 llvm::EngineBuilder(std::move(module))
261 .setEngineKind(llvm::EngineKind::Kind::JIT)
262 .setOptLevel(
static_cast<CodeGenOptLevel
>(opt_level))
268 modify_execution_engine(executionengine.get());
271 class MemoryBufferRefCallback :
public llvm::ObjectCache
275 explicit MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
277 void notifyObjectCompiled(
const llvm::Module *M,
278 llvm::MemoryBufferRef obj)
override
280 const char *c = obj.getBufferStart();
282 ss_.assign(c, obj.getBufferSize());
285 std::unique_ptr<llvm::MemoryBuffer>
286 getObject(
const llvm::Module *M)
override
292 MemoryBufferRefCallback callback(membuffer);
293 executionengine->setObjectCache(&callback);
295 executionengine->finalizeObject();
298 func = (intptr_t)executionengine->getPointerToFunction(F);
300 replacement_symbol_ptrs.clear();
304 LLVMDoubleVisitor::LLVMDoubleVisitor() =
default;
305 LLVMDoubleVisitor::~LLVMDoubleVisitor() =
default;
307 double LLVMDoubleVisitor::call(
const std::vector<double> &vec)
const
310 ((double (*)(
const double *,
double *))func)(vec.data(), &ret);
314 void LLVMDoubleVisitor::call(
double *outs,
const double *inps)
const
316 ((double (*)(
const double *,
double *))func)(inps, outs);
319 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
321 LLVMLongDoubleVisitor::call(
const std::vector<long double> &vec)
const
324 ((
long double (*)(
const long double *,
long double *))func)(vec.data(),
329 void LLVMLongDoubleVisitor::call(
long double *outs,
330 const long double *inps)
const
332 ((
long double (*)(
const long double *,
long double *))func)(inps, outs);
336 LLVMFloatVisitor::LLVMFloatVisitor() =
default;
337 LLVMFloatVisitor::~LLVMFloatVisitor() =
default;
339 float LLVMFloatVisitor::call(
const std::vector<float> &vec)
const
342 ((float (*)(
const float *,
float *))func)(vec.data(), &ret);
346 void LLVMFloatVisitor::call(
float *outs,
const float *inps)
const
348 ((float (*)(
const float *,
float *))func)(inps, outs);
351 void LLVMVisitor::set_double(
double d)
353 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()), d);
356 void LLVMVisitor::bvisit(
const Integer &x)
358 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
359 mp_get_d(x.as_integer_class()));
362 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
364 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() =
default;
365 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() =
default;
367 void LLVMLongDoubleVisitor::convert_from_mpfr(
const Basic &x)
369 #ifndef HAVE_SYMENGINE_MPFR
370 throw NotImplementedError(
"Cannot convert to long double without MPFR");
372 RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
373 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
378 void LLVMLongDoubleVisitor::visit(
const Integer &x)
380 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
385 void LLVMVisitor::bvisit(
const Rational &x)
387 set_double(mp_get_d(x.as_rational_class()));
390 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
391 void LLVMLongDoubleVisitor::visit(
const Rational &x)
393 convert_from_mpfr(x);
397 void LLVMVisitor::bvisit(
const RealDouble &x)
402 #ifdef HAVE_SYMENGINE_MPFR
403 void LLVMVisitor::bvisit(
const RealMPFR &x)
405 set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
407 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
408 void LLVMLongDoubleVisitor::visit(
const RealMPFR &x)
410 convert_from_mpfr(x);
415 void LLVMVisitor::bvisit(
const Add &x)
417 llvm::Value *tmp, *tmp1, *tmp2;
418 auto it = x.get_dict().begin();
420 if (
eq(*x.get_coef(), *zero)) {
422 if (
eq(*one, *(it->second))) {
423 tmp = apply(*(it->first));
425 tmp1 = apply(*(it->first));
426 tmp2 = apply(*(it->second));
427 tmp = builder->CreateFMul(tmp1, tmp2);
431 tmp = apply(*x.get_coef());
434 for (; it != x.get_dict().end(); ++it) {
435 if (
eq(*one, *(it->second))) {
436 tmp1 = apply(*(it->first));
437 tmp = builder->CreateFAdd(tmp, tmp1);
444 tmp1 = apply(*(it->first));
445 tmp2 = apply(*(it->second));
446 tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
452 void LLVMVisitor::bvisit(
const Mul &x)
454 llvm::Value *tmp =
nullptr;
456 for (
const auto &p : x.get_args()) {
460 tmp = builder->CreateFMul(tmp, apply(*p));
467 llvm::Function *LLVMVisitor::get_powi()
469 std::vector<llvm::Type *> arg_type;
470 arg_type.push_back(get_float_type(&
mod->getContext()));
471 #if (LLVM_VERSION_MAJOR > 12)
472 arg_type.push_back(llvm::Type::getInt32Ty(
mod->getContext()));
474 return GetDeclaration(
mod, llvm::Intrinsic::powi, arg_type);
477 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID
id,
478 unsigned n, llvm::Module *
mod)
480 std::vector<llvm::Type *> arg_type(n, type);
481 return GetDeclaration(
mod,
id, arg_type);
484 void LLVMVisitor::bvisit(
const Pow &x)
486 std::vector<llvm::Value *> args;
488 if (
eq(*(x.get_base()), *E)) {
489 args.push_back(apply(*x.get_exp()));
490 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
491 llvm::Intrinsic::exp, 1,
mod);
493 }
else if (
eq(*(x.get_base()), *
integer(2))) {
494 args.push_back(apply(*x.get_exp()));
495 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
496 llvm::Intrinsic::exp2, 1,
mod);
499 if (is_a<Integer>(*x.get_exp())) {
501 llvm::Value *tmp = apply(*x.get_base());
502 result_ = builder->CreateFMul(tmp, tmp);
505 args.push_back(apply(*x.get_base()));
506 int d = numeric_cast<int>(
507 mp_get_si(
static_cast<const Integer &
>(*x.get_exp())
508 .as_integer_class()));
509 result_ = llvm::ConstantInt::get(
510 llvm::Type::getInt32Ty(
mod->getContext()), d,
true);
511 args.push_back(result_);
515 args.push_back(apply(*x.get_base()));
516 args.push_back(apply(*x.get_exp()));
517 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
518 llvm::Intrinsic::pow, 1,
mod);
521 auto r = builder->CreateCall(fun, args);
522 r->setTailCall(
true);
526 void LLVMVisitor::bvisit(
const Sin &x)
528 std::vector<llvm::Value *> args;
530 args.push_back(apply(*x.get_arg()));
531 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
532 llvm::Intrinsic::sin, 1,
mod);
533 auto r = builder->CreateCall(fun, args);
534 r->setTailCall(
true);
538 void LLVMVisitor::bvisit(
const Cos &x)
540 std::vector<llvm::Value *> args;
542 args.push_back(apply(*x.get_arg()));
543 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
544 llvm::Intrinsic::cos, 1,
mod);
545 auto r = builder->CreateCall(fun, args);
546 r->setTailCall(
true);
550 void LLVMVisitor::bvisit(
const Piecewise &x)
552 std::vector<llvm::BasicBlock> blocks;
554 RCP<const Piecewise> pw = x.rcp_from_this_cast<
const Piecewise>();
556 if (
neq(*pw->get_vec().back().second, *boolTrue)) {
557 throw SymEngineException(
558 "LLVMDouble requires a (Expr, True) at the end of Piecewise");
561 if (pw->get_vec().size() > 2) {
562 PiecewiseVec rest = pw->get_vec();
563 rest.erase(rest.begin());
564 auto rest_pw = piecewise(std::move(rest));
566 new_pw.push_back(*pw->get_vec().begin());
567 new_pw.push_back({rest_pw, pw->get_vec().back().second});
568 pw = piecewise(std::move(new_pw))
569 ->rcp_from_this_cast<
const Piecewise>();
570 }
else if (pw->get_vec().size() < 2) {
571 throw SymEngineException(
"Invalid Piecewise object");
574 auto cond_basic = pw->get_vec().front().second;
575 llvm::Value *cond = apply(*cond_basic);
577 cond = builder->CreateFCmpONE(
578 cond, llvm::ConstantFP::get(get_float_type(&
mod->getContext()), 0.0),
580 llvm::Function *
function = builder->GetInsertBlock()->getParent();
584 llvm::BasicBlock *then_bb
585 = llvm::BasicBlock::Create(
mod->getContext(),
"then",
function);
586 llvm::BasicBlock *else_bb
587 = llvm::BasicBlock::Create(
mod->getContext(),
"else");
588 llvm::BasicBlock *merge_bb
589 = llvm::BasicBlock::Create(
mod->getContext(),
"ifcont");
590 builder->CreateCondBr(cond, then_bb, else_bb);
593 builder->SetInsertPoint(then_bb);
594 llvm::Value *then_value = apply(*pw->get_vec().front().first);
595 builder->CreateBr(merge_bb);
599 then_bb = builder->GetInsertBlock();
602 #if (LLVM_VERSION_MAJOR < 16)
603 function->getBasicBlockList().push_back(else_bb);
605 function->insert(function->end(), else_bb);
607 builder->SetInsertPoint(else_bb);
608 llvm::Value *else_value = apply(*pw->get_vec().back().first);
609 builder->CreateBr(merge_bb);
613 else_bb = builder->GetInsertBlock();
616 #if (LLVM_VERSION_MAJOR < 16)
617 function->getBasicBlockList().push_back(merge_bb);
619 function->insert(function->end(), merge_bb);
621 builder->SetInsertPoint(merge_bb);
622 llvm::PHINode *phi_node
623 = builder->CreatePHI(get_float_type(&
mod->getContext()), 2);
625 phi_node->addIncoming(then_value, then_bb);
626 phi_node->addIncoming(else_value, else_bb);
630 void LLVMVisitor::bvisit(
const Sign &x)
632 const auto x2 = x.get_arg();
634 new_pw.push_back({real_double(0.0),
Eq(x2, real_double(0.0))});
635 new_pw.push_back({real_double(-1.0),
Lt(x2, real_double(0.0))});
636 new_pw.push_back({real_double(1.0), boolTrue});
637 auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
641 void LLVMVisitor::bvisit(
const Contains &cts)
643 llvm::Value *expr = apply(*cts.get_expr());
644 const auto set = cts.get_set();
645 if (is_a<Interval>(*set)) {
646 const auto &interv = down_cast<const Interval &>(*set);
647 llvm::Value *start = apply(*interv.get_start());
648 llvm::Value *end = apply(*interv.get_end());
649 const bool left_open = interv.get_left_open();
650 const bool right_open = interv.get_right_open();
651 llvm::Value *left_ok;
652 llvm::Value *right_ok;
653 left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
654 : builder->CreateFCmpOLE(start, expr);
655 right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
656 : builder->CreateFCmpOLE(expr, end);
657 result_ = builder->CreateAnd(left_ok, right_ok);
658 result_ = builder->CreateUIToFP(result_,
659 get_float_type(&
mod->getContext()));
661 throw SymEngineException(
"LLVMVisitor: only ``Interval`` "
662 "implemented for ``Contains``.");
666 void LLVMVisitor::bvisit(
const Infty &x)
668 if (x.is_negative_infinity()) {
669 result_ = llvm::ConstantFP::getInfinity(
670 get_float_type(&
mod->getContext()),
true);
671 }
else if (x.is_positive_infinity()) {
672 result_ = llvm::ConstantFP::getInfinity(
673 get_float_type(&
mod->getContext()),
false);
675 throw SymEngineException(
676 "LLVMDouble can only represent real valued infinity");
680 void LLVMVisitor::bvisit(
const NaN &x)
682 result_ = llvm::ConstantFP::getNaN(get_float_type(&
mod->getContext()),
686 void LLVMVisitor::bvisit(
const BooleanAtom &x)
688 const bool val = x.get_val();
689 set_double(val ? 1.0 : 0.0);
692 void LLVMVisitor::bvisit(
const Log &x)
694 std::vector<llvm::Value *> args;
696 args.push_back(apply(*x.get_arg()));
697 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
698 llvm::Intrinsic::log, 1,
mod);
699 auto r = builder->CreateCall(fun, args);
700 r->setTailCall(
true);
704 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
705 void LLVMVisitor::bvisit(const Class &x) \
707 llvm::Value *value = nullptr; \
710 llvm::Value *zero_val = result_; \
711 for (auto &p : x.get_container()) { \
712 tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
713 if (value == nullptr) { \
716 value = builder->method(value, tmp); \
719 result_ = builder->CreateUIToFP(value, \
720 get_float_type(&mod->getContext())); \
723 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
724 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
725 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
727 void LLVMVisitor::bvisit(
const Not &x)
730 llvm::Value *zero_val = result_;
731 llvm::Value *value = builder->CreateFCmpONE(apply(*x.get_arg()), zero_val);
732 result_ = builder->CreateUIToFP(builder->CreateNot(value),
733 get_float_type(&
mod->getContext()));
736 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
737 void LLVMVisitor::bvisit(const Class &x) \
739 llvm::Value *left = apply(*x.get_arg1()); \
740 llvm::Value *right = apply(*x.get_arg2()); \
741 result_ = builder->method(left, right); \
742 result_ = builder->CreateUIToFP(result_, \
743 get_float_type(&mod->getContext())); \
746 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
747 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
748 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
749 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
751 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
752 void LLVMDoubleVisitor::visit(const Class &x) \
754 vec_basic basic_args = x.get_args(); \
755 llvm::Function *func = get_external_function(#ext, basic_args.size()); \
756 std::vector<llvm::Value *> args; \
757 for (const auto &arg : basic_args) { \
758 args.push_back(apply(*arg)); \
760 auto r = builder->CreateCall(func, args); \
761 r->setTailCall(true); \
764 void LLVMFloatVisitor::visit(const Class &x) \
766 vec_basic basic_args = x.get_args(); \
767 llvm::Function *func = get_external_function(#ext + std::string("f"), \
768 basic_args.size()); \
769 std::vector<llvm::Value *> args; \
770 for (const auto &arg : basic_args) { \
771 args.push_back(apply(*arg)); \
773 auto r = builder->CreateCall(func, args); \
774 r->setTailCall(true); \
778 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
779 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
780 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
781 void LLVMLongDoubleVisitor::visit(const Class &x) \
783 vec_basic basic_args = x.get_args(); \
784 llvm::Function *func = get_external_function(#ext + std::string("l"), \
785 basic_args.size()); \
786 std::vector<llvm::Value *> args; \
787 for (const auto &arg : basic_args) { \
788 args.push_back(apply(*arg)); \
790 auto r = builder->CreateCall(func, args); \
791 r->setTailCall(true); \
795 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
796 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
799 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan,
tan)
800 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin,
asin)
801 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos,
acos)
802 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan,
atan)
803 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2,
atan2)
804 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh,
sinh)
805 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh,
cosh)
806 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh,
tanh)
807 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh,
asinh)
808 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh,
acosh)
809 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh,
atanh)
810 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
811 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
812 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf,
erf)
813 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc,
erfc)
815 void LLVMVisitor::bvisit(
const Abs &x)
817 std::vector<llvm::Value *> args;
819 args.push_back(apply(*x.get_arg()));
820 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
821 llvm::Intrinsic::fabs, 1,
mod);
822 auto r = builder->CreateCall(fun, args);
823 r->setTailCall(
true);
827 void LLVMVisitor::bvisit(
const Min &x)
829 llvm::Value *value =
nullptr;
831 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
832 llvm::Intrinsic::minnum, 1,
mod);
833 for (
auto &arg : x.get_vec()) {
834 if (value !=
nullptr) {
835 std::vector<llvm::Value *> args;
836 args.push_back(value);
837 args.push_back(apply(*arg));
838 auto r = builder->CreateCall(fun, args);
839 r->setTailCall(
true);
848 void LLVMVisitor::bvisit(
const Max &x)
850 llvm::Value *value =
nullptr;
852 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
853 llvm::Intrinsic::maxnum, 1,
mod);
854 for (
auto &arg : x.get_vec()) {
855 if (value !=
nullptr) {
856 std::vector<llvm::Value *> args;
857 args.push_back(value);
858 args.push_back(apply(*arg));
859 auto r = builder->CreateCall(fun, args);
860 r->setTailCall(
true);
869 void LLVMVisitor::bvisit(
const Symbol &x)
872 for (
auto &symb : symbols) {
874 result_ = symbol_ptrs[i];
879 auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
880 if (it != replacement_symbol_ptrs.end()) {
881 result_ = it->second;
885 throw SymEngineException(
"Symbol " + x.__str__()
886 +
" not in the symbols vector.");
889 llvm::Function *LLVMVisitor::get_external_function(
const std::string &name,
892 std::vector<llvm::Type *> func_args(nargs,
893 get_float_type(&
mod->getContext()));
894 llvm::FunctionType *func_type = llvm::FunctionType::get(
895 get_float_type(&
mod->getContext()), func_args,
false);
897 llvm::Function *func =
mod->getFunction(name);
899 func = llvm::Function::Create(
900 func_type, llvm::GlobalValue::ExternalLinkage, name,
mod);
901 func->setCallingConv(llvm::CallingConv::C);
903 func->addFnAttr(llvm::Attribute::NoUnwind);
907 void LLVMVisitor::bvisit(
const Constant &x)
909 set_double(eval_double(x));
912 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
913 void LLVMLongDoubleVisitor::visit(
const Constant &x)
915 convert_from_mpfr(x);
919 void LLVMVisitor::bvisit(
const Basic &x)
921 throw NotImplementedError(x.__str__());
924 const std::string &LLVMVisitor::dumps()
const
929 void LLVMVisitor::loads(
const std::string &s)
932 llvm::InitializeNativeTarget();
933 llvm::InitializeNativeTargetAsmPrinter();
934 llvm::InitializeNativeTargetAsmParser();
935 context = make_unique<llvm::LLVMContext>();
938 std::unique_ptr<llvm::Module> module
939 = make_unique<llvm::Module>(
"SymEngine", *context);
940 module->setDataLayout(
"");
947 auto F = get_function_type(context.get());
950 executionengine = std::unique_ptr<llvm::ExecutionEngine>(
951 llvm::EngineBuilder(std::move(module))
952 .setEngineKind(llvm::EngineKind::Kind::JIT)
953 .setOptLevel(CodeGenOptLevel::Aggressive)
959 modify_execution_engine(executionengine.get());
961 class MCJITObjectLoader :
public llvm::ObjectCache
963 const std::string &s_;
966 MCJITObjectLoader(
const std::string &s) : s_(s) {}
967 void notifyObjectCompiled(
const llvm::Module *M,
968 llvm::MemoryBufferRef obj)
override
974 std::unique_ptr<llvm::MemoryBuffer>
975 getObject(
const llvm::Module *M)
override
977 return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
981 MCJITObjectLoader loader(s);
982 executionengine->setObjectCache(&loader);
983 executionengine->finalizeObject();
985 func = (intptr_t)executionengine->getPointerToFunction(F);
988 void LLVMVisitor::bvisit(
const Floor &x)
990 std::vector<llvm::Value *> args;
992 args.push_back(apply(*x.get_arg()));
993 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
994 llvm::Intrinsic::floor, 1,
mod);
995 auto r = builder->CreateCall(fun, args);
996 r->setTailCall(
true);
1000 void LLVMVisitor::bvisit(
const Ceiling &x)
1002 std::vector<llvm::Value *> args;
1003 llvm::Function *fun;
1004 args.push_back(apply(*x.get_arg()));
1005 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1006 llvm::Intrinsic::ceil, 1,
mod);
1007 auto r = builder->CreateCall(fun, args);
1008 r->setTailCall(
true);
1012 void LLVMVisitor::bvisit(
const UnevaluatedExpr &x)
1014 apply(*x.get_arg());
1017 void LLVMVisitor::bvisit(
const Truncate &x)
1019 std::vector<llvm::Value *> args;
1020 llvm::Function *fun;
1021 args.push_back(apply(*x.get_arg()));
1022 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1023 llvm::Intrinsic::trunc, 1,
mod);
1024 auto r = builder->CreateCall(fun, args);
1025 r->setTailCall(
true);
1029 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1031 return llvm::Type::getDoubleTy(*context);
1034 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1036 return llvm::Type::getFloatTy(*context);
1039 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1040 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1042 return llvm::Type::getX86_FP80Ty(*context);
The lowest unit of symbolic representation.
Main namespace for SymEngine package.
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh: