llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <memory>
34 #include <vector>
35 #include <fstream>
36 
37 #include <symengine/llvm_double.h>
38 #include <symengine/eval_double.h>
39 #include <symengine/eval.h>
40 
41 namespace SymEngine
42 {
43 
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
46 #else
47 using llvm::make_unique;
48 #endif
49 
50 class IRBuilder : public llvm::IRBuilder<>
51 {
52 };
53 
54 llvm::Value *LLVMVisitor::apply(const Basic &b)
55 {
56  b.accept(*this);
57  return result_;
58 }
59 
60 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
61  unsigned opt_level)
62 {
63  init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
64 }
65 
66 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
67 {
69  for (int i = 0; i < 2; i++) {
70  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
71  }
72  llvm::FunctionType *function_type = llvm::FunctionType::get(
73  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
74  auto F = llvm::Function::Create(function_type,
75  llvm::Function::InternalLinkage, "", mod);
76  F->setCallingConv(llvm::CallingConv::C);
77 #if (LLVM_VERSION_MAJOR < 5)
78  {
79  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
80  llvm::AttributeSet PAS;
81  {
82  llvm::AttrBuilder B;
83  B.addAttribute(llvm::Attribute::ReadOnly);
84  B.addAttribute(llvm::Attribute::NoCapture);
85  PAS = llvm::AttributeSet::get(mod->getContext(), 1U, B);
86  }
87 
88  attrs.push_back(PAS);
89  {
90  llvm::AttrBuilder B;
91  B.addAttribute(llvm::Attribute::NoCapture);
92  PAS = llvm::AttributeSet::get(mod->getContext(), 2U, B);
93  }
94 
95  attrs.push_back(PAS);
96  {
97  llvm::AttrBuilder B;
98  B.addAttribute(llvm::Attribute::NoUnwind);
99  B.addAttribute(llvm::Attribute::UWTable);
100  PAS = llvm::AttributeSet::get(mod->getContext(), ~0U, B);
101  }
102 
103  attrs.push_back(PAS);
104 
105  F->setAttributes(llvm::AttributeSet::get(mod->getContext(), attrs));
106  }
107 #else
108  F->addParamAttr(0, llvm::Attribute::ReadOnly);
109  F->addParamAttr(0, llvm::Attribute::NoCapture);
110  F->addParamAttr(1, llvm::Attribute::NoCapture);
111  F->addFnAttr(llvm::Attribute::NoUnwind);
112 #if (LLVM_VERSION_MAJOR < 15)
113  F->addFnAttr(llvm::Attribute::UWTable);
114 #else
115  F->addFnAttr(llvm::Attribute::getWithUWTableKind(
116  *context, llvm::UWTableKind::Default));
117 #endif
118 #endif
119  return F;
120 }
121 
122 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
123  const bool symbolic_cse, unsigned opt_level)
124 {
125  executionengine.reset();
126  llvm::InitializeNativeTarget();
127  llvm::InitializeNativeTargetAsmPrinter();
128  llvm::InitializeNativeTargetAsmParser();
129  context = std::make_shared<llvm::LLVMContext>();
130  symbols = inputs;
131 
132  // Create some module to put our function into it.
134  = make_unique<llvm::Module>("SymEngine", *context.get());
135  module->setDataLayout("");
136  mod = module.get();
137 
138  auto F = get_function_type(context.get());
139 
140  // Add a basic block to the function. As before, it automatically
141  // inserts
142  // because of the last argument.
143  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
144 
145  // Create a basic block builder with default parameters. The builder
146  // will
147  // automatically append instructions to the basic block `BB'.
148  llvm::IRBuilder<> _builder(BB);
149  builder = reinterpret_cast<IRBuilder *>(&_builder);
150  builder->SetInsertPoint(BB);
151  auto fmf = llvm::FastMathFlags();
152  // fmf.setUnsafeAlgebra();
153  builder->setFastMathFlags(fmf);
154 
155  // Load all the symbols and create references
156  auto input_arg = &(*(F->args().begin()));
157  for (unsigned i = 0; i < inputs.size(); i++) {
158  if (not is_a<Symbol>(*inputs[i])) {
159  throw SymEngineException("Input contains a non-symbol.");
160  }
161  auto index
162  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
163  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
164  index);
165  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
166  symbol_ptrs.push_back(result_);
167  }
168 
169  auto it = F->args().begin();
170 #if (LLVM_VERSION_MAJOR < 5)
171  auto out = &(*(++it));
172 #else
173  auto out = &(*(it + 1));
174 #endif
175  std::vector<llvm::Value *> output_vals;
176 
177  if (symbolic_cse) {
178  vec_basic reduced_exprs;
179  vec_pair replacements;
180  // cse the outputs
181  SymEngine::cse(replacements, reduced_exprs, outputs);
182  for (auto &rep : replacements) {
183  // Store the replacement symbol values in a dictionary
184  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
185  }
186  // Generate IR for all the reduced exprs and save references
187  for (unsigned i = 0; i < outputs.size(); i++) {
188  output_vals.push_back(apply(*reduced_exprs[i]));
189  }
190  } else {
191  // Generate IR for all the output exprs and save references
192  for (unsigned i = 0; i < outputs.size(); i++) {
193  output_vals.push_back(apply(*outputs[i]));
194  }
195  }
196 
197  // Store all the output exprs at the end
198  for (unsigned i = 0; i < outputs.size(); i++) {
199  auto index
200  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
201  auto ptr
202  = builder->CreateGEP(get_float_type(context.get()), out, index);
203  builder->CreateStore(output_vals[i], ptr);
204  }
205 
206  // Create the return instruction and add it to the basic block
207  builder->CreateRetVoid();
208 
209  // Validate the generated code, checking for consistency.
210  llvm::verifyFunction(*F, &llvm::outs());
211 
212  // std::cout << "LLVM IR" << std::endl;
213  // #if (LLVM_VERSION_MAJOR < 5)
214  // module->dump();
215  // #else
216  // module->print(llvm::errs(), nullptr);
217  // #endif
218 
219  // module->print(llvm::errs(), nullptr);
220 
221  // Optimize the function using default passes from PassBuilder
222  // FunctionSimplificationPipeline for the opt_level
223 #if (LLVM_VERSION_MAJOR < 14)
224  using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
225 #else
226  using OptimizationLevel = llvm::OptimizationLevel;
227 #endif
228  llvm::PassBuilder PB;
229  llvm::ModuleAnalysisManager MAM;
230  llvm::CGSCCAnalysisManager CGAM;
231  llvm::FunctionAnalysisManager FAM;
232  llvm::LoopAnalysisManager LAM;
233  PB.registerModuleAnalyses(MAM);
234  PB.registerCGSCCAnalyses(CGAM);
235  PB.registerFunctionAnalyses(FAM);
236  PB.registerLoopAnalyses(LAM);
237  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
238  llvm::FunctionPassManager FPM;
239  OptimizationLevel pb_opt_level{OptimizationLevel::O3};
240  if (opt_level == 0) {
241  pb_opt_level = OptimizationLevel::O0;
242  } else if (opt_level == 1) {
243  pb_opt_level = OptimizationLevel::O1;
244  } else if (opt_level == 2) {
245  pb_opt_level = OptimizationLevel::O2;
246  }
247 #if (LLVM_VERSION_MAJOR < 6)
248  FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
249 #elif (LLVM_VERSION_MAJOR < 12)
250  FPM = PB.buildFunctionSimplificationPipeline(
251  pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
252 #else
253  FPM = PB.buildFunctionSimplificationPipeline(
254  pb_opt_level, llvm::ThinOrFullLTOPhase::None);
255 #endif
256  FPM.run(*F, FAM);
257 
258  // std::cout << "Optimized LLVM IR" << std::endl;
259  // module->print(llvm::errs(), nullptr);
260 
261  // Now we create the JIT.
262  std::string error;
263  executionengine = std::shared_ptr<llvm::ExecutionEngine>(
264  llvm::EngineBuilder(std::move(module))
265  .setEngineKind(llvm::EngineKind::Kind::JIT)
266  .setOptLevel(static_cast<llvm::CodeGenOpt::Level>(opt_level))
267  .setErrorStr(&error)
268  .create());
269 
270  // This is a hack to get the MemoryBuffer of a compiled object.
271  class MemoryBufferRefCallback : public llvm::ObjectCache
272  {
273  public:
274  std::string &ss_;
275  MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
276 
277  void notifyObjectCompiled(const llvm::Module *M,
278  llvm::MemoryBufferRef obj) override
279  {
280  const char *c = obj.getBufferStart();
281  // Saving the object code in a std::string
282  ss_.assign(c, obj.getBufferSize());
283  }
284 
286  getObject(const llvm::Module *M) override
287  {
288  return NULL;
289  }
290  };
291 
292  MemoryBufferRefCallback callback(membuffer);
293  executionengine->setObjectCache(&callback);
294  // std::cout << error << std::endl;
295  executionengine->finalizeObject();
296 
297  // Get the symbol's address
298  func = (intptr_t)executionengine->getPointerToFunction(F);
299  symbol_ptrs.clear();
300  replacement_symbol_ptrs.clear();
301  symbols.clear();
302 }
303 
304 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
305 {
306  double ret;
307  ((double (*)(const double *, double *))func)(vec.data(), &ret);
308  return ret;
309 }
310 
311 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
312 {
313  ((double (*)(const double *, double *))func)(inps, outs);
314 }
315 
316 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
317 long double
318 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
319 {
320  long double ret;
321  ((long double (*)(const long double *, long double *))func)(vec.data(),
322  &ret);
323  return ret;
324 }
325 
326 void LLVMLongDoubleVisitor::call(long double *outs,
327  const long double *inps) const
328 {
329  ((long double (*)(const long double *, long double *))func)(inps, outs);
330 }
331 #endif
332 
333 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
334 {
335  float ret;
336  ((float (*)(const float *, float *))func)(vec.data(), &ret);
337  return ret;
338 }
339 
340 void LLVMFloatVisitor::call(float *outs, const float *inps) const
341 {
342  ((float (*)(const float *, float *))func)(inps, outs);
343 }
344 
345 void LLVMVisitor::set_double(double d)
346 {
347  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
348 }
349 
350 void LLVMVisitor::bvisit(const Integer &x)
351 {
352  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
353  mp_get_d(x.as_integer_class()));
354 }
355 
356 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
357 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
358 {
359 #ifndef HAVE_SYMENGINE_MPFR
360  throw NotImplementedError("Cannot convert to long double without MPFR");
361 #else
362  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
363  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
364  m->__str__());
365 #endif
366 }
367 
368 void LLVMLongDoubleVisitor::visit(const Integer &x)
369 {
370  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
371  x.__str__());
372 }
373 #endif
374 
375 void LLVMVisitor::bvisit(const Rational &x)
376 {
377  set_double(mp_get_d(x.as_rational_class()));
378 }
379 
380 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
381 void LLVMLongDoubleVisitor::visit(const Rational &x)
382 {
383  convert_from_mpfr(x);
384 }
385 #endif
386 
387 void LLVMVisitor::bvisit(const RealDouble &x)
388 {
389  set_double(x.i);
390 }
391 
392 #ifdef HAVE_SYMENGINE_MPFR
393 void LLVMVisitor::bvisit(const RealMPFR &x)
394 {
395  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
396 }
397 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
398 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
399 {
400  convert_from_mpfr(x);
401 }
402 #endif
403 #endif
404 
405 void LLVMVisitor::bvisit(const Add &x)
406 {
407  llvm::Value *tmp, *tmp1, *tmp2;
408  auto it = x.get_dict().begin();
409 
410  if (eq(*x.get_coef(), *zero)) {
411  // `x + 0.0` is not optimized out
412  if (eq(*one, *(it->second))) {
413  tmp = apply(*(it->first));
414  } else {
415  tmp1 = apply(*(it->first));
416  tmp2 = apply(*(it->second));
417  tmp = builder->CreateFMul(tmp1, tmp2);
418  }
419  ++it;
420  } else {
421  tmp = apply(*x.get_coef());
422  }
423 
424  for (; it != x.get_dict().end(); ++it) {
425  if (eq(*one, *(it->second))) {
426  tmp1 = apply(*(it->first));
427  tmp = builder->CreateFAdd(tmp, tmp1);
428  } else {
429  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
430  // tmp =
431  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
432  // llvm::Intrinsic::fma,
433  // 3, context), args);
434  tmp1 = apply(*(it->first));
435  tmp2 = apply(*(it->second));
436  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
437  }
438  }
439  result_ = tmp;
440 }
441 
442 void LLVMVisitor::bvisit(const Mul &x)
443 {
444  llvm::Value *tmp = nullptr;
445  bool first = true;
446  for (const auto &p : x.get_args()) {
447  if (first) {
448  tmp = apply(*p);
449  } else {
450  tmp = builder->CreateFMul(tmp, apply(*p));
451  }
452  first = false;
453  }
454  result_ = tmp;
455 }
456 
457 llvm::Function *LLVMVisitor::get_powi()
458 {
459  std::vector<llvm::Type *> arg_type;
460  arg_type.push_back(get_float_type(&mod->getContext()));
461 #if (LLVM_VERSION_MAJOR > 12)
462  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
463 #endif
464  return llvm::Intrinsic::getDeclaration(mod, llvm::Intrinsic::powi,
465  arg_type);
466 }
467 
468 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
469  unsigned n, llvm::Module *mod)
470 {
471  std::vector<llvm::Type *> arg_type(n, type);
472  return llvm::Intrinsic::getDeclaration(mod, id, arg_type);
473 }
474 
475 void LLVMVisitor::bvisit(const Pow &x)
476 {
478  llvm::Function *fun;
479  if (eq(*(x.get_base()), *E)) {
480  args.push_back(apply(*x.get_exp()));
481  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
482  llvm::Intrinsic::exp, 1, mod);
483 
484  } else if (eq(*(x.get_base()), *integer(2))) {
485  args.push_back(apply(*x.get_exp()));
486  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
487  llvm::Intrinsic::exp2, 1, mod);
488 
489  } else {
490  if (is_a<Integer>(*x.get_exp())) {
491  if (eq(*x.get_exp(), *integer(2))) {
492  llvm::Value *tmp = apply(*x.get_base());
493  result_ = builder->CreateFMul(tmp, tmp);
494  return;
495  } else {
496  args.push_back(apply(*x.get_base()));
497  int d = numeric_cast<int>(
498  mp_get_si(static_cast<const Integer &>(*x.get_exp())
499  .as_integer_class()));
500  result_ = llvm::ConstantInt::get(
501  llvm::Type::getInt32Ty(mod->getContext()), d, true);
502  args.push_back(result_);
503  fun = get_powi();
504  }
505  } else {
506  args.push_back(apply(*x.get_base()));
507  args.push_back(apply(*x.get_exp()));
508  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
509  llvm::Intrinsic::pow, 1, mod);
510  }
511  }
512  auto r = builder->CreateCall(fun, args);
513  r->setTailCall(true);
514  result_ = r;
515 }
516 
517 void LLVMVisitor::bvisit(const Sin &x)
518 {
520  llvm::Function *fun;
521  args.push_back(apply(*x.get_arg()));
522  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
523  llvm::Intrinsic::sin, 1, mod);
524  auto r = builder->CreateCall(fun, args);
525  r->setTailCall(true);
526  result_ = r;
527 }
528 
529 void LLVMVisitor::bvisit(const Cos &x)
530 {
532  llvm::Function *fun;
533  args.push_back(apply(*x.get_arg()));
534  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
535  llvm::Intrinsic::cos, 1, mod);
536  auto r = builder->CreateCall(fun, args);
537  r->setTailCall(true);
538  result_ = r;
539 }
540 
541 void LLVMVisitor::bvisit(const Piecewise &x)
542 {
544 
545  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
546 
547  if (neq(*pw->get_vec().back().second, *boolTrue)) {
548  throw SymEngineException(
549  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
550  }
551 
552  if (pw->get_vec().size() > 2) {
553  PiecewiseVec rest = pw->get_vec();
554  rest.erase(rest.begin());
555  auto rest_pw = piecewise(std::move(rest));
556  PiecewiseVec new_pw;
557  new_pw.push_back(*pw->get_vec().begin());
558  new_pw.push_back({rest_pw, pw->get_vec().back().second});
559  pw = piecewise(std::move(new_pw))
560  ->rcp_from_this_cast<const Piecewise>();
561  } else if (pw->get_vec().size() < 2) {
562  throw SymEngineException("Invalid Piecewise object");
563  }
564 
565  auto cond_basic = pw->get_vec().front().second;
566  llvm::Value *cond = apply(*cond_basic);
567  // check if cond != 0.0
568  cond = builder->CreateFCmpONE(
569  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
570  "ifcond");
571  llvm::Function *function = builder->GetInsertBlock()->getParent();
572  // Create blocks for the then and else cases. Insert the 'then' block at
573  // the
574  // end of the function.
575  llvm::BasicBlock *then_bb
576  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
577  llvm::BasicBlock *else_bb
578  = llvm::BasicBlock::Create(mod->getContext(), "else");
579  llvm::BasicBlock *merge_bb
580  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
581  builder->CreateCondBr(cond, then_bb, else_bb);
582 
583  // Emit then value.
584  builder->SetInsertPoint(then_bb);
585  llvm::Value *then_value = apply(*pw->get_vec().front().first);
586  builder->CreateBr(merge_bb);
587 
588  // Codegen of 'then_value' can change the current block, update then_bb for
589  // the PHI.
590  then_bb = builder->GetInsertBlock();
591 
592  // Emit else block.
593 #if (LLVM_VERSION_MAJOR < 16)
594  function->getBasicBlockList().push_back(else_bb);
595 #else
596  function->insert(function->end(), else_bb);
597 #endif
598  builder->SetInsertPoint(else_bb);
599  llvm::Value *else_value = apply(*pw->get_vec().back().first);
600  builder->CreateBr(merge_bb);
601 
602  // Codegen of 'else_value' can change the current block, update else_bb for
603  // the PHI.
604  else_bb = builder->GetInsertBlock();
605 
606  // Emit merge block.
607 #if (LLVM_VERSION_MAJOR < 16)
608  function->getBasicBlockList().push_back(merge_bb);
609 #else
610  function->insert(function->end(), merge_bb);
611 #endif
612  builder->SetInsertPoint(merge_bb);
613  llvm::PHINode *phi_node
614  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
615 
616  phi_node->addIncoming(then_value, then_bb);
617  phi_node->addIncoming(else_value, else_bb);
618  result_ = phi_node;
619 }
620 
621 void LLVMVisitor::bvisit(const Sign &x)
622 {
623  const auto x2 = x.get_arg();
624  PiecewiseVec new_pw;
625  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
626  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
627  new_pw.push_back({real_double(1.0), boolTrue});
628  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
629  bvisit(*pw);
630 }
631 
632 void LLVMVisitor::bvisit(const Contains &cts)
633 {
634  llvm::Value *expr = apply(*cts.get_expr());
635  const auto set = cts.get_set();
636  if (is_a<Interval>(*set)) {
637  const auto &interv = down_cast<const Interval &>(*set);
638  llvm::Value *start = apply(*interv.get_start());
639  llvm::Value *end = apply(*interv.get_end());
640  const bool left_open = interv.get_left_open();
641  const bool right_open = interv.get_right_open();
642  llvm::Value *left_ok;
643  llvm::Value *right_ok;
644  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
645  : builder->CreateFCmpOLE(start, expr);
646  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
647  : builder->CreateFCmpOLE(expr, end);
648  result_ = builder->CreateAnd(left_ok, right_ok);
649  result_ = builder->CreateUIToFP(result_,
650  get_float_type(&mod->getContext()));
651  } else {
652  throw SymEngineException("LLVMVisitor: only ``Interval`` "
653  "implemented for ``Contains``.");
654  }
655 }
656 
657 void LLVMVisitor::bvisit(const Infty &x)
658 {
659  if (x.is_negative_infinity()) {
660  result_ = llvm::ConstantFP::getInfinity(
661  get_float_type(&mod->getContext()), true);
662  } else if (x.is_positive_infinity()) {
663  result_ = llvm::ConstantFP::getInfinity(
664  get_float_type(&mod->getContext()), false);
665  } else {
666  throw SymEngineException(
667  "LLVMDouble can only represent real valued infinity");
668  }
669 }
670 
671 void LLVMVisitor::bvisit(const NaN &x)
672 {
673  result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
674  /*negative=*/false, /*payload=*/0);
675 }
676 
677 void LLVMVisitor::bvisit(const BooleanAtom &x)
678 {
679  const bool val = x.get_val();
680  set_double(val ? 1.0 : 0.0);
681 }
682 
683 void LLVMVisitor::bvisit(const Log &x)
684 {
686  llvm::Function *fun;
687  args.push_back(apply(*x.get_arg()));
688  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
689  llvm::Intrinsic::log, 1, mod);
690  auto r = builder->CreateCall(fun, args);
691  r->setTailCall(true);
692  result_ = r;
693 }
694 
695 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
696  void LLVMVisitor::bvisit(const Class &x) \
697  { \
698  llvm::Value *value = nullptr; \
699  llvm::Value *tmp; \
700  set_double(0.0); \
701  llvm::Value *zero_val = result_; \
702  for (auto &p : x.get_container()) { \
703  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
704  if (value == nullptr) { \
705  value = tmp; \
706  } else { \
707  value = builder->method(value, tmp); \
708  } \
709  } \
710  result_ = builder->CreateUIToFP(value, \
711  get_float_type(&mod->getContext())); \
712  }
713 
714 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
715 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
716 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
717 
718 void LLVMVisitor::bvisit(const Not &x)
719 {
720  builder->CreateNot(apply(*x.get_arg()));
721 }
722 
723 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
724  void LLVMVisitor::bvisit(const Class &x) \
725  { \
726  llvm::Value *left = apply(*x.get_arg1()); \
727  llvm::Value *right = apply(*x.get_arg2()); \
728  result_ = builder->method(left, right); \
729  result_ = builder->CreateUIToFP(result_, \
730  get_float_type(&mod->getContext())); \
731  }
732 
733 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
734 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
735 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
736 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
737 
738 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
739  void LLVMDoubleVisitor::visit(const Class &x) \
740  { \
741  vec_basic basic_args = x.get_args(); \
742  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
743  std::vector<llvm::Value *> args; \
744  for (const auto &arg : basic_args) { \
745  args.push_back(apply(*arg)); \
746  } \
747  auto r = builder->CreateCall(func, args); \
748  r->setTailCall(true); \
749  result_ = r; \
750  } \
751  void LLVMFloatVisitor::visit(const Class &x) \
752  { \
753  vec_basic basic_args = x.get_args(); \
754  llvm::Function *func = get_external_function(#ext + std::string("f"), \
755  basic_args.size()); \
756  std::vector<llvm::Value *> args; \
757  for (const auto &arg : basic_args) { \
758  args.push_back(apply(*arg)); \
759  } \
760  auto r = builder->CreateCall(func, args); \
761  r->setTailCall(true); \
762  result_ = r; \
763  }
764 
765 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
766 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
767  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
768  void LLVMLongDoubleVisitor::visit(const Class &x) \
769  { \
770  vec_basic basic_args = x.get_args(); \
771  llvm::Function *func = get_external_function(#ext + std::string("l"), \
772  basic_args.size()); \
773  std::vector<llvm::Value *> args; \
774  for (const auto &arg : basic_args) { \
775  args.push_back(apply(*arg)); \
776  } \
777  auto r = builder->CreateCall(func, args); \
778  r->setTailCall(true); \
779  result_ = r; \
780  }
781 #else
782 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
783  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
784 #endif
785 
786 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
787 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
788 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
789 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
790 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
791 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
792 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
793 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
794 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
795 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
796 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
797 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
798 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
799 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
800 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
801 
802 void LLVMVisitor::bvisit(const Abs &x)
803 {
805  llvm::Function *fun;
806  args.push_back(apply(*x.get_arg()));
807  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
808  llvm::Intrinsic::fabs, 1, mod);
809  auto r = builder->CreateCall(fun, args);
810  r->setTailCall(true);
811  result_ = r;
812 }
813 
814 void LLVMVisitor::bvisit(const Min &x)
815 {
816  llvm::Value *value = nullptr;
817  llvm::Function *fun;
818  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
819  llvm::Intrinsic::minnum, 1, mod);
820  for (auto &arg : x.get_vec()) {
821  if (value != nullptr) {
823  args.push_back(value);
824  args.push_back(apply(*arg));
825  auto r = builder->CreateCall(fun, args);
826  r->setTailCall(true);
827  value = r;
828  } else {
829  value = apply(*arg);
830  }
831  }
832  result_ = value;
833 }
834 
835 void LLVMVisitor::bvisit(const Max &x)
836 {
837  llvm::Value *value = nullptr;
838  llvm::Function *fun;
839  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
840  llvm::Intrinsic::maxnum, 1, mod);
841  for (auto &arg : x.get_vec()) {
842  if (value != nullptr) {
844  args.push_back(value);
845  args.push_back(apply(*arg));
846  auto r = builder->CreateCall(fun, args);
847  r->setTailCall(true);
848  value = r;
849  } else {
850  value = apply(*arg);
851  }
852  }
853  result_ = value;
854 }
855 
856 void LLVMVisitor::bvisit(const Symbol &x)
857 {
858  unsigned i = 0;
859  for (auto &symb : symbols) {
860  if (eq(x, *symb)) {
861  result_ = symbol_ptrs[i];
862  return;
863  }
864  ++i;
865  }
866  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
867  if (it != replacement_symbol_ptrs.end()) {
868  result_ = it->second;
869  return;
870  }
871 
872  throw SymEngineException("Symbol " + x.__str__()
873  + " not in the symbols vector.");
874 }
875 
876 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
877  size_t nargs)
878 {
879  std::vector<llvm::Type *> func_args(nargs,
880  get_float_type(&mod->getContext()));
881  llvm::FunctionType *func_type = llvm::FunctionType::get(
882  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
883 
884  llvm::Function *func = mod->getFunction(name);
885  if (!func) {
886  func = llvm::Function::Create(
887  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
888  func->setCallingConv(llvm::CallingConv::C);
889  }
890 #if (LLVM_VERSION_MAJOR < 5)
891  llvm::AttributeSet func_attr_set;
892  {
893  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
894  llvm::AttributeSet attr_set;
895  {
896  llvm::AttrBuilder attr_builder;
897  attr_builder.addAttribute(llvm::Attribute::NoUnwind);
898  attr_set
899  = llvm::AttributeSet::get(mod->getContext(), ~0U, attr_builder);
900  }
901 
902  attrs.push_back(attr_set);
903  func_attr_set = llvm::AttributeSet::get(mod->getContext(), attrs);
904  }
905  func->setAttributes(func_attr_set);
906 #else
907  func->addFnAttr(llvm::Attribute::NoUnwind);
908 #endif
909  return func;
910 }
911 
912 void LLVMVisitor::bvisit(const Constant &x)
913 {
914  set_double(eval_double(x));
915 }
916 
917 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
918 void LLVMLongDoubleVisitor::visit(const Constant &x)
919 {
920  convert_from_mpfr(x);
921 }
922 #endif
923 
924 void LLVMVisitor::bvisit(const Basic &x)
925 {
926  throw NotImplementedError(x.__str__());
927 }
928 
929 const std::string &LLVMVisitor::dumps() const
930 {
931  return membuffer;
932 };
933 
934 void LLVMVisitor::loads(const std::string &s)
935 {
936  membuffer = s;
937  llvm::InitializeNativeTarget();
938  llvm::InitializeNativeTargetAsmPrinter();
939  llvm::InitializeNativeTargetAsmParser();
940  context = std::make_shared<llvm::LLVMContext>();
941 
942  // Create some module to put our function into it.
944  = make_unique<llvm::Module>("SymEngine", *context);
945  module->setDataLayout("");
946  mod = module.get();
947 
948  // Only defining the prototype for the function here.
949  // Since we know where the function is stored that's enough
950  // llvm::ObjectCache is designed for caching objects, but it
951  // is used here for loading one specific object.
952  auto F = get_function_type(context.get());
953 
954  std::string error;
955  executionengine = std::shared_ptr<llvm::ExecutionEngine>(
956  llvm::EngineBuilder(std::move(module))
957  .setEngineKind(llvm::EngineKind::Kind::JIT)
958  .setOptLevel(llvm::CodeGenOpt::Level::Aggressive)
959  .setErrorStr(&error)
960  .create());
961 
962  class MCJITObjectLoader : public llvm::ObjectCache
963  {
964  const std::string &s_;
965 
966  public:
967  MCJITObjectLoader(const std::string &s) : s_(s) {}
968  void notifyObjectCompiled(const llvm::Module *M,
969  llvm::MemoryBufferRef obj) override
970  {
971  }
972 
973  // No need to check M because there is only one function
974  // Return it after reading from the file.
976  getObject(const llvm::Module *M) override
977  {
978  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
979  }
980  };
981 
982  MCJITObjectLoader loader(s);
983  executionengine->setObjectCache(&loader);
984  executionengine->finalizeObject();
985  // Set func to compiled function pointer
986  func = (intptr_t)executionengine->getPointerToFunction(F);
987 }
988 
989 void LLVMVisitor::bvisit(const Floor &x)
990 {
992  llvm::Function *fun;
993  args.push_back(apply(*x.get_arg()));
994  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
995  llvm::Intrinsic::floor, 1, mod);
996  auto r = builder->CreateCall(fun, args);
997  r->setTailCall(true);
998  result_ = r;
999 }
1000 
1001 void LLVMVisitor::bvisit(const Ceiling &x)
1002 {
1004  llvm::Function *fun;
1005  args.push_back(apply(*x.get_arg()));
1006  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1007  llvm::Intrinsic::ceil, 1, mod);
1008  auto r = builder->CreateCall(fun, args);
1009  r->setTailCall(true);
1010  result_ = r;
1011 }
1012 
1013 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1014 {
1015  apply(*x.get_arg());
1016 }
1017 
1018 void LLVMVisitor::bvisit(const Truncate &x)
1019 {
1021  llvm::Function *fun;
1022  args.push_back(apply(*x.get_arg()));
1023  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1024  llvm::Intrinsic::trunc, 1, mod);
1025  auto r = builder->CreateCall(fun, args);
1026  r->setTailCall(true);
1027  result_ = r;
1028 }
1029 
1030 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1031 {
1032  return llvm::Type::getDoubleTy(*context);
1033 }
1034 
1035 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1036 {
1037  return llvm::Type::getFloatTy(*context);
1038 }
1039 
1040 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1041 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1042 {
1043  return llvm::Type::getX86_FP80Ty(*context);
1044 }
1045 #endif
1046 
1047 } // namespace SymEngine
T assign(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
T data(T... args)
T end(T... args)
T erase(T... args)
T get(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:768
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:653
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127
T push_back(T... args)