1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
37 #include <symengine/llvm_double.h>
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
47 using llvm::make_unique;
49 #if (LLVM_VERSION_MAJOR < 18)
50 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
52 using CodeGenOptLevel = llvm::CodeGenOptLevel;
55 #if (LLVM_VERSION_MAJOR < 20)
56 const auto &GetDeclaration = llvm::Intrinsic::getDeclaration;
58 const auto &GetDeclaration = llvm::Intrinsic::getOrInsertDeclaration;
61 #if (LLVM_VERSION_MAJOR >= 21)
62 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
64 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
71 LLVMVisitor::LLVMVisitor() =
default;
72 LLVMVisitor::~LLVMVisitor() =
default;
74 llvm::Value *LLVMVisitor::apply(
const Basic &b)
80 void LLVMVisitor::init(
const vec_basic &x,
const Basic &b,
bool symbolic_cse,
83 init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
86 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
88 std::vector<llvm::Type *> inp;
89 for (
int i = 0; i < 2; i++) {
90 inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
92 llvm::FunctionType *function_type = llvm::FunctionType::get(
93 llvm::Type::getVoidTy(*context), inp,
false);
94 auto F = llvm::Function::Create(
95 function_type, llvm::Function::InternalLinkage,
"symengine_func",
mod);
96 F->setCallingConv(llvm::CallingConv::C);
97 #if (LLVM_VERSION_MAJOR < 5)
99 llvm::SmallVector<llvm::AttributeSet, 4> attrs;
100 llvm::AttributeSet PAS;
103 B.addAttribute(llvm::Attribute::ReadOnly);
105 PAS = llvm::AttributeSet::get(
mod->getContext(), 1U, B);
108 attrs.push_back(PAS);
112 PAS = llvm::AttributeSet::get(
mod->getContext(), 2U, B);
115 attrs.push_back(PAS);
118 B.addAttribute(llvm::Attribute::NoUnwind);
119 B.addAttribute(llvm::Attribute::UWTable);
120 PAS = llvm::AttributeSet::get(
mod->getContext(), ~0U, B);
123 attrs.push_back(PAS);
125 F->setAttributes(llvm::AttributeSet::get(
mod->getContext(), attrs));
128 F->addParamAttr(0, llvm::Attribute::ReadOnly);
129 #if (LLVM_VERSION_MAJOR >= 21)
130 F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
131 *context, llvm::CaptureInfo::none()));
132 F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
133 *context, llvm::CaptureInfo::none()));
135 F->addParamAttr(0, llvm::Attribute::NoCapture);
136 F->addParamAttr(1, llvm::Attribute::NoCapture);
138 F->addFnAttr(llvm::Attribute::NoUnwind);
139 #if (LLVM_VERSION_MAJOR < 15)
140 F->addFnAttr(llvm::Attribute::UWTable);
142 F->addFnAttr(llvm::Attribute::getWithUWTableKind(
143 *context, llvm::UWTableKind::Default));
149 void LLVMVisitor::init(
const vec_basic &inputs,
const vec_basic &outputs,
150 const bool symbolic_cse,
unsigned opt_level)
152 executionengine.reset();
153 llvm::InitializeNativeTarget();
154 llvm::InitializeNativeTargetAsmPrinter();
155 llvm::InitializeNativeTargetAsmParser();
156 context = make_unique<llvm::LLVMContext>();
160 std::unique_ptr<llvm::Module> module
161 = make_unique<llvm::Module>(
"SymEngine", *context.get());
162 module->setDataLayout(
"");
165 auto F = get_function_type(context.get());
170 llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context,
"EntryBlock", F);
175 llvm::IRBuilder<> _builder(BB);
176 builder =
reinterpret_cast<IRBuilder *
>(&_builder);
177 builder->SetInsertPoint(BB);
178 auto fmf = llvm::FastMathFlags();
180 builder->setFastMathFlags(fmf);
183 auto input_arg = &(*(F->args().begin()));
184 for (
unsigned i = 0; i < inputs.size(); i++) {
185 if (not is_a<Symbol>(*inputs[i])) {
186 throw SymEngineException(
"Input contains a non-symbol.");
189 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
190 auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
192 result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
193 symbol_ptrs.push_back(result_);
196 auto it = F->args().begin();
197 #if (LLVM_VERSION_MAJOR < 5)
198 auto out = &(*(++it));
200 auto out = &(*(it + 1));
202 std::vector<llvm::Value *> output_vals;
205 vec_basic reduced_exprs;
206 vec_pair replacements;
208 SymEngine::cse(replacements, reduced_exprs, outputs);
209 for (
auto &rep : replacements) {
211 replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
214 for (
unsigned i = 0; i < outputs.size(); i++) {
215 output_vals.push_back(apply(*reduced_exprs[i]));
219 for (
unsigned i = 0; i < outputs.size(); i++) {
220 output_vals.push_back(apply(*outputs[i]));
225 for (
unsigned i = 0; i < outputs.size(); i++) {
227 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
229 = builder->CreateGEP(get_float_type(context.get()), out, index);
230 builder->CreateStore(output_vals[i], ptr);
234 builder->CreateRetVoid();
237 llvm::verifyFunction(*F, &llvm::outs());
250 #if (LLVM_VERSION_MAJOR < 14)
251 using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
253 using OptimizationLevel = llvm::OptimizationLevel;
255 llvm::PassBuilder PB;
256 llvm::ModuleAnalysisManager MAM;
257 llvm::CGSCCAnalysisManager CGAM;
258 llvm::FunctionAnalysisManager FAM;
259 llvm::LoopAnalysisManager LAM;
260 PB.registerModuleAnalyses(MAM);
261 PB.registerCGSCCAnalyses(CGAM);
262 PB.registerFunctionAnalyses(FAM);
263 PB.registerLoopAnalyses(LAM);
264 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
265 llvm::FunctionPassManager FPM;
266 OptimizationLevel pb_opt_level{OptimizationLevel::O3};
267 if (opt_level == 0) {
268 pb_opt_level = OptimizationLevel::O0;
269 }
else if (opt_level == 1) {
270 pb_opt_level = OptimizationLevel::O1;
271 }
else if (opt_level == 2) {
272 pb_opt_level = OptimizationLevel::O2;
275 if (opt_level != 0) {
276 #if (LLVM_VERSION_MAJOR < 6)
277 FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
278 #elif (LLVM_VERSION_MAJOR < 12)
279 FPM = PB.buildFunctionSimplificationPipeline(
280 pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
282 FPM = PB.buildFunctionSimplificationPipeline(
283 pb_opt_level, llvm::ThinOrFullLTOPhase::None);
293 executionengine = std::unique_ptr<llvm::ExecutionEngine>(
294 llvm::EngineBuilder(std::move(module))
295 .setEngineKind(llvm::EngineKind::Kind::JIT)
296 .setOptLevel(
static_cast<CodeGenOptLevel
>(opt_level))
301 class MemoryBufferRefCallback :
public llvm::ObjectCache
305 explicit MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
307 void notifyObjectCompiled(
const llvm::Module *M,
308 llvm::MemoryBufferRef obj)
override
310 const char *c = obj.getBufferStart();
312 ss_.assign(c, obj.getBufferSize());
315 std::unique_ptr<llvm::MemoryBuffer>
316 getObject(
const llvm::Module *M)
override
322 MemoryBufferRefCallback callback(membuffer);
323 executionengine->setObjectCache(&callback);
325 executionengine->finalizeObject();
328 func = (intptr_t)executionengine->getPointerToFunction(F);
330 replacement_symbol_ptrs.clear();
334 LLVMDoubleVisitor::LLVMDoubleVisitor() =
default;
335 LLVMDoubleVisitor::~LLVMDoubleVisitor() =
default;
337 double LLVMDoubleVisitor::call(
const std::vector<double> &vec)
const
340 ((double (*)(
const double *,
double *))func)(vec.data(), &ret);
344 void LLVMDoubleVisitor::call(
double *outs,
const double *inps)
const
346 ((double (*)(
const double *,
double *))func)(inps, outs);
349 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
351 LLVMLongDoubleVisitor::call(
const std::vector<long double> &vec)
const
354 ((
long double (*)(
const long double *,
long double *))func)(vec.data(),
359 void LLVMLongDoubleVisitor::call(
long double *outs,
360 const long double *inps)
const
362 ((
long double (*)(
const long double *,
long double *))func)(inps, outs);
366 LLVMFloatVisitor::LLVMFloatVisitor() =
default;
367 LLVMFloatVisitor::~LLVMFloatVisitor() =
default;
369 float LLVMFloatVisitor::call(
const std::vector<float> &vec)
const
372 ((float (*)(
const float *,
float *))func)(vec.data(), &ret);
376 void LLVMFloatVisitor::call(
float *outs,
const float *inps)
const
378 ((float (*)(
const float *,
float *))func)(inps, outs);
381 void LLVMVisitor::set_double(
double d)
383 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()), d);
386 void LLVMVisitor::bvisit(
const Integer &x)
388 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
389 mp_get_d(x.as_integer_class()));
392 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
394 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() =
default;
395 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() =
default;
397 void LLVMLongDoubleVisitor::convert_from_mpfr(
const Basic &x)
399 #ifndef HAVE_SYMENGINE_MPFR
400 throw NotImplementedError(
"Cannot convert to long double without MPFR");
402 RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
403 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
408 void LLVMLongDoubleVisitor::visit(
const Integer &x)
410 result_ = llvm::ConstantFP::get(get_float_type(&
mod->getContext()),
415 void LLVMVisitor::bvisit(
const Rational &x)
417 set_double(mp_get_d(x.as_rational_class()));
420 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
421 void LLVMLongDoubleVisitor::visit(
const Rational &x)
423 convert_from_mpfr(x);
427 void LLVMVisitor::bvisit(
const RealDouble &x)
432 #ifdef HAVE_SYMENGINE_MPFR
433 void LLVMVisitor::bvisit(
const RealMPFR &x)
435 set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
437 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
438 void LLVMLongDoubleVisitor::visit(
const RealMPFR &x)
440 convert_from_mpfr(x);
445 void LLVMVisitor::bvisit(
const Add &x)
447 llvm::Value *tmp, *tmp1, *tmp2;
448 auto it = x.get_dict().begin();
450 if (
eq(*x.get_coef(), *zero)) {
452 if (
eq(*one, *(it->second))) {
453 tmp = apply(*(it->first));
455 tmp1 = apply(*(it->first));
456 tmp2 = apply(*(it->second));
457 tmp = builder->CreateFMul(tmp1, tmp2);
461 tmp = apply(*x.get_coef());
464 for (; it != x.get_dict().end(); ++it) {
465 if (
eq(*one, *(it->second))) {
466 tmp1 = apply(*(it->first));
467 tmp = builder->CreateFAdd(tmp, tmp1);
474 tmp1 = apply(*(it->first));
475 tmp2 = apply(*(it->second));
476 tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
482 void LLVMVisitor::bvisit(
const Mul &x)
484 llvm::Value *tmp =
nullptr;
486 for (
const auto &p : x.get_args()) {
490 tmp = builder->CreateFMul(tmp, apply(*p));
497 llvm::Function *LLVMVisitor::get_powi()
499 std::vector<llvm::Type *> arg_type;
500 arg_type.push_back(get_float_type(&
mod->getContext()));
501 #if (LLVM_VERSION_MAJOR > 12)
502 arg_type.push_back(llvm::Type::getInt32Ty(
mod->getContext()));
504 return GetDeclaration(
mod, llvm::Intrinsic::powi, arg_type);
507 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID
id,
508 unsigned n, llvm::Module *
mod)
510 std::vector<llvm::Type *> arg_type(n, type);
511 return GetDeclaration(
mod,
id, arg_type);
514 void LLVMVisitor::bvisit(
const Pow &x)
516 std::vector<llvm::Value *> args;
518 if (
eq(*(x.get_base()), *E)) {
519 args.push_back(apply(*x.get_exp()));
520 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
521 llvm::Intrinsic::exp, 1,
mod);
523 }
else if (
eq(*(x.get_base()), *
integer(2))) {
524 args.push_back(apply(*x.get_exp()));
525 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
526 llvm::Intrinsic::exp2, 1,
mod);
529 if (is_a<Integer>(*x.get_exp())) {
531 llvm::Value *tmp = apply(*x.get_base());
532 result_ = builder->CreateFMul(tmp, tmp);
535 args.push_back(apply(*x.get_base()));
536 int d = numeric_cast<int>(
537 mp_get_si(
static_cast<const Integer &
>(*x.get_exp())
538 .as_integer_class()));
539 result_ = llvm::ConstantInt::get(
540 llvm::Type::getInt32Ty(
mod->getContext()), d,
true);
541 args.push_back(result_);
545 args.push_back(apply(*x.get_base()));
546 args.push_back(apply(*x.get_exp()));
547 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
548 llvm::Intrinsic::pow, 1,
mod);
551 auto r = builder->CreateCall(fun, args);
552 r->setTailCall(
true);
556 void LLVMVisitor::bvisit(
const Sin &x)
558 std::vector<llvm::Value *> args;
560 args.push_back(apply(*x.get_arg()));
561 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
562 llvm::Intrinsic::sin, 1,
mod);
563 auto r = builder->CreateCall(fun, args);
564 r->setTailCall(
true);
568 void LLVMVisitor::bvisit(
const Cos &x)
570 std::vector<llvm::Value *> args;
572 args.push_back(apply(*x.get_arg()));
573 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
574 llvm::Intrinsic::cos, 1,
mod);
575 auto r = builder->CreateCall(fun, args);
576 r->setTailCall(
true);
580 void LLVMVisitor::bvisit(
const Piecewise &x)
582 std::vector<llvm::BasicBlock> blocks;
584 RCP<const Piecewise> pw = x.rcp_from_this_cast<
const Piecewise>();
586 if (
neq(*pw->get_vec().back().second, *boolTrue)) {
587 throw SymEngineException(
588 "LLVMDouble requires a (Expr, True) at the end of Piecewise");
591 if (pw->get_vec().size() > 2) {
592 PiecewiseVec rest = pw->get_vec();
593 rest.erase(rest.begin());
594 auto rest_pw = piecewise(std::move(rest));
596 new_pw.push_back(*pw->get_vec().begin());
597 new_pw.push_back({rest_pw, pw->get_vec().back().second});
598 pw = piecewise(std::move(new_pw))
599 ->rcp_from_this_cast<
const Piecewise>();
600 }
else if (pw->get_vec().size() < 2) {
601 throw SymEngineException(
"Invalid Piecewise object");
604 auto cond_basic = pw->get_vec().front().second;
605 llvm::Value *cond = apply(*cond_basic);
607 cond = builder->CreateFCmpONE(
608 cond, llvm::ConstantFP::get(get_float_type(&
mod->getContext()), 0.0),
610 llvm::Function *
function = builder->GetInsertBlock()->getParent();
614 llvm::BasicBlock *then_bb
615 = llvm::BasicBlock::Create(
mod->getContext(),
"then",
function);
616 llvm::BasicBlock *else_bb
617 = llvm::BasicBlock::Create(
mod->getContext(),
"else");
618 llvm::BasicBlock *merge_bb
619 = llvm::BasicBlock::Create(
mod->getContext(),
"ifcont");
620 builder->CreateCondBr(cond, then_bb, else_bb);
623 builder->SetInsertPoint(then_bb);
624 llvm::Value *then_value = apply(*pw->get_vec().front().first);
625 builder->CreateBr(merge_bb);
629 then_bb = builder->GetInsertBlock();
632 #if (LLVM_VERSION_MAJOR < 16)
633 function->getBasicBlockList().push_back(else_bb);
635 function->insert(function->end(), else_bb);
637 builder->SetInsertPoint(else_bb);
638 llvm::Value *else_value = apply(*pw->get_vec().back().first);
639 builder->CreateBr(merge_bb);
643 else_bb = builder->GetInsertBlock();
646 #if (LLVM_VERSION_MAJOR < 16)
647 function->getBasicBlockList().push_back(merge_bb);
649 function->insert(function->end(), merge_bb);
651 builder->SetInsertPoint(merge_bb);
652 llvm::PHINode *phi_node
653 = builder->CreatePHI(get_float_type(&
mod->getContext()), 2);
655 phi_node->addIncoming(then_value, then_bb);
656 phi_node->addIncoming(else_value, else_bb);
660 void LLVMVisitor::bvisit(
const Sign &x)
662 const auto x2 = x.get_arg();
664 new_pw.push_back({real_double(0.0),
Eq(x2, real_double(0.0))});
665 new_pw.push_back({real_double(-1.0),
Lt(x2, real_double(0.0))});
666 new_pw.push_back({real_double(1.0), boolTrue});
667 auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
671 void LLVMVisitor::bvisit(
const Contains &cts)
673 llvm::Value *expr = apply(*cts.get_expr());
674 const auto set = cts.get_set();
675 if (is_a<Interval>(*set)) {
676 const auto &interv = down_cast<const Interval &>(*set);
677 llvm::Value *start = apply(*interv.get_start());
678 llvm::Value *end = apply(*interv.get_end());
679 const bool left_open = interv.get_left_open();
680 const bool right_open = interv.get_right_open();
681 llvm::Value *left_ok;
682 llvm::Value *right_ok;
683 left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
684 : builder->CreateFCmpOLE(start, expr);
685 right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
686 : builder->CreateFCmpOLE(expr, end);
687 result_ = builder->CreateAnd(left_ok, right_ok);
688 result_ = builder->CreateUIToFP(result_,
689 get_float_type(&
mod->getContext()));
691 throw SymEngineException(
"LLVMVisitor: only ``Interval`` "
692 "implemented for ``Contains``.");
696 void LLVMVisitor::bvisit(
const Infty &x)
698 if (x.is_negative_infinity()) {
699 result_ = llvm::ConstantFP::getInfinity(
700 get_float_type(&
mod->getContext()),
true);
701 }
else if (x.is_positive_infinity()) {
702 result_ = llvm::ConstantFP::getInfinity(
703 get_float_type(&
mod->getContext()),
false);
705 throw SymEngineException(
706 "LLVMDouble can only represent real valued infinity");
710 void LLVMVisitor::bvisit(
const NaN &x)
712 result_ = llvm::ConstantFP::getNaN(get_float_type(&
mod->getContext()),
716 void LLVMVisitor::bvisit(
const BooleanAtom &x)
718 const bool val = x.get_val();
719 set_double(val ? 1.0 : 0.0);
722 void LLVMVisitor::bvisit(
const Log &x)
724 std::vector<llvm::Value *> args;
726 args.push_back(apply(*x.get_arg()));
727 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
728 llvm::Intrinsic::log, 1,
mod);
729 auto r = builder->CreateCall(fun, args);
730 r->setTailCall(
true);
734 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
735 void LLVMVisitor::bvisit(const Class &x) \
737 llvm::Value *value = nullptr; \
740 llvm::Value *zero_val = result_; \
741 for (auto &p : x.get_container()) { \
742 tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
743 if (value == nullptr) { \
746 value = builder->method(value, tmp); \
749 result_ = builder->CreateUIToFP(value, \
750 get_float_type(&mod->getContext())); \
753 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
754 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
755 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
757 void LLVMVisitor::bvisit(
const Not &x)
759 builder->CreateNot(apply(*x.get_arg()));
762 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
763 void LLVMVisitor::bvisit(const Class &x) \
765 llvm::Value *left = apply(*x.get_arg1()); \
766 llvm::Value *right = apply(*x.get_arg2()); \
767 result_ = builder->method(left, right); \
768 result_ = builder->CreateUIToFP(result_, \
769 get_float_type(&mod->getContext())); \
772 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
773 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
774 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
775 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
777 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
778 void LLVMDoubleVisitor::visit(const Class &x) \
780 vec_basic basic_args = x.get_args(); \
781 llvm::Function *func = get_external_function(#ext, basic_args.size()); \
782 std::vector<llvm::Value *> args; \
783 for (const auto &arg : basic_args) { \
784 args.push_back(apply(*arg)); \
786 auto r = builder->CreateCall(func, args); \
787 r->setTailCall(true); \
790 void LLVMFloatVisitor::visit(const Class &x) \
792 vec_basic basic_args = x.get_args(); \
793 llvm::Function *func = get_external_function(#ext + std::string("f"), \
794 basic_args.size()); \
795 std::vector<llvm::Value *> args; \
796 for (const auto &arg : basic_args) { \
797 args.push_back(apply(*arg)); \
799 auto r = builder->CreateCall(func, args); \
800 r->setTailCall(true); \
804 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
805 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
806 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
807 void LLVMLongDoubleVisitor::visit(const Class &x) \
809 vec_basic basic_args = x.get_args(); \
810 llvm::Function *func = get_external_function(#ext + std::string("l"), \
811 basic_args.size()); \
812 std::vector<llvm::Value *> args; \
813 for (const auto &arg : basic_args) { \
814 args.push_back(apply(*arg)); \
816 auto r = builder->CreateCall(func, args); \
817 r->setTailCall(true); \
821 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
822 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
825 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan,
tan)
826 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin,
asin)
827 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos,
acos)
828 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan,
atan)
829 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2,
atan2)
830 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh,
sinh)
831 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh,
cosh)
832 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh,
tanh)
833 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh,
asinh)
834 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh,
acosh)
835 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh,
atanh)
836 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
837 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
838 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf,
erf)
839 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc,
erfc)
841 void LLVMVisitor::bvisit(
const Abs &x)
843 std::vector<llvm::Value *> args;
845 args.push_back(apply(*x.get_arg()));
846 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
847 llvm::Intrinsic::fabs, 1,
mod);
848 auto r = builder->CreateCall(fun, args);
849 r->setTailCall(
true);
853 void LLVMVisitor::bvisit(
const Min &x)
855 llvm::Value *value =
nullptr;
857 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
858 llvm::Intrinsic::minnum, 1,
mod);
859 for (
auto &arg : x.get_vec()) {
860 if (value !=
nullptr) {
861 std::vector<llvm::Value *> args;
862 args.push_back(value);
863 args.push_back(apply(*arg));
864 auto r = builder->CreateCall(fun, args);
865 r->setTailCall(
true);
874 void LLVMVisitor::bvisit(
const Max &x)
876 llvm::Value *value =
nullptr;
878 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
879 llvm::Intrinsic::maxnum, 1,
mod);
880 for (
auto &arg : x.get_vec()) {
881 if (value !=
nullptr) {
882 std::vector<llvm::Value *> args;
883 args.push_back(value);
884 args.push_back(apply(*arg));
885 auto r = builder->CreateCall(fun, args);
886 r->setTailCall(
true);
895 void LLVMVisitor::bvisit(
const Symbol &x)
898 for (
auto &symb : symbols) {
900 result_ = symbol_ptrs[i];
905 auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
906 if (it != replacement_symbol_ptrs.end()) {
907 result_ = it->second;
911 throw SymEngineException(
"Symbol " + x.__str__()
912 +
" not in the symbols vector.");
915 llvm::Function *LLVMVisitor::get_external_function(
const std::string &name,
918 std::vector<llvm::Type *> func_args(nargs,
919 get_float_type(&
mod->getContext()));
920 llvm::FunctionType *func_type = llvm::FunctionType::get(
921 get_float_type(&
mod->getContext()), func_args,
false);
923 llvm::Function *func =
mod->getFunction(name);
925 func = llvm::Function::Create(
926 func_type, llvm::GlobalValue::ExternalLinkage, name,
mod);
927 func->setCallingConv(llvm::CallingConv::C);
929 #if (LLVM_VERSION_MAJOR < 5)
930 llvm::AttributeSet func_attr_set;
932 llvm::SmallVector<llvm::AttributeSet, 4> attrs;
933 llvm::AttributeSet attr_set;
935 llvm::AttrBuilder attr_builder;
936 attr_builder.addAttribute(llvm::Attribute::NoUnwind);
938 = llvm::AttributeSet::get(
mod->getContext(), ~0U, attr_builder);
941 attrs.push_back(attr_set);
942 func_attr_set = llvm::AttributeSet::get(
mod->getContext(), attrs);
944 func->setAttributes(func_attr_set);
946 func->addFnAttr(llvm::Attribute::NoUnwind);
951 void LLVMVisitor::bvisit(
const Constant &x)
953 set_double(eval_double(x));
956 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
957 void LLVMLongDoubleVisitor::visit(
const Constant &x)
959 convert_from_mpfr(x);
963 void LLVMVisitor::bvisit(
const Basic &x)
965 throw NotImplementedError(x.__str__());
968 const std::string &LLVMVisitor::dumps()
const
973 void LLVMVisitor::loads(
const std::string &s)
976 llvm::InitializeNativeTarget();
977 llvm::InitializeNativeTargetAsmPrinter();
978 llvm::InitializeNativeTargetAsmParser();
979 context = make_unique<llvm::LLVMContext>();
982 std::unique_ptr<llvm::Module> module
983 = make_unique<llvm::Module>(
"SymEngine", *context);
984 module->setDataLayout(
"");
991 auto F = get_function_type(context.get());
994 executionengine = std::unique_ptr<llvm::ExecutionEngine>(
995 llvm::EngineBuilder(std::move(module))
996 .setEngineKind(llvm::EngineKind::Kind::JIT)
997 .setOptLevel(CodeGenOptLevel::Aggressive)
1001 class MCJITObjectLoader :
public llvm::ObjectCache
1003 const std::string &s_;
1006 MCJITObjectLoader(
const std::string &s) : s_(s) {}
1007 void notifyObjectCompiled(
const llvm::Module *M,
1008 llvm::MemoryBufferRef obj)
override
1014 std::unique_ptr<llvm::MemoryBuffer>
1015 getObject(
const llvm::Module *M)
override
1017 return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
1021 MCJITObjectLoader loader(s);
1022 executionengine->setObjectCache(&loader);
1023 executionengine->finalizeObject();
1025 func = (intptr_t)executionengine->getPointerToFunction(F);
1028 void LLVMVisitor::bvisit(
const Floor &x)
1030 std::vector<llvm::Value *> args;
1031 llvm::Function *fun;
1032 args.push_back(apply(*x.get_arg()));
1033 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1034 llvm::Intrinsic::floor, 1,
mod);
1035 auto r = builder->CreateCall(fun, args);
1036 r->setTailCall(
true);
1040 void LLVMVisitor::bvisit(
const Ceiling &x)
1042 std::vector<llvm::Value *> args;
1043 llvm::Function *fun;
1044 args.push_back(apply(*x.get_arg()));
1045 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1046 llvm::Intrinsic::ceil, 1,
mod);
1047 auto r = builder->CreateCall(fun, args);
1048 r->setTailCall(
true);
1052 void LLVMVisitor::bvisit(
const UnevaluatedExpr &x)
1054 apply(*x.get_arg());
1057 void LLVMVisitor::bvisit(
const Truncate &x)
1059 std::vector<llvm::Value *> args;
1060 llvm::Function *fun;
1061 args.push_back(apply(*x.get_arg()));
1062 fun = get_float_intrinsic(get_float_type(&
mod->getContext()),
1063 llvm::Intrinsic::trunc, 1,
mod);
1064 auto r = builder->CreateCall(fun, args);
1065 r->setTailCall(
true);
1069 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1071 return llvm::Type::getDoubleTy(*context);
1074 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1076 return llvm::Type::getFloatTy(*context);
1079 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1080 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1082 return llvm::Type::getX86_FP80Ty(*context);
The lowest unit of symbolic representation.
Main namespace for SymEngine package.
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh: