llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <memory>
34 #include <vector>
35 #include <fstream>
36 
37 #include <symengine/llvm_double.h>
38 #include <symengine/eval_double.h>
39 #include <symengine/eval.h>
40 
41 namespace SymEngine
42 {
43 
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
46 #else
47 using llvm::make_unique;
48 #endif
49 
50 #if (LLVM_VERSION_MAJOR >= 18)
51 using CodeGenOptLevel = llvm::CodeGenOptLevel;
52 #else
53 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
54 #endif
55 
56 #if (LLVM_VERSION_MAJOR >= 20)
57 auto GetDeclaration = [](llvm::Module *M, llvm::Intrinsic::ID id,
58  llvm::ArrayRef<llvm::Type *> Tys) {
59  return llvm::Intrinsic::getOrInsertDeclaration(M, id, Tys);
60 };
61 #else
62 const auto &GetDeclaration = llvm::Intrinsic::getDeclaration;
63 #endif
64 
65 #if (LLVM_VERSION_MAJOR >= 21)
66 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
67 #else
68 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
69 #endif
70 
71 class IRBuilder : public llvm::IRBuilder<>
72 {
73 };
74 
75 LLVMVisitor::LLVMVisitor() = default;
76 LLVMVisitor::~LLVMVisitor() = default;
77 
78 llvm::Value *LLVMVisitor::apply(const Basic &b)
79 {
80  b.accept(*this);
81  return result_;
82 }
83 
84 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
85  unsigned opt_level)
86 {
87  init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
88 }
89 
90 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
91 {
92  std::vector<llvm::Type *> inp;
93  for (int i = 0; i < 2; i++) {
94 #if (LLVM_VERSION_MAJOR >= 22)
95  inp.push_back(llvm::PointerType::get(*context, 0));
96 #else
97  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
98 #endif
99  }
100  llvm::FunctionType *function_type = llvm::FunctionType::get(
101  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
102  auto F = llvm::Function::Create(
103  function_type, llvm::Function::InternalLinkage, "symengine_func", mod);
104  F->setCallingConv(llvm::CallingConv::C);
105  F->addParamAttr(0, llvm::Attribute::ReadOnly);
106 #if (LLVM_VERSION_MAJOR >= 21)
107  F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
108  *context, llvm::CaptureInfo::none()));
109  F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
110  *context, llvm::CaptureInfo::none()));
111 #else
112  F->addParamAttr(0, llvm::Attribute::NoCapture);
113  F->addParamAttr(1, llvm::Attribute::NoCapture);
114 #endif
115  F->addFnAttr(llvm::Attribute::NoUnwind);
116 #if (LLVM_VERSION_MAJOR < 15)
117  F->addFnAttr(llvm::Attribute::UWTable);
118 #else
119  F->addFnAttr(llvm::Attribute::getWithUWTableKind(
120  *context, llvm::UWTableKind::Default));
121 #endif
122  return F;
123 }
124 
125 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
126  const bool symbolic_cse, unsigned opt_level)
127 {
128  executionengine.reset();
129  llvm::InitializeNativeTarget();
130  llvm::InitializeNativeTargetAsmPrinter();
131  llvm::InitializeNativeTargetAsmParser();
132  context = make_unique<llvm::LLVMContext>();
133  symbols = inputs;
134 
135  // Create some module to put our function into it.
136  std::unique_ptr<llvm::Module> module
137  = make_unique<llvm::Module>("SymEngine", *context.get());
138  module->setDataLayout("");
139  mod = module.get();
140 
141  auto F = get_function_type(context.get());
142 
143  // Add a basic block to the function. As before, it automatically
144  // inserts
145  // because of the last argument.
146  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
147 
148  // Create a basic block builder with default parameters. The builder
149  // will
150  // automatically append instructions to the basic block `BB'.
151  llvm::IRBuilder<> _builder(BB);
152  builder = reinterpret_cast<IRBuilder *>(&_builder);
153  builder->SetInsertPoint(BB);
154  auto fmf = llvm::FastMathFlags();
155  // fmf.setUnsafeAlgebra();
156  builder->setFastMathFlags(fmf);
157 
158  // Load all the symbols and create references
159  auto input_arg = &(*(F->args().begin()));
160  for (unsigned i = 0; i < inputs.size(); i++) {
161  if (not is_a<Symbol>(*inputs[i])) {
162  throw SymEngineException("Input contains a non-symbol.");
163  }
164  auto index
165  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
166  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
167  index);
168  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
169  symbol_ptrs.push_back(result_);
170  }
171 
172  auto it = F->args().begin();
173  auto out = &(*(it + 1));
174  std::vector<llvm::Value *> output_vals;
175 
176  if (symbolic_cse) {
177  vec_basic reduced_exprs;
178  vec_pair replacements;
179  // cse the outputs
180  SymEngine::cse(replacements, reduced_exprs, outputs);
181  for (auto &rep : replacements) {
182  // Store the replacement symbol values in a dictionary
183  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
184  }
185  // Generate IR for all the reduced exprs and save references
186  for (unsigned i = 0; i < outputs.size(); i++) {
187  output_vals.push_back(apply(*reduced_exprs[i]));
188  }
189  } else {
190  // Generate IR for all the output exprs and save references
191  for (unsigned i = 0; i < outputs.size(); i++) {
192  output_vals.push_back(apply(*outputs[i]));
193  }
194  }
195 
196  // Store all the output exprs at the end
197  for (unsigned i = 0; i < outputs.size(); i++) {
198  auto index
199  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
200  auto ptr
201  = builder->CreateGEP(get_float_type(context.get()), out, index);
202  builder->CreateStore(output_vals[i], ptr);
203  }
204 
205  // Create the return instruction and add it to the basic block
206  builder->CreateRetVoid();
207 
208  // Validate the generated code, checking for consistency.
209  llvm::verifyFunction(*F, &llvm::outs());
210 
211  // std::cout << "LLVM IR" << std::endl;
212  // module->print(llvm::errs(), nullptr);
213 
214  // Optimize the function using default passes from PassBuilder
215  // FunctionSimplificationPipeline for the opt_level
216 #if (LLVM_VERSION_MAJOR < 14)
217  using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
218 #else
219  using OptimizationLevel = llvm::OptimizationLevel;
220 #endif
221  llvm::PassBuilder PB;
222  llvm::ModuleAnalysisManager MAM;
223  llvm::CGSCCAnalysisManager CGAM;
224  llvm::FunctionAnalysisManager FAM;
225  llvm::LoopAnalysisManager LAM;
226  PB.registerModuleAnalyses(MAM);
227  PB.registerCGSCCAnalyses(CGAM);
228  PB.registerFunctionAnalyses(FAM);
229  PB.registerLoopAnalyses(LAM);
230  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
231  llvm::FunctionPassManager FPM;
232  OptimizationLevel pb_opt_level{OptimizationLevel::O3};
233  if (opt_level == 0) {
234  pb_opt_level = OptimizationLevel::O0;
235  } else if (opt_level == 1) {
236  pb_opt_level = OptimizationLevel::O1;
237  } else if (opt_level == 2) {
238  pb_opt_level = OptimizationLevel::O2;
239  }
240 
241  if (opt_level != 0) {
242 #if (LLVM_VERSION_MAJOR < 6)
243  FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
244 #elif (LLVM_VERSION_MAJOR < 12)
245  FPM = PB.buildFunctionSimplificationPipeline(
246  pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
247 #else
248  FPM = PB.buildFunctionSimplificationPipeline(
249  pb_opt_level, llvm::ThinOrFullLTOPhase::None);
250 #endif
251  FPM.run(*F, FAM);
252  }
253 
254  // std::cout << "Optimized LLVM IR" << std::endl;
255  // module->print(llvm::errs(), nullptr);
256 
257  // Now we create the JIT.
258  std::string error;
259  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
260  llvm::EngineBuilder(std::move(module))
261  .setEngineKind(llvm::EngineKind::Kind::JIT)
262  .setOptLevel(static_cast<CodeGenOptLevel>(opt_level))
263  .setErrorStr(&error)
264  .create());
265 
266  // Customization point for subclasses: (may e.g. register custom symbol
267  // resolver)
268  modify_execution_engine(executionengine.get());
269 
270  // This is a hack to get the MemoryBuffer of a compiled object.
271  class MemoryBufferRefCallback : public llvm::ObjectCache
272  {
273  public:
274  std::string &ss_;
275  explicit MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
276 
277  void notifyObjectCompiled(const llvm::Module *M,
278  llvm::MemoryBufferRef obj) override
279  {
280  const char *c = obj.getBufferStart();
281  // Saving the object code in a std::string
282  ss_.assign(c, obj.getBufferSize());
283  }
284 
285  std::unique_ptr<llvm::MemoryBuffer>
286  getObject(const llvm::Module *M) override
287  {
288  return nullptr;
289  }
290  };
291 
292  MemoryBufferRefCallback callback(membuffer);
293  executionengine->setObjectCache(&callback);
294  // std::cout << error << std::endl;
295  executionengine->finalizeObject();
296 
297  // Get the symbol's address
298  func = (intptr_t)executionengine->getPointerToFunction(F);
299  symbol_ptrs.clear();
300  replacement_symbol_ptrs.clear();
301  symbols.clear();
302 }
303 
304 LLVMDoubleVisitor::LLVMDoubleVisitor() = default;
305 LLVMDoubleVisitor::~LLVMDoubleVisitor() = default;
306 
307 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
308 {
309  double ret;
310  ((double (*)(const double *, double *))func)(vec.data(), &ret);
311  return ret;
312 }
313 
314 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
315 {
316  ((double (*)(const double *, double *))func)(inps, outs);
317 }
318 
319 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
320 long double
321 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
322 {
323  long double ret;
324  ((long double (*)(const long double *, long double *))func)(vec.data(),
325  &ret);
326  return ret;
327 }
328 
329 void LLVMLongDoubleVisitor::call(long double *outs,
330  const long double *inps) const
331 {
332  ((long double (*)(const long double *, long double *))func)(inps, outs);
333 }
334 #endif
335 
336 LLVMFloatVisitor::LLVMFloatVisitor() = default;
337 LLVMFloatVisitor::~LLVMFloatVisitor() = default;
338 
339 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
340 {
341  float ret;
342  ((float (*)(const float *, float *))func)(vec.data(), &ret);
343  return ret;
344 }
345 
346 void LLVMFloatVisitor::call(float *outs, const float *inps) const
347 {
348  ((float (*)(const float *, float *))func)(inps, outs);
349 }
350 
351 void LLVMVisitor::set_double(double d)
352 {
353  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
354 }
355 
356 void LLVMVisitor::bvisit(const Integer &x)
357 {
358  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
359  mp_get_d(x.as_integer_class()));
360 }
361 
362 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
363 
364 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() = default;
365 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() = default;
366 
367 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
368 {
369 #ifndef HAVE_SYMENGINE_MPFR
370  throw NotImplementedError("Cannot convert to long double without MPFR");
371 #else
372  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
373  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
374  m->__str__());
375 #endif
376 }
377 
378 void LLVMLongDoubleVisitor::visit(const Integer &x)
379 {
380  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
381  x.__str__());
382 }
383 #endif
384 
385 void LLVMVisitor::bvisit(const Rational &x)
386 {
387  set_double(mp_get_d(x.as_rational_class()));
388 }
389 
390 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
391 void LLVMLongDoubleVisitor::visit(const Rational &x)
392 {
393  convert_from_mpfr(x);
394 }
395 #endif
396 
397 void LLVMVisitor::bvisit(const RealDouble &x)
398 {
399  set_double(x.i);
400 }
401 
402 #ifdef HAVE_SYMENGINE_MPFR
403 void LLVMVisitor::bvisit(const RealMPFR &x)
404 {
405  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
406 }
407 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
408 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
409 {
410  convert_from_mpfr(x);
411 }
412 #endif
413 #endif
414 
415 void LLVMVisitor::bvisit(const Add &x)
416 {
417  llvm::Value *tmp, *tmp1, *tmp2;
418  auto it = x.get_dict().begin();
419 
420  if (eq(*x.get_coef(), *zero)) {
421  // `x + 0.0` is not optimized out
422  if (eq(*one, *(it->second))) {
423  tmp = apply(*(it->first));
424  } else {
425  tmp1 = apply(*(it->first));
426  tmp2 = apply(*(it->second));
427  tmp = builder->CreateFMul(tmp1, tmp2);
428  }
429  ++it;
430  } else {
431  tmp = apply(*x.get_coef());
432  }
433 
434  for (; it != x.get_dict().end(); ++it) {
435  if (eq(*one, *(it->second))) {
436  tmp1 = apply(*(it->first));
437  tmp = builder->CreateFAdd(tmp, tmp1);
438  } else {
439  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
440  // tmp =
441  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
442  // llvm::Intrinsic::fma,
443  // 3, context), args);
444  tmp1 = apply(*(it->first));
445  tmp2 = apply(*(it->second));
446  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
447  }
448  }
449  result_ = tmp;
450 }
451 
452 void LLVMVisitor::bvisit(const Mul &x)
453 {
454  llvm::Value *tmp = nullptr;
455  bool first = true;
456  for (const auto &p : x.get_args()) {
457  if (first) {
458  tmp = apply(*p);
459  } else {
460  tmp = builder->CreateFMul(tmp, apply(*p));
461  }
462  first = false;
463  }
464  result_ = tmp;
465 }
466 
467 llvm::Function *LLVMVisitor::get_powi()
468 {
469  std::vector<llvm::Type *> arg_type;
470  arg_type.push_back(get_float_type(&mod->getContext()));
471 #if (LLVM_VERSION_MAJOR > 12)
472  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
473 #endif
474  return GetDeclaration(mod, llvm::Intrinsic::powi, arg_type);
475 }
476 
477 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
478  unsigned n, llvm::Module *mod)
479 {
480  std::vector<llvm::Type *> arg_type(n, type);
481  return GetDeclaration(mod, id, arg_type);
482 }
483 
484 void LLVMVisitor::bvisit(const Pow &x)
485 {
486  std::vector<llvm::Value *> args;
487  llvm::Function *fun;
488  if (eq(*(x.get_base()), *E)) {
489  args.push_back(apply(*x.get_exp()));
490  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
491  llvm::Intrinsic::exp, 1, mod);
492 
493  } else if (eq(*(x.get_base()), *integer(2))) {
494  args.push_back(apply(*x.get_exp()));
495  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
496  llvm::Intrinsic::exp2, 1, mod);
497 
498  } else {
499  if (is_a<Integer>(*x.get_exp())) {
500  if (eq(*x.get_exp(), *integer(2))) {
501  llvm::Value *tmp = apply(*x.get_base());
502  result_ = builder->CreateFMul(tmp, tmp);
503  return;
504  } else {
505  args.push_back(apply(*x.get_base()));
506  int d = numeric_cast<int>(
507  mp_get_si(static_cast<const Integer &>(*x.get_exp())
508  .as_integer_class()));
509  result_ = llvm::ConstantInt::get(
510  llvm::Type::getInt32Ty(mod->getContext()), d, true);
511  args.push_back(result_);
512  fun = get_powi();
513  }
514  } else {
515  args.push_back(apply(*x.get_base()));
516  args.push_back(apply(*x.get_exp()));
517  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
518  llvm::Intrinsic::pow, 1, mod);
519  }
520  }
521  auto r = builder->CreateCall(fun, args);
522  r->setTailCall(true);
523  result_ = r;
524 }
525 
526 void LLVMVisitor::bvisit(const Sin &x)
527 {
528  std::vector<llvm::Value *> args;
529  llvm::Function *fun;
530  args.push_back(apply(*x.get_arg()));
531  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
532  llvm::Intrinsic::sin, 1, mod);
533  auto r = builder->CreateCall(fun, args);
534  r->setTailCall(true);
535  result_ = r;
536 }
537 
538 void LLVMVisitor::bvisit(const Cos &x)
539 {
540  std::vector<llvm::Value *> args;
541  llvm::Function *fun;
542  args.push_back(apply(*x.get_arg()));
543  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
544  llvm::Intrinsic::cos, 1, mod);
545  auto r = builder->CreateCall(fun, args);
546  r->setTailCall(true);
547  result_ = r;
548 }
549 
550 void LLVMVisitor::bvisit(const Piecewise &x)
551 {
552  std::vector<llvm::BasicBlock> blocks;
553 
554  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
555 
556  if (neq(*pw->get_vec().back().second, *boolTrue)) {
557  throw SymEngineException(
558  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
559  }
560 
561  if (pw->get_vec().size() > 2) {
562  PiecewiseVec rest = pw->get_vec();
563  rest.erase(rest.begin());
564  auto rest_pw = piecewise(std::move(rest));
565  PiecewiseVec new_pw;
566  new_pw.push_back(*pw->get_vec().begin());
567  new_pw.push_back({rest_pw, pw->get_vec().back().second});
568  pw = piecewise(std::move(new_pw))
569  ->rcp_from_this_cast<const Piecewise>();
570  } else if (pw->get_vec().size() < 2) {
571  throw SymEngineException("Invalid Piecewise object");
572  }
573 
574  auto cond_basic = pw->get_vec().front().second;
575  llvm::Value *cond = apply(*cond_basic);
576  // check if cond != 0.0
577  cond = builder->CreateFCmpONE(
578  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
579  "ifcond");
580  llvm::Function *function = builder->GetInsertBlock()->getParent();
581  // Create blocks for the then and else cases. Insert the 'then' block at
582  // the
583  // end of the function.
584  llvm::BasicBlock *then_bb
585  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
586  llvm::BasicBlock *else_bb
587  = llvm::BasicBlock::Create(mod->getContext(), "else");
588  llvm::BasicBlock *merge_bb
589  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
590  builder->CreateCondBr(cond, then_bb, else_bb);
591 
592  // Emit then value.
593  builder->SetInsertPoint(then_bb);
594  llvm::Value *then_value = apply(*pw->get_vec().front().first);
595  builder->CreateBr(merge_bb);
596 
597  // Codegen of 'then_value' can change the current block, update then_bb for
598  // the PHI.
599  then_bb = builder->GetInsertBlock();
600 
601  // Emit else block.
602 #if (LLVM_VERSION_MAJOR < 16)
603  function->getBasicBlockList().push_back(else_bb);
604 #else
605  function->insert(function->end(), else_bb);
606 #endif
607  builder->SetInsertPoint(else_bb);
608  llvm::Value *else_value = apply(*pw->get_vec().back().first);
609  builder->CreateBr(merge_bb);
610 
611  // Codegen of 'else_value' can change the current block, update else_bb for
612  // the PHI.
613  else_bb = builder->GetInsertBlock();
614 
615  // Emit merge block.
616 #if (LLVM_VERSION_MAJOR < 16)
617  function->getBasicBlockList().push_back(merge_bb);
618 #else
619  function->insert(function->end(), merge_bb);
620 #endif
621  builder->SetInsertPoint(merge_bb);
622  llvm::PHINode *phi_node
623  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
624 
625  phi_node->addIncoming(then_value, then_bb);
626  phi_node->addIncoming(else_value, else_bb);
627  result_ = phi_node;
628 }
629 
630 void LLVMVisitor::bvisit(const Sign &x)
631 {
632  const auto x2 = x.get_arg();
633  PiecewiseVec new_pw;
634  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
635  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
636  new_pw.push_back({real_double(1.0), boolTrue});
637  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
638  bvisit(*pw);
639 }
640 
641 void LLVMVisitor::bvisit(const Contains &cts)
642 {
643  llvm::Value *expr = apply(*cts.get_expr());
644  const auto set = cts.get_set();
645  if (is_a<Interval>(*set)) {
646  const auto &interv = down_cast<const Interval &>(*set);
647  llvm::Value *start = apply(*interv.get_start());
648  llvm::Value *end = apply(*interv.get_end());
649  const bool left_open = interv.get_left_open();
650  const bool right_open = interv.get_right_open();
651  llvm::Value *left_ok;
652  llvm::Value *right_ok;
653  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
654  : builder->CreateFCmpOLE(start, expr);
655  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
656  : builder->CreateFCmpOLE(expr, end);
657  result_ = builder->CreateAnd(left_ok, right_ok);
658  result_ = builder->CreateUIToFP(result_,
659  get_float_type(&mod->getContext()));
660  } else {
661  throw SymEngineException("LLVMVisitor: only ``Interval`` "
662  "implemented for ``Contains``.");
663  }
664 }
665 
666 void LLVMVisitor::bvisit(const Infty &x)
667 {
668  if (x.is_negative_infinity()) {
669  result_ = llvm::ConstantFP::getInfinity(
670  get_float_type(&mod->getContext()), true);
671  } else if (x.is_positive_infinity()) {
672  result_ = llvm::ConstantFP::getInfinity(
673  get_float_type(&mod->getContext()), false);
674  } else {
675  throw SymEngineException(
676  "LLVMDouble can only represent real valued infinity");
677  }
678 }
679 
680 void LLVMVisitor::bvisit(const NaN &x)
681 {
682  result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
683  /*negative=*/false, /*payload=*/0);
684 }
685 
686 void LLVMVisitor::bvisit(const BooleanAtom &x)
687 {
688  const bool val = x.get_val();
689  set_double(val ? 1.0 : 0.0);
690 }
691 
692 void LLVMVisitor::bvisit(const Log &x)
693 {
694  std::vector<llvm::Value *> args;
695  llvm::Function *fun;
696  args.push_back(apply(*x.get_arg()));
697  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
698  llvm::Intrinsic::log, 1, mod);
699  auto r = builder->CreateCall(fun, args);
700  r->setTailCall(true);
701  result_ = r;
702 }
703 
704 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
705  void LLVMVisitor::bvisit(const Class &x) \
706  { \
707  llvm::Value *value = nullptr; \
708  llvm::Value *tmp; \
709  set_double(0.0); \
710  llvm::Value *zero_val = result_; \
711  for (auto &p : x.get_container()) { \
712  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
713  if (value == nullptr) { \
714  value = tmp; \
715  } else { \
716  value = builder->method(value, tmp); \
717  } \
718  } \
719  result_ = builder->CreateUIToFP(value, \
720  get_float_type(&mod->getContext())); \
721  }
722 
723 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
724 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
725 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
726 
727 void LLVMVisitor::bvisit(const Not &x)
728 {
729  set_double(0.0);
730  llvm::Value *zero_val = result_;
731  llvm::Value *value = builder->CreateFCmpONE(apply(*x.get_arg()), zero_val);
732  result_ = builder->CreateUIToFP(builder->CreateNot(value),
733  get_float_type(&mod->getContext()));
734 }
735 
736 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
737  void LLVMVisitor::bvisit(const Class &x) \
738  { \
739  llvm::Value *left = apply(*x.get_arg1()); \
740  llvm::Value *right = apply(*x.get_arg2()); \
741  result_ = builder->method(left, right); \
742  result_ = builder->CreateUIToFP(result_, \
743  get_float_type(&mod->getContext())); \
744  }
745 
746 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
747 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
748 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
749 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
750 
751 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
752  void LLVMDoubleVisitor::visit(const Class &x) \
753  { \
754  vec_basic basic_args = x.get_args(); \
755  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
756  std::vector<llvm::Value *> args; \
757  for (const auto &arg : basic_args) { \
758  args.push_back(apply(*arg)); \
759  } \
760  auto r = builder->CreateCall(func, args); \
761  r->setTailCall(true); \
762  result_ = r; \
763  } \
764  void LLVMFloatVisitor::visit(const Class &x) \
765  { \
766  vec_basic basic_args = x.get_args(); \
767  llvm::Function *func = get_external_function(#ext + std::string("f"), \
768  basic_args.size()); \
769  std::vector<llvm::Value *> args; \
770  for (const auto &arg : basic_args) { \
771  args.push_back(apply(*arg)); \
772  } \
773  auto r = builder->CreateCall(func, args); \
774  r->setTailCall(true); \
775  result_ = r; \
776  }
777 
778 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
779 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
780  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
781  void LLVMLongDoubleVisitor::visit(const Class &x) \
782  { \
783  vec_basic basic_args = x.get_args(); \
784  llvm::Function *func = get_external_function(#ext + std::string("l"), \
785  basic_args.size()); \
786  std::vector<llvm::Value *> args; \
787  for (const auto &arg : basic_args) { \
788  args.push_back(apply(*arg)); \
789  } \
790  auto r = builder->CreateCall(func, args); \
791  r->setTailCall(true); \
792  result_ = r; \
793  }
794 #else
795 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
796  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
797 #endif
798 
799 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
800 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
801 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
802 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
803 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
804 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
805 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
806 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
807 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
808 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
809 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
810 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
811 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
812 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
813 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
814 
815 void LLVMVisitor::bvisit(const Abs &x)
816 {
817  std::vector<llvm::Value *> args;
818  llvm::Function *fun;
819  args.push_back(apply(*x.get_arg()));
820  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
821  llvm::Intrinsic::fabs, 1, mod);
822  auto r = builder->CreateCall(fun, args);
823  r->setTailCall(true);
824  result_ = r;
825 }
826 
827 void LLVMVisitor::bvisit(const Min &x)
828 {
829  llvm::Value *value = nullptr;
830  llvm::Function *fun;
831  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
832  llvm::Intrinsic::minnum, 1, mod);
833  for (auto &arg : x.get_vec()) {
834  if (value != nullptr) {
835  std::vector<llvm::Value *> args;
836  args.push_back(value);
837  args.push_back(apply(*arg));
838  auto r = builder->CreateCall(fun, args);
839  r->setTailCall(true);
840  value = r;
841  } else {
842  value = apply(*arg);
843  }
844  }
845  result_ = value;
846 }
847 
848 void LLVMVisitor::bvisit(const Max &x)
849 {
850  llvm::Value *value = nullptr;
851  llvm::Function *fun;
852  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
853  llvm::Intrinsic::maxnum, 1, mod);
854  for (auto &arg : x.get_vec()) {
855  if (value != nullptr) {
856  std::vector<llvm::Value *> args;
857  args.push_back(value);
858  args.push_back(apply(*arg));
859  auto r = builder->CreateCall(fun, args);
860  r->setTailCall(true);
861  value = r;
862  } else {
863  value = apply(*arg);
864  }
865  }
866  result_ = value;
867 }
868 
869 void LLVMVisitor::bvisit(const Symbol &x)
870 {
871  unsigned i = 0;
872  for (auto &symb : symbols) {
873  if (eq(x, *symb)) {
874  result_ = symbol_ptrs[i];
875  return;
876  }
877  ++i;
878  }
879  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
880  if (it != replacement_symbol_ptrs.end()) {
881  result_ = it->second;
882  return;
883  }
884 
885  throw SymEngineException("Symbol " + x.__str__()
886  + " not in the symbols vector.");
887 }
888 
889 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
890  size_t nargs)
891 {
892  std::vector<llvm::Type *> func_args(nargs,
893  get_float_type(&mod->getContext()));
894  llvm::FunctionType *func_type = llvm::FunctionType::get(
895  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
896 
897  llvm::Function *func = mod->getFunction(name);
898  if (!func) {
899  func = llvm::Function::Create(
900  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
901  func->setCallingConv(llvm::CallingConv::C);
902  }
903  func->addFnAttr(llvm::Attribute::NoUnwind);
904  return func;
905 }
906 
907 void LLVMVisitor::bvisit(const Constant &x)
908 {
909  set_double(eval_double(x));
910 }
911 
912 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
913 void LLVMLongDoubleVisitor::visit(const Constant &x)
914 {
915  convert_from_mpfr(x);
916 }
917 #endif
918 
919 void LLVMVisitor::bvisit(const Basic &x)
920 {
921  throw NotImplementedError(x.__str__());
922 }
923 
924 const std::string &LLVMVisitor::dumps() const
925 {
926  return membuffer;
927 };
928 
929 void LLVMVisitor::loads(const std::string &s)
930 {
931  membuffer = s;
932  llvm::InitializeNativeTarget();
933  llvm::InitializeNativeTargetAsmPrinter();
934  llvm::InitializeNativeTargetAsmParser();
935  context = make_unique<llvm::LLVMContext>();
936 
937  // Create some module to put our function into it.
938  std::unique_ptr<llvm::Module> module
939  = make_unique<llvm::Module>("SymEngine", *context);
940  module->setDataLayout("");
941  mod = module.get();
942 
943  // Only defining the prototype for the function here.
944  // Since we know where the function is stored that's enough
945  // llvm::ObjectCache is designed for caching objects, but it
946  // is used here for loading one specific object.
947  auto F = get_function_type(context.get());
948 
949  std::string error;
950  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
951  llvm::EngineBuilder(std::move(module))
952  .setEngineKind(llvm::EngineKind::Kind::JIT)
953  .setOptLevel(CodeGenOptLevel::Aggressive)
954  .setErrorStr(&error)
955  .create());
956 
957  // Customization point for subclasses: (may e.g. register custom symbol
958  // resolver)
959  modify_execution_engine(executionengine.get());
960 
961  class MCJITObjectLoader : public llvm::ObjectCache
962  {
963  const std::string &s_;
964 
965  public:
966  MCJITObjectLoader(const std::string &s) : s_(s) {}
967  void notifyObjectCompiled(const llvm::Module *M,
968  llvm::MemoryBufferRef obj) override
969  {
970  }
971 
972  // No need to check M because there is only one function
973  // Return it after reading from the file.
974  std::unique_ptr<llvm::MemoryBuffer>
975  getObject(const llvm::Module *M) override
976  {
977  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
978  }
979  };
980 
981  MCJITObjectLoader loader(s);
982  executionengine->setObjectCache(&loader);
983  executionengine->finalizeObject();
984  // Set func to compiled function pointer
985  func = (intptr_t)executionengine->getPointerToFunction(F);
986 }
987 
988 void LLVMVisitor::bvisit(const Floor &x)
989 {
990  std::vector<llvm::Value *> args;
991  llvm::Function *fun;
992  args.push_back(apply(*x.get_arg()));
993  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
994  llvm::Intrinsic::floor, 1, mod);
995  auto r = builder->CreateCall(fun, args);
996  r->setTailCall(true);
997  result_ = r;
998 }
999 
1000 void LLVMVisitor::bvisit(const Ceiling &x)
1001 {
1002  std::vector<llvm::Value *> args;
1003  llvm::Function *fun;
1004  args.push_back(apply(*x.get_arg()));
1005  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1006  llvm::Intrinsic::ceil, 1, mod);
1007  auto r = builder->CreateCall(fun, args);
1008  r->setTailCall(true);
1009  result_ = r;
1010 }
1011 
1012 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1013 {
1014  apply(*x.get_arg());
1015 }
1016 
1017 void LLVMVisitor::bvisit(const Truncate &x)
1018 {
1019  std::vector<llvm::Value *> args;
1020  llvm::Function *fun;
1021  args.push_back(apply(*x.get_arg()));
1022  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1023  llvm::Intrinsic::trunc, 1, mod);
1024  auto r = builder->CreateCall(fun, args);
1025  r->setTailCall(true);
1026  result_ = r;
1027 }
1028 
1029 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1030 {
1031  return llvm::Type::getDoubleTy(*context);
1032 }
1033 
1034 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1035 {
1036  return llvm::Type::getFloatTy(*context);
1037 }
1038 
1039 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1040 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1041 {
1042  return llvm::Type::getX86_FP80Ty(*context);
1043 }
1044 #endif
1045 
1046 } // namespace SymEngine
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:757
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:642
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127