1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
37 #include <symengine/llvm_double.h>
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
47 using llvm::make_unique;
49 #if (LLVM_VERSION_MAJOR < 18)
50 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
52 using CodeGenOptLevel = llvm::CodeGenOptLevel;
55 #if (LLVM_VERSION_MAJOR >= 21)
56 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
58 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
65 LLVMVisitor::LLVMVisitor() =
default;
66 LLVMVisitor::~LLVMVisitor() =
default;
68 llvm::Value *LLVMVisitor::apply(
const Basic &b)
74 void LLVMVisitor::init(
const vec_basic &x,
const Basic &b,
bool symbolic_cse,
77 init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
80 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
83 for (
int i = 0; i < 2; i++) {
84 inp.
push_back(llvm::PointerType::get(get_float_type(context), 0));
86 llvm::FunctionType *function_type = llvm::FunctionType::get(
87 llvm::Type::getVoidTy(*context), inp,
false);
88 auto F = llvm::Function::Create(function_type,
89 llvm::Function::InternalLinkage,
"",
mod);
90 F->setCallingConv(llvm::CallingConv::C);
91 #if (LLVM_VERSION_MAJOR < 5)
93 llvm::SmallVector<llvm::AttributeSet, 4> attrs;
94 llvm::AttributeSet PAS;
97 B.addAttribute(llvm::Attribute::ReadOnly);
99 PAS = llvm::AttributeSet::get(
mod->getContext(), 1U, B);
102 attrs.push_back(PAS);
106 PAS = llvm::AttributeSet::get(
mod->getContext(), 2U, B);
109 attrs.push_back(PAS);
112 B.addAttribute(llvm::Attribute::NoUnwind);
113 B.addAttribute(llvm::Attribute::UWTable);
114 PAS = llvm::AttributeSet::get(
mod->getContext(), ~0U, B);
117 attrs.push_back(PAS);
119 F->setAttributes(llvm::AttributeSet::get(
mod->getContext(), attrs));
122 F->addParamAttr(0, llvm::Attribute::ReadOnly);
123 #if (LLVM_VERSION_MAJOR >= 21)
124 F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
125 *context, llvm::CaptureInfo::none()));
126 F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
127 *context, llvm::CaptureInfo::none()));
129 F->addParamAttr(0, llvm::Attribute::NoCapture);
130 F->addParamAttr(1, llvm::Attribute::NoCapture);
132 F->addFnAttr(llvm::Attribute::NoUnwind);
133 #if (LLVM_VERSION_MAJOR < 15)
134 F->addFnAttr(llvm::Attribute::UWTable);
136 F->addFnAttr(llvm::Attribute::getWithUWTableKind(
137 *context, llvm::UWTableKind::Default));
143 void LLVMVisitor::init(
const vec_basic &inputs,
const vec_basic &outputs,
144 const bool symbolic_cse,
unsigned opt_level)
146 executionengine.reset();
147 llvm::InitializeNativeTarget();
148 llvm::InitializeNativeTargetAsmPrinter();
149 llvm::InitializeNativeTargetAsmParser();
150 context = make_unique<llvm::LLVMContext>();
155 = make_unique<llvm::Module>(
"SymEngine", *context.get());
156 module->setDataLayout(
"");
159 auto F = get_function_type(context.get());
164 llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context,
"EntryBlock", F);
169 llvm::IRBuilder<> _builder(BB);
170 builder =
reinterpret_cast<IRBuilder *
>(&_builder);
171 builder->SetInsertPoint(BB);
172 auto fmf = llvm::FastMathFlags();
174 builder->setFastMathFlags(fmf);
177 auto input_arg = &(*(F->args().begin()));
178 for (
unsigned i = 0; i < inputs.size(); i++) {
179 if (not is_a<Symbol>(*inputs[i])) {
180 throw SymEngineException(
"Input contains a non-symbol.");
183 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
184 auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
186 result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
187 symbol_ptrs.push_back(result_);
190 auto it = F->args().begin();
191 #if (LLVM_VERSION_MAJOR < 5)
192 auto out = &(*(++it));
194 auto out = &(*(it + 1));
199 vec_basic reduced_exprs;
200 vec_pair replacements;
202 SymEngine::cse(replacements, reduced_exprs, outputs);
203 for (
auto &rep : replacements) {
205 replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
208 for (
unsigned i = 0; i < outputs.size(); i++) {
209 output_vals.
push_back(apply(*reduced_exprs[i]));
213 for (
unsigned i = 0; i < outputs.size(); i++) {
214 output_vals.
push_back(apply(*outputs[i]));
219 for (
unsigned i = 0; i < outputs.size(); i++) {
221 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
223 = builder->CreateGEP(get_float_type(context.get()), out, index);
224 builder->CreateStore(output_vals[i], ptr);
228 builder->CreateRetVoid();
231 llvm::verifyFunction(*F, &llvm::outs());
244 #if (LLVM_VERSION_MAJOR < 14)
245 using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
247 using OptimizationLevel = llvm::OptimizationLevel;
249 llvm::PassBuilder PB;
250 llvm::ModuleAnalysisManager MAM;
251 llvm::CGSCCAnalysisManager CGAM;
252 llvm::FunctionAnalysisManager FAM;
253 llvm::LoopAnalysisManager LAM;
254 PB.registerModuleAnalyses(MAM);
255 PB.registerCGSCCAnalyses(CGAM);
256 PB.registerFunctionAnalyses(FAM);
257 PB.registerLoopAnalyses(LAM);
258 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
259 llvm::FunctionPassManager FPM;
260 OptimizationLevel pb_opt_level{OptimizationLevel::O3};
261 if (opt_level == 0) {
262 pb_opt_level = OptimizationLevel::O0;
263 }
else if (opt_level == 1) {
264 pb_opt_level = OptimizationLevel::O1;
265 }
else if (opt_level == 2) {
266 pb_opt_level = OptimizationLevel::O2;
268 #if (LLVM_VERSION_MAJOR < 6)
269 FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
270 #elif (LLVM_VERSION_MAJOR < 12)
271 FPM = PB.buildFunctionSimplificationPipeline(
272 pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
274 FPM = PB.buildFunctionSimplificationPipeline(
275 pb_opt_level, llvm::ThinOrFullLTOPhase::None);
286 .setEngineKind(llvm::EngineKind::Kind::JIT)
287 .setOptLevel(
static_cast<CodeGenOptLevel
>(opt_level))
292 class MemoryBufferRefCallback :
public llvm::ObjectCache
296 MemoryBufferRefCallback(
std::string &ss) : ss_(ss) {}
298 void notifyObjectCompiled(
const llvm::Module *M,
299 llvm::MemoryBufferRef obj)
override
301 const char *c = obj.getBufferStart();
303 ss_.
assign(c, obj.getBufferSize());
307 getObject(
const llvm::Module *M)
override
313 MemoryBufferRefCallback callback(membuffer);
314 executionengine->setObjectCache(&callback);
316 executionengine->finalizeObject();
319 func = (intptr_t)executionengine->getPointerToFunction(F);
321 replacement_symbol_ptrs.clear();
325 LLVMDoubleVisitor::LLVMDoubleVisitor() =
default;
326 LLVMDoubleVisitor::~LLVMDoubleVisitor() =
default;
331 ((double (*)(
const double *,
double *))func)(vec.
data(), &ret);
335 void LLVMDoubleVisitor::call(
double *outs,
const double *inps)
const
337 ((double (*)(
const double *,
double *))func)(inps, outs);
340 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
345 ((
long double (*)(
const long double *,
long double *))func)(vec.
data(),
350 void LLVMLongDoubleVisitor::call(
long double *outs,
351 const long double *inps)
const
353 ((
long double (*)(
const long double *,
long double *))func)(inps, outs);
357 LLVMFloatVisitor::LLVMFloatVisitor() =
default;
358 LLVMFloatVisitor::~LLVMFloatVisitor() =
default;
363 ((float (*)(
const float *,
float *))func)(vec.
data(), &ret);
367 void LLVMFloatVisitor::call(
float *outs,
const float *inps)
const
369 ((float (*)(
const float *,
float *))func)(inps, outs);
372 void LLVMVisitor::set_double(
double d)
374 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()), d);
377 void LLVMVisitor::bvisit(
const Integer &x)
379 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
380 mp_get_d(x.as_integer_class()));
383 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
385 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() =
default;
386 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() =
default;
388 void LLVMLongDoubleVisitor::convert_from_mpfr(
const Basic &x)
390 #ifndef HAVE_SYMENGINE_MPFR
391 throw NotImplementedError(
"Cannot convert to long double without MPFR");
393 RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
394 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
399 void LLVMLongDoubleVisitor::visit(
const Integer &x)
401 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
406 void LLVMVisitor::bvisit(
const Rational &x)
408 set_double(mp_get_d(x.as_rational_class()));
411 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
412 void LLVMLongDoubleVisitor::visit(
const Rational &x)
414 convert_from_mpfr(x);
418 void LLVMVisitor::bvisit(
const RealDouble &x)
423 #ifdef HAVE_SYMENGINE_MPFR
424 void LLVMVisitor::bvisit(
const RealMPFR &x)
426 set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
428 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
429 void LLVMLongDoubleVisitor::visit(
const RealMPFR &x)
431 convert_from_mpfr(x);
436 void LLVMVisitor::bvisit(
const Add &x)
438 llvm::Value *tmp, *tmp1, *tmp2;
439 auto it = x.get_dict().begin();
441 if (
eq(*x.get_coef(), *zero)) {
443 if (
eq(*one, *(it->second))) {
444 tmp = apply(*(it->first));
446 tmp1 = apply(*(it->first));
447 tmp2 = apply(*(it->second));
448 tmp = builder->CreateFMul(tmp1, tmp2);
452 tmp = apply(*x.get_coef());
455 for (; it != x.get_dict().
end(); ++it) {
456 if (
eq(*one, *(it->second))) {
457 tmp1 = apply(*(it->first));
458 tmp = builder->CreateFAdd(tmp, tmp1);
465 tmp1 = apply(*(it->first));
466 tmp2 = apply(*(it->second));
467 tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
473 void LLVMVisitor::bvisit(
const Mul &x)
475 llvm::Value *tmp =
nullptr;
477 for (
const auto &p : x.get_args()) {
481 tmp = builder->CreateFMul(tmp, apply(*p));
488 llvm::Function *LLVMVisitor::get_powi()
492 #if (LLVM_VERSION_MAJOR > 12)
493 arg_type.
push_back(llvm::Type::getInt32Ty(
mod->getContext()));
495 return llvm::Intrinsic::getDeclaration(
mod, llvm::Intrinsic::powi,
499 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID
id,
500 unsigned n, llvm::Module *
mod)
503 return llvm::Intrinsic::getDeclaration(
mod,
id, arg_type);
506 void LLVMVisitor::bvisit(
const Pow &x)
510 if (
eq(*(x.get_base()), *E)) {
512 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
513 llvm::Intrinsic::exp, 1,
mod);
515 }
else if (
eq(*(x.get_base()), *
integer(2))) {
517 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
518 llvm::Intrinsic::exp2, 1,
mod);
521 if (is_a<Integer>(*x.get_exp())) {
523 llvm::Value *tmp = apply(*x.get_base());
524 result_ = builder->CreateFMul(tmp, tmp);
528 int d = numeric_cast<int>(
529 mp_get_si(
static_cast<const Integer &
>(*x.get_exp())
530 .as_integer_class()));
531 result_ = llvm::ConstantInt::get(
532 llvm::Type::getInt32Ty(
mod->getContext()), d,
true);
539 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
540 llvm::Intrinsic::pow, 1,
mod);
543 auto r = builder->CreateCall(fun, args);
544 r->setTailCall(
true);
548 void LLVMVisitor::bvisit(
const Sin &x)
553 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
554 llvm::Intrinsic::sin, 1,
mod);
555 auto r = builder->CreateCall(fun, args);
556 r->setTailCall(
true);
560 void LLVMVisitor::bvisit(
const Cos &x)
565 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
566 llvm::Intrinsic::cos, 1,
mod);
567 auto r = builder->CreateCall(fun, args);
568 r->setTailCall(
true);
572 void LLVMVisitor::bvisit(
const Piecewise &x)
576 RCP<const Piecewise> pw = x.rcp_from_this_cast<
const Piecewise>();
578 if (
neq(*pw->get_vec().back().second, *boolTrue)) {
579 throw SymEngineException(
580 "LLVMDouble requires a (Expr, True) at the end of Piecewise");
583 if (pw->get_vec().size() > 2) {
584 PiecewiseVec rest = pw->get_vec();
585 rest.
erase(rest.begin());
586 auto rest_pw = piecewise(
std::move(rest));
588 new_pw.
push_back(*pw->get_vec().begin());
589 new_pw.push_back({rest_pw, pw->get_vec().back().second});
591 ->rcp_from_this_cast<
const Piecewise>();
592 }
else if (pw->get_vec().size() < 2) {
593 throw SymEngineException(
"Invalid Piecewise object");
596 auto cond_basic = pw->get_vec().front().second;
597 llvm::Value *cond = apply(*cond_basic);
599 cond = builder->CreateFCmpONE(
600 cond, llvm::ConstantFP::get(get_float_type(&
mod->getContext()), 0.0),
602 llvm::Function *
function = builder->GetInsertBlock()->getParent();
606 llvm::BasicBlock *then_bb
607 = llvm::BasicBlock::Create(
mod->getContext(),
"then",
function);
608 llvm::BasicBlock *else_bb
609 = llvm::BasicBlock::Create(
mod->getContext(),
"else");
610 llvm::BasicBlock *merge_bb
611 = llvm::BasicBlock::Create(
mod->getContext(),
"ifcont");
612 builder->CreateCondBr(cond, then_bb, else_bb);
615 builder->SetInsertPoint(then_bb);
616 llvm::Value *then_value = apply(*pw->get_vec().front().first);
617 builder->CreateBr(merge_bb);
621 then_bb = builder->GetInsertBlock();
624 #if (LLVM_VERSION_MAJOR < 16)
625 function->getBasicBlockList().push_back(else_bb);
627 function->insert(function->end(), else_bb);
629 builder->SetInsertPoint(else_bb);
630 llvm::Value *else_value = apply(*pw->get_vec().back().first);
631 builder->CreateBr(merge_bb);
635 else_bb = builder->GetInsertBlock();
638 #if (LLVM_VERSION_MAJOR < 16)
639 function->getBasicBlockList().push_back(merge_bb);
641 function->insert(function->end(), merge_bb);
643 builder->SetInsertPoint(merge_bb);
644 llvm::PHINode *phi_node
645 = builder->CreatePHI(get_float_type(&
mod->getContext()), 2);
647 phi_node->addIncoming(then_value, then_bb);
648 phi_node->addIncoming(else_value, else_bb);
652 void LLVMVisitor::bvisit(
const Sign &x)
654 const auto x2 = x.get_arg();
656 new_pw.
push_back({real_double(0.0),
Eq(x2, real_double(0.0))});
657 new_pw.push_back({real_double(-1.0),
Lt(x2, real_double(0.0))});
658 new_pw.push_back({real_double(1.0), boolTrue});
659 auto pw = rcp_static_cast<const Piecewise>(piecewise(
std::move(new_pw)));
663 void LLVMVisitor::bvisit(
const Contains &cts)
665 llvm::Value *expr = apply(*cts.get_expr());
666 const auto set = cts.get_set();
667 if (is_a<Interval>(*set)) {
668 const auto &interv = down_cast<const Interval &>(*set);
669 llvm::Value *start = apply(*interv.get_start());
670 llvm::Value *
end = apply(*interv.get_end());
671 const bool left_open = interv.get_left_open();
672 const bool right_open = interv.get_right_open();
673 llvm::Value *left_ok;
674 llvm::Value *right_ok;
675 left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
676 : builder->CreateFCmpOLE(start, expr);
677 right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
678 : builder->CreateFCmpOLE(expr, end);
679 result_ = builder->CreateAnd(left_ok, right_ok);
680 result_ = builder->CreateUIToFP(result_,
681 get_float_type(&
mod->getContext()));
683 throw SymEngineException(
"LLVMVisitor: only ``Interval`` "
684 "implemented for ``Contains``.");
688 void LLVMVisitor::bvisit(
const Infty &x)
690 if (x.is_negative_infinity()) {
691 result_ = llvm::ConstantFP::getInfinity(
692 get_float_type(&
mod->getContext()),
true);
693 }
else if (x.is_positive_infinity()) {
694 result_ = llvm::ConstantFP::getInfinity(
695 get_float_type(&
mod->getContext()),
false);
697 throw SymEngineException(
698 "LLVMDouble can only represent real valued infinity");
702 void LLVMVisitor::bvisit(
const NaN &x)
704 result_ = llvm::ConstantFP::getNaN(get_float_type(&
mod->getContext()),
708 void LLVMVisitor::bvisit(
const BooleanAtom &x)
710 const bool val = x.get_val();
711 set_double(val ? 1.0 : 0.0);
714 void LLVMVisitor::bvisit(
const Log &x)
719 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
720 llvm::Intrinsic::log, 1,
mod);
721 auto r = builder->CreateCall(fun, args);
722 r->setTailCall(
true);
726 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
727 void LLVMVisitor::bvisit(const Class &x) \
729 llvm::Value *value = nullptr; \
732 llvm::Value *zero_val = result_; \
733 for (auto &p : x.get_container()) { \
734 tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
735 if (value == nullptr) { \
738 value = builder->method(value, tmp); \
741 result_ = builder->CreateUIToFP(value, \
742 get_float_type(&mod->getContext())); \
745 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
746 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
747 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
749 void LLVMVisitor::bvisit(
const Not &x)
751 builder->CreateNot(apply(*x.get_arg()));
754 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
755 void LLVMVisitor::bvisit(const Class &x) \
757 llvm::Value *left = apply(*x.get_arg1()); \
758 llvm::Value *right = apply(*x.get_arg2()); \
759 result_ = builder->method(left, right); \
760 result_ = builder->CreateUIToFP(result_, \
761 get_float_type(&mod->getContext())); \
764 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
765 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
766 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
767 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
769 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
770 void LLVMDoubleVisitor::visit(const Class &x) \
772 vec_basic basic_args = x.get_args(); \
773 llvm::Function *func = get_external_function(#ext, basic_args.size()); \
774 std::vector<llvm::Value *> args; \
775 for (const auto &arg : basic_args) { \
776 args.push_back(apply(*arg)); \
778 auto r = builder->CreateCall(func, args); \
779 r->setTailCall(true); \
782 void LLVMFloatVisitor::visit(const Class &x) \
784 vec_basic basic_args = x.get_args(); \
785 llvm::Function *func = get_external_function(#ext + std::string("f"), \
786 basic_args.size()); \
787 std::vector<llvm::Value *> args; \
788 for (const auto &arg : basic_args) { \
789 args.push_back(apply(*arg)); \
791 auto r = builder->CreateCall(func, args); \
792 r->setTailCall(true); \
796 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
797 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
798 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
799 void LLVMLongDoubleVisitor::visit(const Class &x) \
801 vec_basic basic_args = x.get_args(); \
802 llvm::Function *func = get_external_function(#ext + std::string("l"), \
803 basic_args.size()); \
804 std::vector<llvm::Value *> args; \
805 for (const auto &arg : basic_args) { \
806 args.push_back(apply(*arg)); \
808 auto r = builder->CreateCall(func, args); \
809 r->setTailCall(true); \
813 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
814 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
817 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan,
tan)
818 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin,
asin)
819 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos,
acos)
820 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan,
atan)
821 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2,
atan2)
822 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh,
sinh)
823 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh,
cosh)
824 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh,
tanh)
825 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh,
asinh)
826 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh,
acosh)
827 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh,
atanh)
828 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
829 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
830 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf,
erf)
831 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc,
erfc)
833 void LLVMVisitor::bvisit(
const Abs &x)
838 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
839 llvm::Intrinsic::fabs, 1,
mod);
840 auto r = builder->CreateCall(fun, args);
841 r->setTailCall(
true);
845 void LLVMVisitor::bvisit(
const Min &x)
847 llvm::Value *value =
nullptr;
849 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
850 llvm::Intrinsic::minnum, 1,
mod);
851 for (
auto &arg : x.get_vec()) {
852 if (value !=
nullptr) {
856 auto r = builder->CreateCall(fun, args);
857 r->setTailCall(
true);
866 void LLVMVisitor::bvisit(
const Max &x)
868 llvm::Value *value =
nullptr;
870 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
871 llvm::Intrinsic::maxnum, 1,
mod);
872 for (
auto &arg : x.get_vec()) {
873 if (value !=
nullptr) {
877 auto r = builder->CreateCall(fun, args);
878 r->setTailCall(
true);
887 void LLVMVisitor::bvisit(
const Symbol &x)
890 for (
auto &symb : symbols) {
892 result_ = symbol_ptrs[i];
897 auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
898 if (it != replacement_symbol_ptrs.end()) {
899 result_ = it->second;
903 throw SymEngineException(
"Symbol " + x.__str__()
904 +
" not in the symbols vector.");
907 llvm::Function *LLVMVisitor::get_external_function(
const std::string &name,
911 get_float_type(&
mod->getContext()));
912 llvm::FunctionType *func_type = llvm::FunctionType::get(
913 get_float_type(&
mod->getContext()), func_args,
false);
915 llvm::Function *func =
mod->getFunction(name);
917 func = llvm::Function::Create(
918 func_type, llvm::GlobalValue::ExternalLinkage, name,
mod);
919 func->setCallingConv(llvm::CallingConv::C);
921 #if (LLVM_VERSION_MAJOR < 5)
922 llvm::AttributeSet func_attr_set;
924 llvm::SmallVector<llvm::AttributeSet, 4> attrs;
925 llvm::AttributeSet attr_set;
927 llvm::AttrBuilder attr_builder;
928 attr_builder.addAttribute(llvm::Attribute::NoUnwind);
930 = llvm::AttributeSet::get(
mod->getContext(), ~0U, attr_builder);
933 attrs.push_back(attr_set);
934 func_attr_set = llvm::AttributeSet::get(
mod->getContext(), attrs);
936 func->setAttributes(func_attr_set);
938 func->addFnAttr(llvm::Attribute::NoUnwind);
943 void LLVMVisitor::bvisit(
const Constant &x)
945 set_double(eval_double(x));
948 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
949 void LLVMLongDoubleVisitor::visit(
const Constant &x)
951 convert_from_mpfr(x);
955 void LLVMVisitor::bvisit(
const Basic &x)
957 throw NotImplementedError(x.__str__());
968 llvm::InitializeNativeTarget();
969 llvm::InitializeNativeTargetAsmPrinter();
970 llvm::InitializeNativeTargetAsmParser();
971 context = make_unique<llvm::LLVMContext>();
975 = make_unique<llvm::Module>(
"SymEngine", *context);
976 module->setDataLayout(
"");
983 auto F = get_function_type(context.get());
988 .setEngineKind(llvm::EngineKind::Kind::JIT)
989 .setOptLevel(CodeGenOptLevel::Aggressive)
993 class MCJITObjectLoader :
public llvm::ObjectCache
999 void notifyObjectCompiled(
const llvm::Module *M,
1000 llvm::MemoryBufferRef obj)
override
1007 getObject(
const llvm::Module *M)
override
1009 return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
1013 MCJITObjectLoader loader(s);
1014 executionengine->setObjectCache(&loader);
1015 executionengine->finalizeObject();
1017 func = (intptr_t)executionengine->getPointerToFunction(F);
1020 void LLVMVisitor::bvisit(
const Floor &x)
1023 llvm::Function *fun;
1025 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1026 llvm::Intrinsic::floor, 1,
mod);
1027 auto r = builder->CreateCall(fun, args);
1028 r->setTailCall(
true);
1032 void LLVMVisitor::bvisit(
const Ceiling &x)
1035 llvm::Function *fun;
1037 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1038 llvm::Intrinsic::ceil, 1,
mod);
1039 auto r = builder->CreateCall(fun, args);
1040 r->setTailCall(
true);
1044 void LLVMVisitor::bvisit(
const UnevaluatedExpr &x)
1046 apply(*x.get_arg());
1049 void LLVMVisitor::bvisit(
const Truncate &x)
1052 llvm::Function *fun;
1054 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1055 llvm::Intrinsic::trunc, 1,
mod);
1056 auto r = builder->CreateCall(fun, args);
1057 r->setTailCall(
true);
1061 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1063 return llvm::Type::getDoubleTy(*context);
1066 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1068 return llvm::Type::getFloatTy(*context);
1071 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1072 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1074 return llvm::Type::getX86_FP80Ty(*context);
The lowest unit of symbolic representation.
Main namespace for SymEngine package.
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh: