llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <memory>
34 #include <vector>
35 #include <fstream>
36 
37 #include <symengine/llvm_double.h>
38 #include <symengine/eval_double.h>
39 #include <symengine/eval.h>
40 
41 namespace SymEngine
42 {
43 
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
46 #else
47 using llvm::make_unique;
48 #endif
49 #if (LLVM_VERSION_MAJOR < 18)
50 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
51 #else
52 using CodeGenOptLevel = llvm::CodeGenOptLevel;
53 #endif
54 
55 #if (LLVM_VERSION_MAJOR >= 21)
56 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
57 #else
58 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
59 #endif
60 
61 class IRBuilder : public llvm::IRBuilder<>
62 {
63 };
64 
65 LLVMVisitor::LLVMVisitor() = default;
66 LLVMVisitor::~LLVMVisitor() = default;
67 
68 llvm::Value *LLVMVisitor::apply(const Basic &b)
69 {
70  b.accept(*this);
71  return result_;
72 }
73 
74 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
75  unsigned opt_level)
76 {
77  init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
78 }
79 
80 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
81 {
83  for (int i = 0; i < 2; i++) {
84  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
85  }
86  llvm::FunctionType *function_type = llvm::FunctionType::get(
87  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
88  auto F = llvm::Function::Create(function_type,
89  llvm::Function::InternalLinkage, "", mod);
90  F->setCallingConv(llvm::CallingConv::C);
91 #if (LLVM_VERSION_MAJOR < 5)
92  {
93  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
94  llvm::AttributeSet PAS;
95  {
96  llvm::AttrBuilder B;
97  B.addAttribute(llvm::Attribute::ReadOnly);
98  AddNoCapture(B);
99  PAS = llvm::AttributeSet::get(mod->getContext(), 1U, B);
100  }
101 
102  attrs.push_back(PAS);
103  {
104  llvm::AttrBuilder B;
105  AddNoCapture(B);
106  PAS = llvm::AttributeSet::get(mod->getContext(), 2U, B);
107  }
108 
109  attrs.push_back(PAS);
110  {
111  llvm::AttrBuilder B;
112  B.addAttribute(llvm::Attribute::NoUnwind);
113  B.addAttribute(llvm::Attribute::UWTable);
114  PAS = llvm::AttributeSet::get(mod->getContext(), ~0U, B);
115  }
116 
117  attrs.push_back(PAS);
118 
119  F->setAttributes(llvm::AttributeSet::get(mod->getContext(), attrs));
120  }
121 #else
122  F->addParamAttr(0, llvm::Attribute::ReadOnly);
123 #if (LLVM_VERSION_MAJOR >= 21)
124  F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
125  *context, llvm::CaptureInfo::none()));
126  F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
127  *context, llvm::CaptureInfo::none()));
128 #else
129  F->addParamAttr(0, llvm::Attribute::NoCapture);
130  F->addParamAttr(1, llvm::Attribute::NoCapture);
131 #endif
132  F->addFnAttr(llvm::Attribute::NoUnwind);
133 #if (LLVM_VERSION_MAJOR < 15)
134  F->addFnAttr(llvm::Attribute::UWTable);
135 #else
136  F->addFnAttr(llvm::Attribute::getWithUWTableKind(
137  *context, llvm::UWTableKind::Default));
138 #endif
139 #endif
140  return F;
141 }
142 
143 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
144  const bool symbolic_cse, unsigned opt_level)
145 {
146  executionengine.reset();
147  llvm::InitializeNativeTarget();
148  llvm::InitializeNativeTargetAsmPrinter();
149  llvm::InitializeNativeTargetAsmParser();
150  context = make_unique<llvm::LLVMContext>();
151  symbols = inputs;
152 
153  // Create some module to put our function into it.
155  = make_unique<llvm::Module>("SymEngine", *context.get());
156  module->setDataLayout("");
157  mod = module.get();
158 
159  auto F = get_function_type(context.get());
160 
161  // Add a basic block to the function. As before, it automatically
162  // inserts
163  // because of the last argument.
164  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
165 
166  // Create a basic block builder with default parameters. The builder
167  // will
168  // automatically append instructions to the basic block `BB'.
169  llvm::IRBuilder<> _builder(BB);
170  builder = reinterpret_cast<IRBuilder *>(&_builder);
171  builder->SetInsertPoint(BB);
172  auto fmf = llvm::FastMathFlags();
173  // fmf.setUnsafeAlgebra();
174  builder->setFastMathFlags(fmf);
175 
176  // Load all the symbols and create references
177  auto input_arg = &(*(F->args().begin()));
178  for (unsigned i = 0; i < inputs.size(); i++) {
179  if (not is_a<Symbol>(*inputs[i])) {
180  throw SymEngineException("Input contains a non-symbol.");
181  }
182  auto index
183  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
184  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
185  index);
186  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
187  symbol_ptrs.push_back(result_);
188  }
189 
190  auto it = F->args().begin();
191 #if (LLVM_VERSION_MAJOR < 5)
192  auto out = &(*(++it));
193 #else
194  auto out = &(*(it + 1));
195 #endif
196  std::vector<llvm::Value *> output_vals;
197 
198  if (symbolic_cse) {
199  vec_basic reduced_exprs;
200  vec_pair replacements;
201  // cse the outputs
202  SymEngine::cse(replacements, reduced_exprs, outputs);
203  for (auto &rep : replacements) {
204  // Store the replacement symbol values in a dictionary
205  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
206  }
207  // Generate IR for all the reduced exprs and save references
208  for (unsigned i = 0; i < outputs.size(); i++) {
209  output_vals.push_back(apply(*reduced_exprs[i]));
210  }
211  } else {
212  // Generate IR for all the output exprs and save references
213  for (unsigned i = 0; i < outputs.size(); i++) {
214  output_vals.push_back(apply(*outputs[i]));
215  }
216  }
217 
218  // Store all the output exprs at the end
219  for (unsigned i = 0; i < outputs.size(); i++) {
220  auto index
221  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
222  auto ptr
223  = builder->CreateGEP(get_float_type(context.get()), out, index);
224  builder->CreateStore(output_vals[i], ptr);
225  }
226 
227  // Create the return instruction and add it to the basic block
228  builder->CreateRetVoid();
229 
230  // Validate the generated code, checking for consistency.
231  llvm::verifyFunction(*F, &llvm::outs());
232 
233  // std::cout << "LLVM IR" << std::endl;
234  // #if (LLVM_VERSION_MAJOR < 5)
235  // module->dump();
236  // #else
237  // module->print(llvm::errs(), nullptr);
238  // #endif
239 
240  // module->print(llvm::errs(), nullptr);
241 
242  // Optimize the function using default passes from PassBuilder
243  // FunctionSimplificationPipeline for the opt_level
244 #if (LLVM_VERSION_MAJOR < 14)
245  using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
246 #else
247  using OptimizationLevel = llvm::OptimizationLevel;
248 #endif
249  llvm::PassBuilder PB;
250  llvm::ModuleAnalysisManager MAM;
251  llvm::CGSCCAnalysisManager CGAM;
252  llvm::FunctionAnalysisManager FAM;
253  llvm::LoopAnalysisManager LAM;
254  PB.registerModuleAnalyses(MAM);
255  PB.registerCGSCCAnalyses(CGAM);
256  PB.registerFunctionAnalyses(FAM);
257  PB.registerLoopAnalyses(LAM);
258  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
259  llvm::FunctionPassManager FPM;
260  OptimizationLevel pb_opt_level{OptimizationLevel::O3};
261  if (opt_level == 0) {
262  pb_opt_level = OptimizationLevel::O0;
263  } else if (opt_level == 1) {
264  pb_opt_level = OptimizationLevel::O1;
265  } else if (opt_level == 2) {
266  pb_opt_level = OptimizationLevel::O2;
267  }
268 #if (LLVM_VERSION_MAJOR < 6)
269  FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
270 #elif (LLVM_VERSION_MAJOR < 12)
271  FPM = PB.buildFunctionSimplificationPipeline(
272  pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
273 #else
274  FPM = PB.buildFunctionSimplificationPipeline(
275  pb_opt_level, llvm::ThinOrFullLTOPhase::None);
276 #endif
277  FPM.run(*F, FAM);
278 
279  // std::cout << "Optimized LLVM IR" << std::endl;
280  // module->print(llvm::errs(), nullptr);
281 
282  // Now we create the JIT.
283  std::string error;
284  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
285  llvm::EngineBuilder(std::move(module))
286  .setEngineKind(llvm::EngineKind::Kind::JIT)
287  .setOptLevel(static_cast<CodeGenOptLevel>(opt_level))
288  .setErrorStr(&error)
289  .create());
290 
291  // This is a hack to get the MemoryBuffer of a compiled object.
292  class MemoryBufferRefCallback : public llvm::ObjectCache
293  {
294  public:
295  std::string &ss_;
296  MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
297 
298  void notifyObjectCompiled(const llvm::Module *M,
299  llvm::MemoryBufferRef obj) override
300  {
301  const char *c = obj.getBufferStart();
302  // Saving the object code in a std::string
303  ss_.assign(c, obj.getBufferSize());
304  }
305 
307  getObject(const llvm::Module *M) override
308  {
309  return NULL;
310  }
311  };
312 
313  MemoryBufferRefCallback callback(membuffer);
314  executionengine->setObjectCache(&callback);
315  // std::cout << error << std::endl;
316  executionengine->finalizeObject();
317 
318  // Get the symbol's address
319  func = (intptr_t)executionengine->getPointerToFunction(F);
320  symbol_ptrs.clear();
321  replacement_symbol_ptrs.clear();
322  symbols.clear();
323 }
324 
325 LLVMDoubleVisitor::LLVMDoubleVisitor() = default;
326 LLVMDoubleVisitor::~LLVMDoubleVisitor() = default;
327 
328 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
329 {
330  double ret;
331  ((double (*)(const double *, double *))func)(vec.data(), &ret);
332  return ret;
333 }
334 
335 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
336 {
337  ((double (*)(const double *, double *))func)(inps, outs);
338 }
339 
340 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
341 long double
342 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
343 {
344  long double ret;
345  ((long double (*)(const long double *, long double *))func)(vec.data(),
346  &ret);
347  return ret;
348 }
349 
350 void LLVMLongDoubleVisitor::call(long double *outs,
351  const long double *inps) const
352 {
353  ((long double (*)(const long double *, long double *))func)(inps, outs);
354 }
355 #endif
356 
357 LLVMFloatVisitor::LLVMFloatVisitor() = default;
358 LLVMFloatVisitor::~LLVMFloatVisitor() = default;
359 
360 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
361 {
362  float ret;
363  ((float (*)(const float *, float *))func)(vec.data(), &ret);
364  return ret;
365 }
366 
367 void LLVMFloatVisitor::call(float *outs, const float *inps) const
368 {
369  ((float (*)(const float *, float *))func)(inps, outs);
370 }
371 
372 void LLVMVisitor::set_double(double d)
373 {
374  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
375 }
376 
377 void LLVMVisitor::bvisit(const Integer &x)
378 {
379  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
380  mp_get_d(x.as_integer_class()));
381 }
382 
383 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
384 
385 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() = default;
386 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() = default;
387 
388 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
389 {
390 #ifndef HAVE_SYMENGINE_MPFR
391  throw NotImplementedError("Cannot convert to long double without MPFR");
392 #else
393  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
394  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
395  m->__str__());
396 #endif
397 }
398 
399 void LLVMLongDoubleVisitor::visit(const Integer &x)
400 {
401  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
402  x.__str__());
403 }
404 #endif
405 
406 void LLVMVisitor::bvisit(const Rational &x)
407 {
408  set_double(mp_get_d(x.as_rational_class()));
409 }
410 
411 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
412 void LLVMLongDoubleVisitor::visit(const Rational &x)
413 {
414  convert_from_mpfr(x);
415 }
416 #endif
417 
418 void LLVMVisitor::bvisit(const RealDouble &x)
419 {
420  set_double(x.i);
421 }
422 
423 #ifdef HAVE_SYMENGINE_MPFR
424 void LLVMVisitor::bvisit(const RealMPFR &x)
425 {
426  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
427 }
428 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
429 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
430 {
431  convert_from_mpfr(x);
432 }
433 #endif
434 #endif
435 
436 void LLVMVisitor::bvisit(const Add &x)
437 {
438  llvm::Value *tmp, *tmp1, *tmp2;
439  auto it = x.get_dict().begin();
440 
441  if (eq(*x.get_coef(), *zero)) {
442  // `x + 0.0` is not optimized out
443  if (eq(*one, *(it->second))) {
444  tmp = apply(*(it->first));
445  } else {
446  tmp1 = apply(*(it->first));
447  tmp2 = apply(*(it->second));
448  tmp = builder->CreateFMul(tmp1, tmp2);
449  }
450  ++it;
451  } else {
452  tmp = apply(*x.get_coef());
453  }
454 
455  for (; it != x.get_dict().end(); ++it) {
456  if (eq(*one, *(it->second))) {
457  tmp1 = apply(*(it->first));
458  tmp = builder->CreateFAdd(tmp, tmp1);
459  } else {
460  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
461  // tmp =
462  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
463  // llvm::Intrinsic::fma,
464  // 3, context), args);
465  tmp1 = apply(*(it->first));
466  tmp2 = apply(*(it->second));
467  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
468  }
469  }
470  result_ = tmp;
471 }
472 
473 void LLVMVisitor::bvisit(const Mul &x)
474 {
475  llvm::Value *tmp = nullptr;
476  bool first = true;
477  for (const auto &p : x.get_args()) {
478  if (first) {
479  tmp = apply(*p);
480  } else {
481  tmp = builder->CreateFMul(tmp, apply(*p));
482  }
483  first = false;
484  }
485  result_ = tmp;
486 }
487 
488 llvm::Function *LLVMVisitor::get_powi()
489 {
490  std::vector<llvm::Type *> arg_type;
491  arg_type.push_back(get_float_type(&mod->getContext()));
492 #if (LLVM_VERSION_MAJOR > 12)
493  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
494 #endif
495  return llvm::Intrinsic::getDeclaration(mod, llvm::Intrinsic::powi,
496  arg_type);
497 }
498 
499 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
500  unsigned n, llvm::Module *mod)
501 {
502  std::vector<llvm::Type *> arg_type(n, type);
503  return llvm::Intrinsic::getDeclaration(mod, id, arg_type);
504 }
505 
506 void LLVMVisitor::bvisit(const Pow &x)
507 {
509  llvm::Function *fun;
510  if (eq(*(x.get_base()), *E)) {
511  args.push_back(apply(*x.get_exp()));
512  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
513  llvm::Intrinsic::exp, 1, mod);
514 
515  } else if (eq(*(x.get_base()), *integer(2))) {
516  args.push_back(apply(*x.get_exp()));
517  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
518  llvm::Intrinsic::exp2, 1, mod);
519 
520  } else {
521  if (is_a<Integer>(*x.get_exp())) {
522  if (eq(*x.get_exp(), *integer(2))) {
523  llvm::Value *tmp = apply(*x.get_base());
524  result_ = builder->CreateFMul(tmp, tmp);
525  return;
526  } else {
527  args.push_back(apply(*x.get_base()));
528  int d = numeric_cast<int>(
529  mp_get_si(static_cast<const Integer &>(*x.get_exp())
530  .as_integer_class()));
531  result_ = llvm::ConstantInt::get(
532  llvm::Type::getInt32Ty(mod->getContext()), d, true);
533  args.push_back(result_);
534  fun = get_powi();
535  }
536  } else {
537  args.push_back(apply(*x.get_base()));
538  args.push_back(apply(*x.get_exp()));
539  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
540  llvm::Intrinsic::pow, 1, mod);
541  }
542  }
543  auto r = builder->CreateCall(fun, args);
544  r->setTailCall(true);
545  result_ = r;
546 }
547 
548 void LLVMVisitor::bvisit(const Sin &x)
549 {
551  llvm::Function *fun;
552  args.push_back(apply(*x.get_arg()));
553  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
554  llvm::Intrinsic::sin, 1, mod);
555  auto r = builder->CreateCall(fun, args);
556  r->setTailCall(true);
557  result_ = r;
558 }
559 
560 void LLVMVisitor::bvisit(const Cos &x)
561 {
563  llvm::Function *fun;
564  args.push_back(apply(*x.get_arg()));
565  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
566  llvm::Intrinsic::cos, 1, mod);
567  auto r = builder->CreateCall(fun, args);
568  r->setTailCall(true);
569  result_ = r;
570 }
571 
572 void LLVMVisitor::bvisit(const Piecewise &x)
573 {
575 
576  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
577 
578  if (neq(*pw->get_vec().back().second, *boolTrue)) {
579  throw SymEngineException(
580  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
581  }
582 
583  if (pw->get_vec().size() > 2) {
584  PiecewiseVec rest = pw->get_vec();
585  rest.erase(rest.begin());
586  auto rest_pw = piecewise(std::move(rest));
587  PiecewiseVec new_pw;
588  new_pw.push_back(*pw->get_vec().begin());
589  new_pw.push_back({rest_pw, pw->get_vec().back().second});
590  pw = piecewise(std::move(new_pw))
591  ->rcp_from_this_cast<const Piecewise>();
592  } else if (pw->get_vec().size() < 2) {
593  throw SymEngineException("Invalid Piecewise object");
594  }
595 
596  auto cond_basic = pw->get_vec().front().second;
597  llvm::Value *cond = apply(*cond_basic);
598  // check if cond != 0.0
599  cond = builder->CreateFCmpONE(
600  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
601  "ifcond");
602  llvm::Function *function = builder->GetInsertBlock()->getParent();
603  // Create blocks for the then and else cases. Insert the 'then' block at
604  // the
605  // end of the function.
606  llvm::BasicBlock *then_bb
607  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
608  llvm::BasicBlock *else_bb
609  = llvm::BasicBlock::Create(mod->getContext(), "else");
610  llvm::BasicBlock *merge_bb
611  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
612  builder->CreateCondBr(cond, then_bb, else_bb);
613 
614  // Emit then value.
615  builder->SetInsertPoint(then_bb);
616  llvm::Value *then_value = apply(*pw->get_vec().front().first);
617  builder->CreateBr(merge_bb);
618 
619  // Codegen of 'then_value' can change the current block, update then_bb for
620  // the PHI.
621  then_bb = builder->GetInsertBlock();
622 
623  // Emit else block.
624 #if (LLVM_VERSION_MAJOR < 16)
625  function->getBasicBlockList().push_back(else_bb);
626 #else
627  function->insert(function->end(), else_bb);
628 #endif
629  builder->SetInsertPoint(else_bb);
630  llvm::Value *else_value = apply(*pw->get_vec().back().first);
631  builder->CreateBr(merge_bb);
632 
633  // Codegen of 'else_value' can change the current block, update else_bb for
634  // the PHI.
635  else_bb = builder->GetInsertBlock();
636 
637  // Emit merge block.
638 #if (LLVM_VERSION_MAJOR < 16)
639  function->getBasicBlockList().push_back(merge_bb);
640 #else
641  function->insert(function->end(), merge_bb);
642 #endif
643  builder->SetInsertPoint(merge_bb);
644  llvm::PHINode *phi_node
645  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
646 
647  phi_node->addIncoming(then_value, then_bb);
648  phi_node->addIncoming(else_value, else_bb);
649  result_ = phi_node;
650 }
651 
652 void LLVMVisitor::bvisit(const Sign &x)
653 {
654  const auto x2 = x.get_arg();
655  PiecewiseVec new_pw;
656  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
657  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
658  new_pw.push_back({real_double(1.0), boolTrue});
659  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
660  bvisit(*pw);
661 }
662 
663 void LLVMVisitor::bvisit(const Contains &cts)
664 {
665  llvm::Value *expr = apply(*cts.get_expr());
666  const auto set = cts.get_set();
667  if (is_a<Interval>(*set)) {
668  const auto &interv = down_cast<const Interval &>(*set);
669  llvm::Value *start = apply(*interv.get_start());
670  llvm::Value *end = apply(*interv.get_end());
671  const bool left_open = interv.get_left_open();
672  const bool right_open = interv.get_right_open();
673  llvm::Value *left_ok;
674  llvm::Value *right_ok;
675  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
676  : builder->CreateFCmpOLE(start, expr);
677  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
678  : builder->CreateFCmpOLE(expr, end);
679  result_ = builder->CreateAnd(left_ok, right_ok);
680  result_ = builder->CreateUIToFP(result_,
681  get_float_type(&mod->getContext()));
682  } else {
683  throw SymEngineException("LLVMVisitor: only ``Interval`` "
684  "implemented for ``Contains``.");
685  }
686 }
687 
688 void LLVMVisitor::bvisit(const Infty &x)
689 {
690  if (x.is_negative_infinity()) {
691  result_ = llvm::ConstantFP::getInfinity(
692  get_float_type(&mod->getContext()), true);
693  } else if (x.is_positive_infinity()) {
694  result_ = llvm::ConstantFP::getInfinity(
695  get_float_type(&mod->getContext()), false);
696  } else {
697  throw SymEngineException(
698  "LLVMDouble can only represent real valued infinity");
699  }
700 }
701 
702 void LLVMVisitor::bvisit(const NaN &x)
703 {
704  result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
705  /*negative=*/false, /*payload=*/0);
706 }
707 
708 void LLVMVisitor::bvisit(const BooleanAtom &x)
709 {
710  const bool val = x.get_val();
711  set_double(val ? 1.0 : 0.0);
712 }
713 
714 void LLVMVisitor::bvisit(const Log &x)
715 {
717  llvm::Function *fun;
718  args.push_back(apply(*x.get_arg()));
719  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
720  llvm::Intrinsic::log, 1, mod);
721  auto r = builder->CreateCall(fun, args);
722  r->setTailCall(true);
723  result_ = r;
724 }
725 
726 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
727  void LLVMVisitor::bvisit(const Class &x) \
728  { \
729  llvm::Value *value = nullptr; \
730  llvm::Value *tmp; \
731  set_double(0.0); \
732  llvm::Value *zero_val = result_; \
733  for (auto &p : x.get_container()) { \
734  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
735  if (value == nullptr) { \
736  value = tmp; \
737  } else { \
738  value = builder->method(value, tmp); \
739  } \
740  } \
741  result_ = builder->CreateUIToFP(value, \
742  get_float_type(&mod->getContext())); \
743  }
744 
745 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
746 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
747 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
748 
749 void LLVMVisitor::bvisit(const Not &x)
750 {
751  builder->CreateNot(apply(*x.get_arg()));
752 }
753 
754 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
755  void LLVMVisitor::bvisit(const Class &x) \
756  { \
757  llvm::Value *left = apply(*x.get_arg1()); \
758  llvm::Value *right = apply(*x.get_arg2()); \
759  result_ = builder->method(left, right); \
760  result_ = builder->CreateUIToFP(result_, \
761  get_float_type(&mod->getContext())); \
762  }
763 
764 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
765 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
766 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
767 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
768 
769 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
770  void LLVMDoubleVisitor::visit(const Class &x) \
771  { \
772  vec_basic basic_args = x.get_args(); \
773  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
774  std::vector<llvm::Value *> args; \
775  for (const auto &arg : basic_args) { \
776  args.push_back(apply(*arg)); \
777  } \
778  auto r = builder->CreateCall(func, args); \
779  r->setTailCall(true); \
780  result_ = r; \
781  } \
782  void LLVMFloatVisitor::visit(const Class &x) \
783  { \
784  vec_basic basic_args = x.get_args(); \
785  llvm::Function *func = get_external_function(#ext + std::string("f"), \
786  basic_args.size()); \
787  std::vector<llvm::Value *> args; \
788  for (const auto &arg : basic_args) { \
789  args.push_back(apply(*arg)); \
790  } \
791  auto r = builder->CreateCall(func, args); \
792  r->setTailCall(true); \
793  result_ = r; \
794  }
795 
796 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
797 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
798  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
799  void LLVMLongDoubleVisitor::visit(const Class &x) \
800  { \
801  vec_basic basic_args = x.get_args(); \
802  llvm::Function *func = get_external_function(#ext + std::string("l"), \
803  basic_args.size()); \
804  std::vector<llvm::Value *> args; \
805  for (const auto &arg : basic_args) { \
806  args.push_back(apply(*arg)); \
807  } \
808  auto r = builder->CreateCall(func, args); \
809  r->setTailCall(true); \
810  result_ = r; \
811  }
812 #else
813 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
814  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
815 #endif
816 
817 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
818 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
819 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
820 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
821 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
822 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
823 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
824 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
825 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
826 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
827 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
828 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
829 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
830 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
831 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
832 
833 void LLVMVisitor::bvisit(const Abs &x)
834 {
836  llvm::Function *fun;
837  args.push_back(apply(*x.get_arg()));
838  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
839  llvm::Intrinsic::fabs, 1, mod);
840  auto r = builder->CreateCall(fun, args);
841  r->setTailCall(true);
842  result_ = r;
843 }
844 
845 void LLVMVisitor::bvisit(const Min &x)
846 {
847  llvm::Value *value = nullptr;
848  llvm::Function *fun;
849  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
850  llvm::Intrinsic::minnum, 1, mod);
851  for (auto &arg : x.get_vec()) {
852  if (value != nullptr) {
854  args.push_back(value);
855  args.push_back(apply(*arg));
856  auto r = builder->CreateCall(fun, args);
857  r->setTailCall(true);
858  value = r;
859  } else {
860  value = apply(*arg);
861  }
862  }
863  result_ = value;
864 }
865 
866 void LLVMVisitor::bvisit(const Max &x)
867 {
868  llvm::Value *value = nullptr;
869  llvm::Function *fun;
870  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
871  llvm::Intrinsic::maxnum, 1, mod);
872  for (auto &arg : x.get_vec()) {
873  if (value != nullptr) {
875  args.push_back(value);
876  args.push_back(apply(*arg));
877  auto r = builder->CreateCall(fun, args);
878  r->setTailCall(true);
879  value = r;
880  } else {
881  value = apply(*arg);
882  }
883  }
884  result_ = value;
885 }
886 
887 void LLVMVisitor::bvisit(const Symbol &x)
888 {
889  unsigned i = 0;
890  for (auto &symb : symbols) {
891  if (eq(x, *symb)) {
892  result_ = symbol_ptrs[i];
893  return;
894  }
895  ++i;
896  }
897  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
898  if (it != replacement_symbol_ptrs.end()) {
899  result_ = it->second;
900  return;
901  }
902 
903  throw SymEngineException("Symbol " + x.__str__()
904  + " not in the symbols vector.");
905 }
906 
907 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
908  size_t nargs)
909 {
910  std::vector<llvm::Type *> func_args(nargs,
911  get_float_type(&mod->getContext()));
912  llvm::FunctionType *func_type = llvm::FunctionType::get(
913  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
914 
915  llvm::Function *func = mod->getFunction(name);
916  if (!func) {
917  func = llvm::Function::Create(
918  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
919  func->setCallingConv(llvm::CallingConv::C);
920  }
921 #if (LLVM_VERSION_MAJOR < 5)
922  llvm::AttributeSet func_attr_set;
923  {
924  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
925  llvm::AttributeSet attr_set;
926  {
927  llvm::AttrBuilder attr_builder;
928  attr_builder.addAttribute(llvm::Attribute::NoUnwind);
929  attr_set
930  = llvm::AttributeSet::get(mod->getContext(), ~0U, attr_builder);
931  }
932 
933  attrs.push_back(attr_set);
934  func_attr_set = llvm::AttributeSet::get(mod->getContext(), attrs);
935  }
936  func->setAttributes(func_attr_set);
937 #else
938  func->addFnAttr(llvm::Attribute::NoUnwind);
939 #endif
940  return func;
941 }
942 
943 void LLVMVisitor::bvisit(const Constant &x)
944 {
945  set_double(eval_double(x));
946 }
947 
948 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
949 void LLVMLongDoubleVisitor::visit(const Constant &x)
950 {
951  convert_from_mpfr(x);
952 }
953 #endif
954 
955 void LLVMVisitor::bvisit(const Basic &x)
956 {
957  throw NotImplementedError(x.__str__());
958 }
959 
960 const std::string &LLVMVisitor::dumps() const
961 {
962  return membuffer;
963 };
964 
965 void LLVMVisitor::loads(const std::string &s)
966 {
967  membuffer = s;
968  llvm::InitializeNativeTarget();
969  llvm::InitializeNativeTargetAsmPrinter();
970  llvm::InitializeNativeTargetAsmParser();
971  context = make_unique<llvm::LLVMContext>();
972 
973  // Create some module to put our function into it.
975  = make_unique<llvm::Module>("SymEngine", *context);
976  module->setDataLayout("");
977  mod = module.get();
978 
979  // Only defining the prototype for the function here.
980  // Since we know where the function is stored that's enough
981  // llvm::ObjectCache is designed for caching objects, but it
982  // is used here for loading one specific object.
983  auto F = get_function_type(context.get());
984 
985  std::string error;
986  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
987  llvm::EngineBuilder(std::move(module))
988  .setEngineKind(llvm::EngineKind::Kind::JIT)
989  .setOptLevel(CodeGenOptLevel::Aggressive)
990  .setErrorStr(&error)
991  .create());
992 
993  class MCJITObjectLoader : public llvm::ObjectCache
994  {
995  const std::string &s_;
996 
997  public:
998  MCJITObjectLoader(const std::string &s) : s_(s) {}
999  void notifyObjectCompiled(const llvm::Module *M,
1000  llvm::MemoryBufferRef obj) override
1001  {
1002  }
1003 
1004  // No need to check M because there is only one function
1005  // Return it after reading from the file.
1007  getObject(const llvm::Module *M) override
1008  {
1009  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
1010  }
1011  };
1012 
1013  MCJITObjectLoader loader(s);
1014  executionengine->setObjectCache(&loader);
1015  executionengine->finalizeObject();
1016  // Set func to compiled function pointer
1017  func = (intptr_t)executionengine->getPointerToFunction(F);
1018 }
1019 
1020 void LLVMVisitor::bvisit(const Floor &x)
1021 {
1023  llvm::Function *fun;
1024  args.push_back(apply(*x.get_arg()));
1025  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1026  llvm::Intrinsic::floor, 1, mod);
1027  auto r = builder->CreateCall(fun, args);
1028  r->setTailCall(true);
1029  result_ = r;
1030 }
1031 
1032 void LLVMVisitor::bvisit(const Ceiling &x)
1033 {
1035  llvm::Function *fun;
1036  args.push_back(apply(*x.get_arg()));
1037  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1038  llvm::Intrinsic::ceil, 1, mod);
1039  auto r = builder->CreateCall(fun, args);
1040  r->setTailCall(true);
1041  result_ = r;
1042 }
1043 
1044 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1045 {
1046  apply(*x.get_arg());
1047 }
1048 
1049 void LLVMVisitor::bvisit(const Truncate &x)
1050 {
1052  llvm::Function *fun;
1053  args.push_back(apply(*x.get_arg()));
1054  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1055  llvm::Intrinsic::trunc, 1, mod);
1056  auto r = builder->CreateCall(fun, args);
1057  r->setTailCall(true);
1058  result_ = r;
1059 }
1060 
1061 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1062 {
1063  return llvm::Type::getDoubleTy(*context);
1064 }
1065 
1066 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1067 {
1068  return llvm::Type::getFloatTy(*context);
1069 }
1070 
1071 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1072 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1073 {
1074  return llvm::Type::getX86_FP80Ty(*context);
1075 }
1076 #endif
1077 
1078 } // namespace SymEngine
T assign(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
T data(T... args)
T end(T... args)
T erase(T... args)
T get(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:757
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:642
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127
T push_back(T... args)