Loading...
Searching...
No Matches
llvm_double.cpp
1#include "llvm/ADT/STLExtras.h"
2#include "llvm/Analysis/Passes.h"
3#include "llvm/ExecutionEngine/ExecutionEngine.h"
4#include "llvm/ExecutionEngine/GenericValue.h"
5#include "llvm/ExecutionEngine/MCJIT.h"
6#include "llvm/IR/Argument.h"
7#include "llvm/IR/Attributes.h"
8#include "llvm/IR/BasicBlock.h"
9#include "llvm/IR/Constants.h"
10#include "llvm/IR/DerivedTypes.h"
11#include "llvm/IR/Function.h"
12#include "llvm/IR/IRBuilder.h"
13#include "llvm/IR/Instructions.h"
14#include "llvm/IR/Intrinsics.h"
15#include "llvm/IR/LegacyPassManager.h"
16#include "llvm/IR/LLVMContext.h"
17#include "llvm/IR/Module.h"
18#include "llvm/IR/Type.h"
19#include "llvm/Support/Casting.h"
20#include "llvm/Support/ManagedStatic.h"
21#include "llvm/Support/TargetSelect.h"
22#include "llvm/Support/raw_ostream.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/IR/Verifier.h"
26#include "llvm/Support/TargetSelect.h"
27#include "llvm/Target/TargetMachine.h"
28#include "llvm/Transforms/Scalar.h"
29#include "llvm/Transforms/Vectorize.h"
30#include "llvm/ExecutionEngine/ObjectCache.h"
31#include "llvm/Support/FileSystem.h"
32#include "llvm/Support/Path.h"
33#include <algorithm>
34#include <cassert>
35#include <memory>
36#include <vector>
37#include <fstream>
38
39#if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9) \
40 || (LLVM_VERSION_MAJOR > 3)
41#include <llvm/Transforms/Scalar/GVN.h>
42#endif
43
44#if (LLVM_VERSION_MAJOR >= 7)
45#include <llvm/Transforms/InstCombine/InstCombine.h>
46#include <llvm/Transforms/Scalar/InstSimplifyPass.h>
47#include <llvm/Transforms/Utils.h>
48#endif
49
50#include <symengine/llvm_double.h>
52#include <symengine/eval.h>
53
54namespace SymEngine
55{
56
57#if (LLVM_VERSION_MAJOR >= 10)
58using std::make_unique;
59#else
60using llvm::make_unique;
61#endif
62
63class IRBuilder : public llvm::IRBuilder<>
64{
65};
66
67llvm::Value *LLVMVisitor::apply(const Basic &b)
68{
69 b.accept(*this);
70 return result_;
71}
72
73void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
74 unsigned opt_level)
75{
76 init(x, b, symbolic_cse, LLVMVisitor::create_default_passes(opt_level),
77 opt_level);
78}
79
80void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
81 const std::vector<llvm::Pass *> &passes,
82 unsigned opt_level)
83{
84 init(x, {b.rcp_from_this()}, symbolic_cse, passes, opt_level);
85}
86
87llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
88{
90 for (int i = 0; i < 2; i++) {
91 inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
92 }
93 llvm::FunctionType *function_type = llvm::FunctionType::get(
94 llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
95 auto F = llvm::Function::Create(function_type,
96 llvm::Function::InternalLinkage, "", mod);
97 F->setCallingConv(llvm::CallingConv::C);
98#if (LLVM_VERSION_MAJOR < 5)
99 {
100 llvm::SmallVector<llvm::AttributeSet, 4> attrs;
101 llvm::AttributeSet PAS;
102 {
103 llvm::AttrBuilder B;
104 B.addAttribute(llvm::Attribute::ReadOnly);
105 B.addAttribute(llvm::Attribute::NoCapture);
106 PAS = llvm::AttributeSet::get(mod->getContext(), 1U, B);
107 }
108
109 attrs.push_back(PAS);
110 {
111 llvm::AttrBuilder B;
112 B.addAttribute(llvm::Attribute::NoCapture);
113 PAS = llvm::AttributeSet::get(mod->getContext(), 2U, B);
114 }
115
116 attrs.push_back(PAS);
117 {
118 llvm::AttrBuilder B;
119 B.addAttribute(llvm::Attribute::NoUnwind);
120 B.addAttribute(llvm::Attribute::UWTable);
121 PAS = llvm::AttributeSet::get(mod->getContext(), ~0U, B);
122 }
123
124 attrs.push_back(PAS);
125
126 F->setAttributes(llvm::AttributeSet::get(mod->getContext(), attrs));
127 }
128#else
129 F->addParamAttr(0, llvm::Attribute::ReadOnly);
130 F->addParamAttr(0, llvm::Attribute::NoCapture);
131 F->addParamAttr(1, llvm::Attribute::NoCapture);
132 F->addFnAttr(llvm::Attribute::NoUnwind);
133#if (LLVM_VERSION_MAJOR < 15)
134 F->addFnAttr(llvm::Attribute::UWTable);
135#else
136 F->addFnAttr(llvm::Attribute::getWithUWTableKind(
137 *context, llvm::UWTableKind::Default));
138#endif
139#endif
140 return F;
141}
142
143std::vector<llvm::Pass *> LLVMVisitor::create_default_passes(int optlevel)
144{
146 if (optlevel == 0) {
147 return passes;
148 }
149#if (LLVM_VERSION_MAJOR < 4)
150 passes.push_back(llvm::createInstructionCombiningPass());
151#else
152 passes.push_back(llvm::createInstructionCombiningPass(optlevel > 1));
153#endif
154 passes.push_back(llvm::createDeadCodeEliminationPass());
155 passes.push_back(llvm::createPromoteMemoryToRegisterPass());
156 passes.push_back(llvm::createReassociatePass());
157 passes.push_back(llvm::createGVNPass());
158 passes.push_back(llvm::createCFGSimplificationPass());
159 passes.push_back(llvm::createPartiallyInlineLibCallsPass());
160#if (LLVM_VERSION_MAJOR < 5)
161 passes.push_back(llvm::createLoadCombinePass());
162#endif
163#if LLVM_VERSION_MAJOR >= 7
164 passes.push_back(llvm::createInstSimplifyLegacyPass());
165#else
166 passes.push_back(llvm::createInstructionSimplifierPass());
167#endif
168 passes.push_back(llvm::createMemCpyOptPass());
169 passes.push_back(llvm::createSROAPass());
170 passes.push_back(llvm::createMergedLoadStoreMotionPass());
171 passes.push_back(llvm::createBitTrackingDCEPass());
172 passes.push_back(llvm::createAggressiveDCEPass());
173 if (optlevel > 2) {
174 passes.push_back(llvm::createSLPVectorizerPass());
175#if LLVM_VERSION_MAJOR >= 7
176 passes.push_back(llvm::createInstSimplifyLegacyPass());
177#else
178 passes.push_back(llvm::createInstructionSimplifierPass());
179#endif
180 }
181 return passes;
182}
183
184void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
185 const bool symbolic_cse, unsigned opt_level)
186{
187 init(inputs, outputs, symbolic_cse,
188 LLVMVisitor::create_default_passes(opt_level), opt_level);
189}
190
191void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
192 const bool symbolic_cse,
193 const std::vector<llvm::Pass *> &passes,
194 unsigned opt_level)
195{
196 executionengine.reset();
197 llvm::InitializeNativeTarget();
198 llvm::InitializeNativeTargetAsmPrinter();
199 llvm::InitializeNativeTargetAsmParser();
200 context = std::make_shared<llvm::LLVMContext>();
201 symbols = inputs;
202
203 // Create some module to put our function into it.
205 = make_unique<llvm::Module>("SymEngine", *context.get());
206 module->setDataLayout("");
207 mod = module.get();
208
209 // Create a new pass manager attached to it.
210 fpm = std::make_shared<llvm::legacy::FunctionPassManager>(mod);
211 for (auto pass : passes) {
212 fpm->add(pass);
213 }
214 fpm->doInitialization();
215
216 auto F = get_function_type(context.get());
217
218 // Add a basic block to the function. As before, it automatically
219 // inserts
220 // because of the last argument.
221 llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
222
223 // Create a basic block builder with default parameters. The builder
224 // will
225 // automatically append instructions to the basic block `BB'.
226 llvm::IRBuilder<> _builder(BB);
227 builder = reinterpret_cast<IRBuilder *>(&_builder);
228 builder->SetInsertPoint(BB);
229 auto fmf = llvm::FastMathFlags();
230 // fmf.setUnsafeAlgebra();
231 builder->setFastMathFlags(fmf);
232
233 // Load all the symbols and create references
234 auto input_arg = &(*(F->args().begin()));
235 for (unsigned i = 0; i < inputs.size(); i++) {
236 if (not is_a<Symbol>(*inputs[i])) {
237 throw SymEngineException("Input contains a non-symbol.");
238 }
239 auto index
240 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
241 auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
242 index);
243 result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
244 symbol_ptrs.push_back(result_);
245 }
246
247 auto it = F->args().begin();
248#if (LLVM_VERSION_MAJOR < 5)
249 auto out = &(*(++it));
250#else
251 auto out = &(*(it + 1));
252#endif
253 std::vector<llvm::Value *> output_vals;
254
255 if (symbolic_cse) {
256 vec_basic reduced_exprs;
257 vec_pair replacements;
258 // cse the outputs
259 SymEngine::cse(replacements, reduced_exprs, outputs);
260 for (auto &rep : replacements) {
261 // Store the replacement symbol values in a dictionary
262 replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
263 }
264 // Generate IR for all the reduced exprs and save references
265 for (unsigned i = 0; i < outputs.size(); i++) {
266 output_vals.push_back(apply(*reduced_exprs[i]));
267 }
268 } else {
269 // Generate IR for all the output exprs and save references
270 for (unsigned i = 0; i < outputs.size(); i++) {
271 output_vals.push_back(apply(*outputs[i]));
272 }
273 }
274
275 // Store all the output exprs at the end
276 for (unsigned i = 0; i < outputs.size(); i++) {
277 auto index
278 = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
279 auto ptr
280 = builder->CreateGEP(get_float_type(context.get()), out, index);
281 builder->CreateStore(output_vals[i], ptr);
282 }
283
284 // Create the return instruction and add it to the basic block
285 builder->CreateRetVoid();
286
287 // Validate the generated code, checking for consistency.
288 llvm::verifyFunction(*F, &llvm::outs());
289
290 // std::cout << "LLVM IR" << std::endl;
291 // #if (LLVM_VERSION_MAJOR < 5)
292 // module->dump();
293 // #else
294 // module->print(llvm::errs(), nullptr);
295 // #endif
296
297 // Optimize the function.
298 fpm->run(*F);
299
300 // std::cout << "Optimized LLVM IR" << std::endl;
301 // module->dump();
302
303 // Now we create the JIT.
304 std::string error;
306 llvm::EngineBuilder(std::move(module))
307 .setEngineKind(llvm::EngineKind::Kind::JIT)
308 .setOptLevel(static_cast<llvm::CodeGenOpt::Level>(opt_level))
309 .setErrorStr(&error)
310 .create());
311
312 // This is a hack to get the MemoryBuffer of a compiled object.
313 class MemoryBufferRefCallback : public llvm::ObjectCache
314 {
315 public:
316 std::string &ss_;
317 MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
318
319 void notifyObjectCompiled(const llvm::Module *M,
320 llvm::MemoryBufferRef obj) override
321 {
322 const char *c = obj.getBufferStart();
323 // Saving the object code in a std::string
324 ss_.assign(c, obj.getBufferSize());
325 }
326
328 getObject(const llvm::Module *M) override
329 {
330 return NULL;
331 }
332 };
333
334 MemoryBufferRefCallback callback(membuffer);
335 executionengine->setObjectCache(&callback);
336 // std::cout << error << std::endl;
337 executionengine->finalizeObject();
338
339 // Get the symbol's address
340 func = (intptr_t)executionengine->getPointerToFunction(F);
341 symbol_ptrs.clear();
342 replacement_symbol_ptrs.clear();
343 symbols.clear();
344}
345
346double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
347{
348 double ret;
349 ((double (*)(const double *, double *))func)(vec.data(), &ret);
350 return ret;
351}
352
353void LLVMDoubleVisitor::call(double *outs, const double *inps) const
354{
355 ((double (*)(const double *, double *))func)(inps, outs);
356}
357
358#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
359long double
360LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
361{
362 long double ret;
363 ((long double (*)(const long double *, long double *))func)(vec.data(),
364 &ret);
365 return ret;
366}
367
368void LLVMLongDoubleVisitor::call(long double *outs,
369 const long double *inps) const
370{
371 ((long double (*)(const long double *, long double *))func)(inps, outs);
372}
373#endif
374
375float LLVMFloatVisitor::call(const std::vector<float> &vec) const
376{
377 float ret;
378 ((float (*)(const float *, float *))func)(vec.data(), &ret);
379 return ret;
380}
381
382void LLVMFloatVisitor::call(float *outs, const float *inps) const
383{
384 ((float (*)(const float *, float *))func)(inps, outs);
385}
386
387void LLVMVisitor::set_double(double d)
388{
389 result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
390}
391
392void LLVMVisitor::bvisit(const Integer &x)
393{
394 result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
395 mp_get_d(x.as_integer_class()));
396}
397
398#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
399void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
400{
401#ifndef HAVE_SYMENGINE_MPFR
402 throw NotImplementedError("Cannot convert to long double without MPFR");
403#else
404 RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
405 result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
406 m->__str__());
407#endif
408}
409
410void LLVMLongDoubleVisitor::visit(const Integer &x)
411{
412 result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
413 x.__str__());
414}
415#endif
416
417void LLVMVisitor::bvisit(const Rational &x)
418{
419 set_double(mp_get_d(x.as_rational_class()));
420}
421
422#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
423void LLVMLongDoubleVisitor::visit(const Rational &x)
424{
425 convert_from_mpfr(x);
426}
427#endif
428
429void LLVMVisitor::bvisit(const RealDouble &x)
430{
431 set_double(x.i);
432}
433
434#ifdef HAVE_SYMENGINE_MPFR
435void LLVMVisitor::bvisit(const RealMPFR &x)
436{
437 set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
438}
439#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
440void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
441{
442 convert_from_mpfr(x);
443}
444#endif
445#endif
446
447void LLVMVisitor::bvisit(const Add &x)
448{
449 llvm::Value *tmp, *tmp1, *tmp2;
450 auto it = x.get_dict().begin();
451
452 if (eq(*x.get_coef(), *zero)) {
453 // `x + 0.0` is not optimized out
454 if (eq(*one, *(it->second))) {
455 tmp = apply(*(it->first));
456 } else {
457 tmp1 = apply(*(it->first));
458 tmp2 = apply(*(it->second));
459 tmp = builder->CreateFMul(tmp1, tmp2);
460 }
461 ++it;
462 } else {
463 tmp = apply(*x.get_coef());
464 }
465
466 for (; it != x.get_dict().end(); ++it) {
467 if (eq(*one, *(it->second))) {
468 tmp1 = apply(*(it->first));
469 tmp = builder->CreateFAdd(tmp, tmp1);
470 } else {
471 // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
472 // tmp =
473 // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
474 // llvm::Intrinsic::fma,
475 // 3, context), args);
476 tmp1 = apply(*(it->first));
477 tmp2 = apply(*(it->second));
478 tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
479 }
480 }
481 result_ = tmp;
482}
483
484void LLVMVisitor::bvisit(const Mul &x)
485{
486 llvm::Value *tmp = nullptr;
487 bool first = true;
488 for (const auto &p : x.get_args()) {
489 if (first) {
490 tmp = apply(*p);
491 } else {
492 tmp = builder->CreateFMul(tmp, apply(*p));
493 }
494 first = false;
495 }
496 result_ = tmp;
497}
498
499llvm::Function *LLVMVisitor::get_powi()
500{
502 arg_type.push_back(get_float_type(&mod->getContext()));
503#if (LLVM_VERSION_MAJOR > 12)
504 arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
505#endif
506 return llvm::Intrinsic::getDeclaration(mod, llvm::Intrinsic::powi,
507 arg_type);
508}
509
510llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
511 unsigned n, llvm::Module *mod)
512{
513 std::vector<llvm::Type *> arg_type(n, type);
514 return llvm::Intrinsic::getDeclaration(mod, id, arg_type);
515}
516
517void LLVMVisitor::bvisit(const Pow &x)
518{
520 llvm::Function *fun;
521 if (eq(*(x.get_base()), *E)) {
522 args.push_back(apply(*x.get_exp()));
523 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
524 llvm::Intrinsic::exp, 1, mod);
525
526 } else if (eq(*(x.get_base()), *integer(2))) {
527 args.push_back(apply(*x.get_exp()));
528 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
529 llvm::Intrinsic::exp2, 1, mod);
530
531 } else {
532 if (is_a<Integer>(*x.get_exp())) {
533 if (eq(*x.get_exp(), *integer(2))) {
534 llvm::Value *tmp = apply(*x.get_base());
535 result_ = builder->CreateFMul(tmp, tmp);
536 return;
537 } else {
538 args.push_back(apply(*x.get_base()));
539 int d = numeric_cast<int>(
540 mp_get_si(static_cast<const Integer &>(*x.get_exp())
541 .as_integer_class()));
542 result_ = llvm::ConstantInt::get(
543 llvm::Type::getInt32Ty(mod->getContext()), d, true);
544 args.push_back(result_);
545 fun = get_powi();
546 }
547 } else {
548 args.push_back(apply(*x.get_base()));
549 args.push_back(apply(*x.get_exp()));
550 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
551 llvm::Intrinsic::pow, 1, mod);
552 }
553 }
554 auto r = builder->CreateCall(fun, args);
555 r->setTailCall(true);
556 result_ = r;
557}
558
559void LLVMVisitor::bvisit(const Sin &x)
560{
562 llvm::Function *fun;
563 args.push_back(apply(*x.get_arg()));
564 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
565 llvm::Intrinsic::sin, 1, mod);
566 auto r = builder->CreateCall(fun, args);
567 r->setTailCall(true);
568 result_ = r;
569}
570
571void LLVMVisitor::bvisit(const Cos &x)
572{
574 llvm::Function *fun;
575 args.push_back(apply(*x.get_arg()));
576 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
577 llvm::Intrinsic::cos, 1, mod);
578 auto r = builder->CreateCall(fun, args);
579 r->setTailCall(true);
580 result_ = r;
581}
582
583void LLVMVisitor::bvisit(const Piecewise &x)
584{
586
587 RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
588
589 if (neq(*pw->get_vec().back().second, *boolTrue)) {
590 throw SymEngineException(
591 "LLVMDouble requires a (Expr, True) at the end of Piecewise");
592 }
593
594 if (pw->get_vec().size() > 2) {
595 PiecewiseVec rest = pw->get_vec();
596 rest.erase(rest.begin());
597 auto rest_pw = piecewise(std::move(rest));
598 PiecewiseVec new_pw;
599 new_pw.push_back(*pw->get_vec().begin());
600 new_pw.push_back({rest_pw, pw->get_vec().back().second});
601 pw = piecewise(std::move(new_pw))
602 ->rcp_from_this_cast<const Piecewise>();
603 } else if (pw->get_vec().size() < 2) {
604 throw SymEngineException("Invalid Piecewise object");
605 }
606
607 auto cond_basic = pw->get_vec().front().second;
608 llvm::Value *cond = apply(*cond_basic);
609 // check if cond != 0.0
610 cond = builder->CreateFCmpONE(
611 cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
612 "ifcond");
613 llvm::Function *function = builder->GetInsertBlock()->getParent();
614 // Create blocks for the then and else cases. Insert the 'then' block at
615 // the
616 // end of the function.
617 llvm::BasicBlock *then_bb
618 = llvm::BasicBlock::Create(mod->getContext(), "then", function);
619 llvm::BasicBlock *else_bb
620 = llvm::BasicBlock::Create(mod->getContext(), "else");
621 llvm::BasicBlock *merge_bb
622 = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
623 builder->CreateCondBr(cond, then_bb, else_bb);
624
625 // Emit then value.
626 builder->SetInsertPoint(then_bb);
627 llvm::Value *then_value = apply(*pw->get_vec().front().first);
628 builder->CreateBr(merge_bb);
629
630 // Codegen of 'then_value' can change the current block, update then_bb for
631 // the PHI.
632 then_bb = builder->GetInsertBlock();
633
634 // Emit else block.
635#if (LLVM_VERSION_MAJOR < 16)
636 function->getBasicBlockList().push_back(else_bb);
637#else
638 function->insert(function->end(), else_bb);
639#endif
640 builder->SetInsertPoint(else_bb);
641 llvm::Value *else_value = apply(*pw->get_vec().back().first);
642 builder->CreateBr(merge_bb);
643
644 // Codegen of 'else_value' can change the current block, update else_bb for
645 // the PHI.
646 else_bb = builder->GetInsertBlock();
647
648 // Emit merge block.
649#if (LLVM_VERSION_MAJOR < 16)
650 function->getBasicBlockList().push_back(merge_bb);
651#else
652 function->insert(function->end(), merge_bb);
653#endif
654 builder->SetInsertPoint(merge_bb);
655 llvm::PHINode *phi_node
656 = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
657
658 phi_node->addIncoming(then_value, then_bb);
659 phi_node->addIncoming(else_value, else_bb);
660 result_ = phi_node;
661}
662
663void LLVMVisitor::bvisit(const Sign &x)
664{
665 const auto x2 = x.get_arg();
666 PiecewiseVec new_pw;
667 new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
668 new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
669 new_pw.push_back({real_double(1.0), boolTrue});
670 auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
671 bvisit(*pw);
672}
673
674void LLVMVisitor::bvisit(const Contains &cts)
675{
676 llvm::Value *expr = apply(*cts.get_expr());
677 const auto set = cts.get_set();
678 if (is_a<Interval>(*set)) {
679 const auto &interv = down_cast<const Interval &>(*set);
680 llvm::Value *start = apply(*interv.get_start());
681 llvm::Value *end = apply(*interv.get_end());
682 const bool left_open = interv.get_left_open();
683 const bool right_open = interv.get_right_open();
684 llvm::Value *left_ok;
685 llvm::Value *right_ok;
686 left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
687 : builder->CreateFCmpOLE(start, expr);
688 right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
689 : builder->CreateFCmpOLE(expr, end);
690 result_ = builder->CreateAnd(left_ok, right_ok);
691 result_ = builder->CreateUIToFP(result_,
692 get_float_type(&mod->getContext()));
693 } else {
694 throw SymEngineException("LLVMVisitor: only ``Interval`` "
695 "implemented for ``Contains``.");
696 }
697}
698
699void LLVMVisitor::bvisit(const Infty &x)
700{
701 if (x.is_negative_infinity()) {
702 result_ = llvm::ConstantFP::getInfinity(
703 get_float_type(&mod->getContext()), true);
704 } else if (x.is_positive_infinity()) {
705 result_ = llvm::ConstantFP::getInfinity(
706 get_float_type(&mod->getContext()), false);
707 } else {
708 throw SymEngineException(
709 "LLVMDouble can only represent real valued infinity");
710 }
711}
712
713void LLVMVisitor::bvisit(const NaN &x)
714{
715 result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
716 /*negative=*/false, /*payload=*/0);
717}
718
719void LLVMVisitor::bvisit(const BooleanAtom &x)
720{
721 const bool val = x.get_val();
722 set_double(val ? 1.0 : 0.0);
723}
724
725void LLVMVisitor::bvisit(const Log &x)
726{
728 llvm::Function *fun;
729 args.push_back(apply(*x.get_arg()));
730 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
731 llvm::Intrinsic::log, 1, mod);
732 auto r = builder->CreateCall(fun, args);
733 r->setTailCall(true);
734 result_ = r;
735}
736
737#define SYMENGINE_LOGIC_FUNCTION(Class, method) \
738 void LLVMVisitor::bvisit(const Class &x) \
739 { \
740 llvm::Value *value = nullptr; \
741 llvm::Value *tmp; \
742 set_double(0.0); \
743 llvm::Value *zero_val = result_; \
744 for (auto &p : x.get_container()) { \
745 tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
746 if (value == nullptr) { \
747 value = tmp; \
748 } else { \
749 value = builder->method(value, tmp); \
750 } \
751 } \
752 result_ = builder->CreateUIToFP(value, \
753 get_float_type(&mod->getContext())); \
754 }
755
756SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
757SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
758SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
759
760void LLVMVisitor::bvisit(const Not &x)
761{
762 builder->CreateNot(apply(*x.get_arg()));
763}
764
765#define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
766 void LLVMVisitor::bvisit(const Class &x) \
767 { \
768 llvm::Value *left = apply(*x.get_arg1()); \
769 llvm::Value *right = apply(*x.get_arg2()); \
770 result_ = builder->method(left, right); \
771 result_ = builder->CreateUIToFP(result_, \
772 get_float_type(&mod->getContext())); \
773 }
774
775SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
776SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
777SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
778SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
779
780#define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
781 void LLVMDoubleVisitor::visit(const Class &x) \
782 { \
783 vec_basic basic_args = x.get_args(); \
784 llvm::Function *func = get_external_function(#ext, basic_args.size()); \
785 std::vector<llvm::Value *> args; \
786 for (const auto &arg : basic_args) { \
787 args.push_back(apply(*arg)); \
788 } \
789 auto r = builder->CreateCall(func, args); \
790 r->setTailCall(true); \
791 result_ = r; \
792 } \
793 void LLVMFloatVisitor::visit(const Class &x) \
794 { \
795 vec_basic basic_args = x.get_args(); \
796 llvm::Function *func = get_external_function(#ext + std::string("f"), \
797 basic_args.size()); \
798 std::vector<llvm::Value *> args; \
799 for (const auto &arg : basic_args) { \
800 args.push_back(apply(*arg)); \
801 } \
802 auto r = builder->CreateCall(func, args); \
803 r->setTailCall(true); \
804 result_ = r; \
805 }
806
807#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
808#define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
809 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
810 void LLVMLongDoubleVisitor::visit(const Class &x) \
811 { \
812 vec_basic basic_args = x.get_args(); \
813 llvm::Function *func = get_external_function(#ext + std::string("l"), \
814 basic_args.size()); \
815 std::vector<llvm::Value *> args; \
816 for (const auto &arg : basic_args) { \
817 args.push_back(apply(*arg)); \
818 } \
819 auto r = builder->CreateCall(func, args); \
820 r->setTailCall(true); \
821 result_ = r; \
822 }
823#else
824#define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
825 _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
826#endif
827
828SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
829SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
830SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
831SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
832SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
833SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
834SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
835SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
836SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
837SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
838SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
839SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
840SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
841SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
842SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
843
844void LLVMVisitor::bvisit(const Abs &x)
845{
847 llvm::Function *fun;
848 args.push_back(apply(*x.get_arg()));
849 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
850 llvm::Intrinsic::fabs, 1, mod);
851 auto r = builder->CreateCall(fun, args);
852 r->setTailCall(true);
853 result_ = r;
854}
855
856void LLVMVisitor::bvisit(const Min &x)
857{
858 llvm::Value *value = nullptr;
859 llvm::Function *fun;
860 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
861 llvm::Intrinsic::minnum, 1, mod);
862 for (auto &arg : x.get_vec()) {
863 if (value != nullptr) {
865 args.push_back(value);
866 args.push_back(apply(*arg));
867 auto r = builder->CreateCall(fun, args);
868 r->setTailCall(true);
869 value = r;
870 } else {
871 value = apply(*arg);
872 }
873 }
874 result_ = value;
875}
876
877void LLVMVisitor::bvisit(const Max &x)
878{
879 llvm::Value *value = nullptr;
880 llvm::Function *fun;
881 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
882 llvm::Intrinsic::maxnum, 1, mod);
883 for (auto &arg : x.get_vec()) {
884 if (value != nullptr) {
886 args.push_back(value);
887 args.push_back(apply(*arg));
888 auto r = builder->CreateCall(fun, args);
889 r->setTailCall(true);
890 value = r;
891 } else {
892 value = apply(*arg);
893 }
894 }
895 result_ = value;
896}
897
898void LLVMVisitor::bvisit(const Symbol &x)
899{
900 unsigned i = 0;
901 for (auto &symb : symbols) {
902 if (eq(x, *symb)) {
903 result_ = symbol_ptrs[i];
904 return;
905 }
906 ++i;
907 }
908 auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
909 if (it != replacement_symbol_ptrs.end()) {
910 result_ = it->second;
911 return;
912 }
913
914 throw SymEngineException("Symbol " + x.__str__()
915 + " not in the symbols vector.");
916}
917
918llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
919 size_t nargs)
920{
921 std::vector<llvm::Type *> func_args(nargs,
922 get_float_type(&mod->getContext()));
923 llvm::FunctionType *func_type = llvm::FunctionType::get(
924 get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
925
926 llvm::Function *func = mod->getFunction(name);
927 if (!func) {
928 func = llvm::Function::Create(
929 func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
930 func->setCallingConv(llvm::CallingConv::C);
931 }
932#if (LLVM_VERSION_MAJOR < 5)
933 llvm::AttributeSet func_attr_set;
934 {
935 llvm::SmallVector<llvm::AttributeSet, 4> attrs;
936 llvm::AttributeSet attr_set;
937 {
938 llvm::AttrBuilder attr_builder;
939 attr_builder.addAttribute(llvm::Attribute::NoUnwind);
940 attr_set
941 = llvm::AttributeSet::get(mod->getContext(), ~0U, attr_builder);
942 }
943
944 attrs.push_back(attr_set);
945 func_attr_set = llvm::AttributeSet::get(mod->getContext(), attrs);
946 }
947 func->setAttributes(func_attr_set);
948#else
949 func->addFnAttr(llvm::Attribute::NoUnwind);
950#endif
951 return func;
952}
953
954void LLVMVisitor::bvisit(const Constant &x)
955{
956 set_double(eval_double(x));
957}
958
959#ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
960void LLVMLongDoubleVisitor::visit(const Constant &x)
961{
962 convert_from_mpfr(x);
963}
964#endif
965
966void LLVMVisitor::bvisit(const Basic &x)
967{
968 throw NotImplementedError(x.__str__());
969}
970
971const std::string &LLVMVisitor::dumps() const
972{
973 return membuffer;
974};
975
976void LLVMVisitor::loads(const std::string &s)
977{
978 membuffer = s;
979 llvm::InitializeNativeTarget();
980 llvm::InitializeNativeTargetAsmPrinter();
981 llvm::InitializeNativeTargetAsmParser();
982 context = std::make_shared<llvm::LLVMContext>();
983
984 // Create some module to put our function into it.
986 = make_unique<llvm::Module>("SymEngine", *context);
987 module->setDataLayout("");
988 mod = module.get();
989
990 // Only defining the prototype for the function here.
991 // Since we know where the function is stored that's enough
992 // llvm::ObjectCache is designed for caching objects, but it
993 // is used here for loading one specific object.
994 auto F = get_function_type(context.get());
995
996 std::string error;
998 llvm::EngineBuilder(std::move(module))
999 .setEngineKind(llvm::EngineKind::Kind::JIT)
1000 .setOptLevel(llvm::CodeGenOpt::Level::Aggressive)
1001 .setErrorStr(&error)
1002 .create());
1003
1004 class MCJITObjectLoader : public llvm::ObjectCache
1005 {
1006 const std::string &s_;
1007
1008 public:
1009 MCJITObjectLoader(const std::string &s) : s_(s) {}
1010 void notifyObjectCompiled(const llvm::Module *M,
1011 llvm::MemoryBufferRef obj) override
1012 {
1013 }
1014
1015 // No need to check M because there is only one function
1016 // Return it after reading from the file.
1018 getObject(const llvm::Module *M) override
1019 {
1020 return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
1021 }
1022 };
1023
1024 MCJITObjectLoader loader(s);
1025 executionengine->setObjectCache(&loader);
1026 executionengine->finalizeObject();
1027 // Set func to compiled function pointer
1028 func = (intptr_t)executionengine->getPointerToFunction(F);
1029}
1030
1031void LLVMVisitor::bvisit(const Floor &x)
1032{
1034 llvm::Function *fun;
1035 args.push_back(apply(*x.get_arg()));
1036 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1037 llvm::Intrinsic::floor, 1, mod);
1038 auto r = builder->CreateCall(fun, args);
1039 r->setTailCall(true);
1040 result_ = r;
1041}
1042
1043void LLVMVisitor::bvisit(const Ceiling &x)
1044{
1046 llvm::Function *fun;
1047 args.push_back(apply(*x.get_arg()));
1048 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1049 llvm::Intrinsic::ceil, 1, mod);
1050 auto r = builder->CreateCall(fun, args);
1051 r->setTailCall(true);
1052 result_ = r;
1053}
1054
1055void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1056{
1057 apply(*x.get_arg());
1058}
1059
1060void LLVMVisitor::bvisit(const Truncate &x)
1061{
1063 llvm::Function *fun;
1064 args.push_back(apply(*x.get_arg()));
1065 fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1066 llvm::Intrinsic::trunc, 1, mod);
1067 auto r = builder->CreateCall(fun, args);
1068 r->setTailCall(true);
1069 result_ = r;
1070}
1071
1072llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1073{
1074 return llvm::Type::getDoubleTy(*context);
1075}
1076
1077llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1078{
1079 return llvm::Type::getFloatTy(*context);
1080}
1081
1082#if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1083llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1084{
1085 return llvm::Type::getX86_FP80Ty(*context);
1086}
1087#endif
1088
1089} // namespace SymEngine
T assign(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
T data(T... args)
T end(T... args)
T erase(T... args)
T get(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:768
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:66
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:653
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
std::enable_if< std::is_integral< T >::value, RCP< constInteger > >::type integer(T i)
Definition: integer.h:197
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127
T push_back(T... args)