llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <memory>
34 #include <vector>
35 #include <fstream>
36 
37 #include <symengine/llvm_double.h>
38 #include <symengine/eval_double.h>
39 #include <symengine/eval.h>
40 
41 namespace SymEngine
42 {
43 
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
46 #else
47 using llvm::make_unique;
48 #endif
49 #if (LLVM_VERSION_MAJOR < 18)
50 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
51 #else
52 using CodeGenOptLevel = llvm::CodeGenOptLevel;
53 #endif
54 
55 #if (LLVM_VERSION_MAJOR < 20)
56 const auto &GetDeclaration = llvm::Intrinsic::getDeclaration;
57 #else
58 const auto &GetDeclaration = llvm::Intrinsic::getOrInsertDeclaration;
59 #endif
60 
61 #if (LLVM_VERSION_MAJOR >= 21)
62 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
63 #else
64 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
65 #endif
66 
67 class IRBuilder : public llvm::IRBuilder<>
68 {
69 };
70 
71 LLVMVisitor::LLVMVisitor() = default;
72 LLVMVisitor::~LLVMVisitor() = default;
73 
74 llvm::Value *LLVMVisitor::apply(const Basic &b)
75 {
76  b.accept(*this);
77  return result_;
78 }
79 
80 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
81  unsigned opt_level)
82 {
83  init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
84 }
85 
86 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
87 {
88  std::vector<llvm::Type *> inp;
89  for (int i = 0; i < 2; i++) {
90  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
91  }
92  llvm::FunctionType *function_type = llvm::FunctionType::get(
93  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
94  auto F = llvm::Function::Create(
95  function_type, llvm::Function::InternalLinkage, "symengine_func", mod);
96  F->setCallingConv(llvm::CallingConv::C);
97 #if (LLVM_VERSION_MAJOR < 5)
98  {
99  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
100  llvm::AttributeSet PAS;
101  {
102  llvm::AttrBuilder B;
103  B.addAttribute(llvm::Attribute::ReadOnly);
104  AddNoCapture(B);
105  PAS = llvm::AttributeSet::get(mod->getContext(), 1U, B);
106  }
107 
108  attrs.push_back(PAS);
109  {
110  llvm::AttrBuilder B;
111  AddNoCapture(B);
112  PAS = llvm::AttributeSet::get(mod->getContext(), 2U, B);
113  }
114 
115  attrs.push_back(PAS);
116  {
117  llvm::AttrBuilder B;
118  B.addAttribute(llvm::Attribute::NoUnwind);
119  B.addAttribute(llvm::Attribute::UWTable);
120  PAS = llvm::AttributeSet::get(mod->getContext(), ~0U, B);
121  }
122 
123  attrs.push_back(PAS);
124 
125  F->setAttributes(llvm::AttributeSet::get(mod->getContext(), attrs));
126  }
127 #else
128  F->addParamAttr(0, llvm::Attribute::ReadOnly);
129 #if (LLVM_VERSION_MAJOR >= 21)
130  F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
131  *context, llvm::CaptureInfo::none()));
132  F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
133  *context, llvm::CaptureInfo::none()));
134 #else
135  F->addParamAttr(0, llvm::Attribute::NoCapture);
136  F->addParamAttr(1, llvm::Attribute::NoCapture);
137 #endif
138  F->addFnAttr(llvm::Attribute::NoUnwind);
139 #if (LLVM_VERSION_MAJOR < 15)
140  F->addFnAttr(llvm::Attribute::UWTable);
141 #else
142  F->addFnAttr(llvm::Attribute::getWithUWTableKind(
143  *context, llvm::UWTableKind::Default));
144 #endif
145 #endif
146  return F;
147 }
148 
149 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
150  const bool symbolic_cse, unsigned opt_level)
151 {
152  executionengine.reset();
153  llvm::InitializeNativeTarget();
154  llvm::InitializeNativeTargetAsmPrinter();
155  llvm::InitializeNativeTargetAsmParser();
156  context = make_unique<llvm::LLVMContext>();
157  symbols = inputs;
158 
159  // Create some module to put our function into it.
160  std::unique_ptr<llvm::Module> module
161  = make_unique<llvm::Module>("SymEngine", *context.get());
162  module->setDataLayout("");
163  mod = module.get();
164 
165  auto F = get_function_type(context.get());
166 
167  // Add a basic block to the function. As before, it automatically
168  // inserts
169  // because of the last argument.
170  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
171 
172  // Create a basic block builder with default parameters. The builder
173  // will
174  // automatically append instructions to the basic block `BB'.
175  llvm::IRBuilder<> _builder(BB);
176  builder = reinterpret_cast<IRBuilder *>(&_builder);
177  builder->SetInsertPoint(BB);
178  auto fmf = llvm::FastMathFlags();
179  // fmf.setUnsafeAlgebra();
180  builder->setFastMathFlags(fmf);
181 
182  // Load all the symbols and create references
183  auto input_arg = &(*(F->args().begin()));
184  for (unsigned i = 0; i < inputs.size(); i++) {
185  if (not is_a<Symbol>(*inputs[i])) {
186  throw SymEngineException("Input contains a non-symbol.");
187  }
188  auto index
189  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
190  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
191  index);
192  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
193  symbol_ptrs.push_back(result_);
194  }
195 
196  auto it = F->args().begin();
197 #if (LLVM_VERSION_MAJOR < 5)
198  auto out = &(*(++it));
199 #else
200  auto out = &(*(it + 1));
201 #endif
202  std::vector<llvm::Value *> output_vals;
203 
204  if (symbolic_cse) {
205  vec_basic reduced_exprs;
206  vec_pair replacements;
207  // cse the outputs
208  SymEngine::cse(replacements, reduced_exprs, outputs);
209  for (auto &rep : replacements) {
210  // Store the replacement symbol values in a dictionary
211  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
212  }
213  // Generate IR for all the reduced exprs and save references
214  for (unsigned i = 0; i < outputs.size(); i++) {
215  output_vals.push_back(apply(*reduced_exprs[i]));
216  }
217  } else {
218  // Generate IR for all the output exprs and save references
219  for (unsigned i = 0; i < outputs.size(); i++) {
220  output_vals.push_back(apply(*outputs[i]));
221  }
222  }
223 
224  // Store all the output exprs at the end
225  for (unsigned i = 0; i < outputs.size(); i++) {
226  auto index
227  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
228  auto ptr
229  = builder->CreateGEP(get_float_type(context.get()), out, index);
230  builder->CreateStore(output_vals[i], ptr);
231  }
232 
233  // Create the return instruction and add it to the basic block
234  builder->CreateRetVoid();
235 
236  // Validate the generated code, checking for consistency.
237  llvm::verifyFunction(*F, &llvm::outs());
238 
239  // std::cout << "LLVM IR" << std::endl;
240  // #if (LLVM_VERSION_MAJOR < 5)
241  // module->dump();
242  // #else
243  // module->print(llvm::errs(), nullptr);
244  // #endif
245 
246  // module->print(llvm::errs(), nullptr);
247 
248  // Optimize the function using default passes from PassBuilder
249  // FunctionSimplificationPipeline for the opt_level
250 #if (LLVM_VERSION_MAJOR < 14)
251  using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
252 #else
253  using OptimizationLevel = llvm::OptimizationLevel;
254 #endif
255  llvm::PassBuilder PB;
256  llvm::ModuleAnalysisManager MAM;
257  llvm::CGSCCAnalysisManager CGAM;
258  llvm::FunctionAnalysisManager FAM;
259  llvm::LoopAnalysisManager LAM;
260  PB.registerModuleAnalyses(MAM);
261  PB.registerCGSCCAnalyses(CGAM);
262  PB.registerFunctionAnalyses(FAM);
263  PB.registerLoopAnalyses(LAM);
264  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
265  llvm::FunctionPassManager FPM;
266  OptimizationLevel pb_opt_level{OptimizationLevel::O3};
267  if (opt_level == 0) {
268  pb_opt_level = OptimizationLevel::O0;
269  } else if (opt_level == 1) {
270  pb_opt_level = OptimizationLevel::O1;
271  } else if (opt_level == 2) {
272  pb_opt_level = OptimizationLevel::O2;
273  }
274 
275  if (opt_level != 0) {
276 #if (LLVM_VERSION_MAJOR < 6)
277  FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
278 #elif (LLVM_VERSION_MAJOR < 12)
279  FPM = PB.buildFunctionSimplificationPipeline(
280  pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
281 #else
282  FPM = PB.buildFunctionSimplificationPipeline(
283  pb_opt_level, llvm::ThinOrFullLTOPhase::None);
284 #endif
285  FPM.run(*F, FAM);
286  }
287 
288  // std::cout << "Optimized LLVM IR" << std::endl;
289  // module->print(llvm::errs(), nullptr);
290 
291  // Now we create the JIT.
292  std::string error;
293  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
294  llvm::EngineBuilder(std::move(module))
295  .setEngineKind(llvm::EngineKind::Kind::JIT)
296  .setOptLevel(static_cast<CodeGenOptLevel>(opt_level))
297  .setErrorStr(&error)
298  .create());
299 
300  // This is a hack to get the MemoryBuffer of a compiled object.
301  class MemoryBufferRefCallback : public llvm::ObjectCache
302  {
303  public:
304  std::string &ss_;
305  explicit MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
306 
307  void notifyObjectCompiled(const llvm::Module *M,
308  llvm::MemoryBufferRef obj) override
309  {
310  const char *c = obj.getBufferStart();
311  // Saving the object code in a std::string
312  ss_.assign(c, obj.getBufferSize());
313  }
314 
315  std::unique_ptr<llvm::MemoryBuffer>
316  getObject(const llvm::Module *M) override
317  {
318  return nullptr;
319  }
320  };
321 
322  MemoryBufferRefCallback callback(membuffer);
323  executionengine->setObjectCache(&callback);
324  // std::cout << error << std::endl;
325  executionengine->finalizeObject();
326 
327  // Get the symbol's address
328  func = (intptr_t)executionengine->getPointerToFunction(F);
329  symbol_ptrs.clear();
330  replacement_symbol_ptrs.clear();
331  symbols.clear();
332 }
333 
334 LLVMDoubleVisitor::LLVMDoubleVisitor() = default;
335 LLVMDoubleVisitor::~LLVMDoubleVisitor() = default;
336 
337 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
338 {
339  double ret;
340  ((double (*)(const double *, double *))func)(vec.data(), &ret);
341  return ret;
342 }
343 
344 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
345 {
346  ((double (*)(const double *, double *))func)(inps, outs);
347 }
348 
349 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
350 long double
351 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
352 {
353  long double ret;
354  ((long double (*)(const long double *, long double *))func)(vec.data(),
355  &ret);
356  return ret;
357 }
358 
359 void LLVMLongDoubleVisitor::call(long double *outs,
360  const long double *inps) const
361 {
362  ((long double (*)(const long double *, long double *))func)(inps, outs);
363 }
364 #endif
365 
366 LLVMFloatVisitor::LLVMFloatVisitor() = default;
367 LLVMFloatVisitor::~LLVMFloatVisitor() = default;
368 
369 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
370 {
371  float ret;
372  ((float (*)(const float *, float *))func)(vec.data(), &ret);
373  return ret;
374 }
375 
376 void LLVMFloatVisitor::call(float *outs, const float *inps) const
377 {
378  ((float (*)(const float *, float *))func)(inps, outs);
379 }
380 
381 void LLVMVisitor::set_double(double d)
382 {
383  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
384 }
385 
386 void LLVMVisitor::bvisit(const Integer &x)
387 {
388  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
389  mp_get_d(x.as_integer_class()));
390 }
391 
392 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
393 
394 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() = default;
395 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() = default;
396 
397 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
398 {
399 #ifndef HAVE_SYMENGINE_MPFR
400  throw NotImplementedError("Cannot convert to long double without MPFR");
401 #else
402  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
403  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
404  m->__str__());
405 #endif
406 }
407 
408 void LLVMLongDoubleVisitor::visit(const Integer &x)
409 {
410  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
411  x.__str__());
412 }
413 #endif
414 
415 void LLVMVisitor::bvisit(const Rational &x)
416 {
417  set_double(mp_get_d(x.as_rational_class()));
418 }
419 
420 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
421 void LLVMLongDoubleVisitor::visit(const Rational &x)
422 {
423  convert_from_mpfr(x);
424 }
425 #endif
426 
427 void LLVMVisitor::bvisit(const RealDouble &x)
428 {
429  set_double(x.i);
430 }
431 
432 #ifdef HAVE_SYMENGINE_MPFR
433 void LLVMVisitor::bvisit(const RealMPFR &x)
434 {
435  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
436 }
437 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
438 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
439 {
440  convert_from_mpfr(x);
441 }
442 #endif
443 #endif
444 
445 void LLVMVisitor::bvisit(const Add &x)
446 {
447  llvm::Value *tmp, *tmp1, *tmp2;
448  auto it = x.get_dict().begin();
449 
450  if (eq(*x.get_coef(), *zero)) {
451  // `x + 0.0` is not optimized out
452  if (eq(*one, *(it->second))) {
453  tmp = apply(*(it->first));
454  } else {
455  tmp1 = apply(*(it->first));
456  tmp2 = apply(*(it->second));
457  tmp = builder->CreateFMul(tmp1, tmp2);
458  }
459  ++it;
460  } else {
461  tmp = apply(*x.get_coef());
462  }
463 
464  for (; it != x.get_dict().end(); ++it) {
465  if (eq(*one, *(it->second))) {
466  tmp1 = apply(*(it->first));
467  tmp = builder->CreateFAdd(tmp, tmp1);
468  } else {
469  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
470  // tmp =
471  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
472  // llvm::Intrinsic::fma,
473  // 3, context), args);
474  tmp1 = apply(*(it->first));
475  tmp2 = apply(*(it->second));
476  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
477  }
478  }
479  result_ = tmp;
480 }
481 
482 void LLVMVisitor::bvisit(const Mul &x)
483 {
484  llvm::Value *tmp = nullptr;
485  bool first = true;
486  for (const auto &p : x.get_args()) {
487  if (first) {
488  tmp = apply(*p);
489  } else {
490  tmp = builder->CreateFMul(tmp, apply(*p));
491  }
492  first = false;
493  }
494  result_ = tmp;
495 }
496 
497 llvm::Function *LLVMVisitor::get_powi()
498 {
499  std::vector<llvm::Type *> arg_type;
500  arg_type.push_back(get_float_type(&mod->getContext()));
501 #if (LLVM_VERSION_MAJOR > 12)
502  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
503 #endif
504  return GetDeclaration(mod, llvm::Intrinsic::powi, arg_type);
505 }
506 
507 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
508  unsigned n, llvm::Module *mod)
509 {
510  std::vector<llvm::Type *> arg_type(n, type);
511  return GetDeclaration(mod, id, arg_type);
512 }
513 
514 void LLVMVisitor::bvisit(const Pow &x)
515 {
516  std::vector<llvm::Value *> args;
517  llvm::Function *fun;
518  if (eq(*(x.get_base()), *E)) {
519  args.push_back(apply(*x.get_exp()));
520  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
521  llvm::Intrinsic::exp, 1, mod);
522 
523  } else if (eq(*(x.get_base()), *integer(2))) {
524  args.push_back(apply(*x.get_exp()));
525  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
526  llvm::Intrinsic::exp2, 1, mod);
527 
528  } else {
529  if (is_a<Integer>(*x.get_exp())) {
530  if (eq(*x.get_exp(), *integer(2))) {
531  llvm::Value *tmp = apply(*x.get_base());
532  result_ = builder->CreateFMul(tmp, tmp);
533  return;
534  } else {
535  args.push_back(apply(*x.get_base()));
536  int d = numeric_cast<int>(
537  mp_get_si(static_cast<const Integer &>(*x.get_exp())
538  .as_integer_class()));
539  result_ = llvm::ConstantInt::get(
540  llvm::Type::getInt32Ty(mod->getContext()), d, true);
541  args.push_back(result_);
542  fun = get_powi();
543  }
544  } else {
545  args.push_back(apply(*x.get_base()));
546  args.push_back(apply(*x.get_exp()));
547  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
548  llvm::Intrinsic::pow, 1, mod);
549  }
550  }
551  auto r = builder->CreateCall(fun, args);
552  r->setTailCall(true);
553  result_ = r;
554 }
555 
556 void LLVMVisitor::bvisit(const Sin &x)
557 {
558  std::vector<llvm::Value *> args;
559  llvm::Function *fun;
560  args.push_back(apply(*x.get_arg()));
561  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
562  llvm::Intrinsic::sin, 1, mod);
563  auto r = builder->CreateCall(fun, args);
564  r->setTailCall(true);
565  result_ = r;
566 }
567 
568 void LLVMVisitor::bvisit(const Cos &x)
569 {
570  std::vector<llvm::Value *> args;
571  llvm::Function *fun;
572  args.push_back(apply(*x.get_arg()));
573  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
574  llvm::Intrinsic::cos, 1, mod);
575  auto r = builder->CreateCall(fun, args);
576  r->setTailCall(true);
577  result_ = r;
578 }
579 
580 void LLVMVisitor::bvisit(const Piecewise &x)
581 {
582  std::vector<llvm::BasicBlock> blocks;
583 
584  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
585 
586  if (neq(*pw->get_vec().back().second, *boolTrue)) {
587  throw SymEngineException(
588  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
589  }
590 
591  if (pw->get_vec().size() > 2) {
592  PiecewiseVec rest = pw->get_vec();
593  rest.erase(rest.begin());
594  auto rest_pw = piecewise(std::move(rest));
595  PiecewiseVec new_pw;
596  new_pw.push_back(*pw->get_vec().begin());
597  new_pw.push_back({rest_pw, pw->get_vec().back().second});
598  pw = piecewise(std::move(new_pw))
599  ->rcp_from_this_cast<const Piecewise>();
600  } else if (pw->get_vec().size() < 2) {
601  throw SymEngineException("Invalid Piecewise object");
602  }
603 
604  auto cond_basic = pw->get_vec().front().second;
605  llvm::Value *cond = apply(*cond_basic);
606  // check if cond != 0.0
607  cond = builder->CreateFCmpONE(
608  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
609  "ifcond");
610  llvm::Function *function = builder->GetInsertBlock()->getParent();
611  // Create blocks for the then and else cases. Insert the 'then' block at
612  // the
613  // end of the function.
614  llvm::BasicBlock *then_bb
615  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
616  llvm::BasicBlock *else_bb
617  = llvm::BasicBlock::Create(mod->getContext(), "else");
618  llvm::BasicBlock *merge_bb
619  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
620  builder->CreateCondBr(cond, then_bb, else_bb);
621 
622  // Emit then value.
623  builder->SetInsertPoint(then_bb);
624  llvm::Value *then_value = apply(*pw->get_vec().front().first);
625  builder->CreateBr(merge_bb);
626 
627  // Codegen of 'then_value' can change the current block, update then_bb for
628  // the PHI.
629  then_bb = builder->GetInsertBlock();
630 
631  // Emit else block.
632 #if (LLVM_VERSION_MAJOR < 16)
633  function->getBasicBlockList().push_back(else_bb);
634 #else
635  function->insert(function->end(), else_bb);
636 #endif
637  builder->SetInsertPoint(else_bb);
638  llvm::Value *else_value = apply(*pw->get_vec().back().first);
639  builder->CreateBr(merge_bb);
640 
641  // Codegen of 'else_value' can change the current block, update else_bb for
642  // the PHI.
643  else_bb = builder->GetInsertBlock();
644 
645  // Emit merge block.
646 #if (LLVM_VERSION_MAJOR < 16)
647  function->getBasicBlockList().push_back(merge_bb);
648 #else
649  function->insert(function->end(), merge_bb);
650 #endif
651  builder->SetInsertPoint(merge_bb);
652  llvm::PHINode *phi_node
653  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
654 
655  phi_node->addIncoming(then_value, then_bb);
656  phi_node->addIncoming(else_value, else_bb);
657  result_ = phi_node;
658 }
659 
660 void LLVMVisitor::bvisit(const Sign &x)
661 {
662  const auto x2 = x.get_arg();
663  PiecewiseVec new_pw;
664  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
665  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
666  new_pw.push_back({real_double(1.0), boolTrue});
667  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
668  bvisit(*pw);
669 }
670 
671 void LLVMVisitor::bvisit(const Contains &cts)
672 {
673  llvm::Value *expr = apply(*cts.get_expr());
674  const auto set = cts.get_set();
675  if (is_a<Interval>(*set)) {
676  const auto &interv = down_cast<const Interval &>(*set);
677  llvm::Value *start = apply(*interv.get_start());
678  llvm::Value *end = apply(*interv.get_end());
679  const bool left_open = interv.get_left_open();
680  const bool right_open = interv.get_right_open();
681  llvm::Value *left_ok;
682  llvm::Value *right_ok;
683  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
684  : builder->CreateFCmpOLE(start, expr);
685  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
686  : builder->CreateFCmpOLE(expr, end);
687  result_ = builder->CreateAnd(left_ok, right_ok);
688  result_ = builder->CreateUIToFP(result_,
689  get_float_type(&mod->getContext()));
690  } else {
691  throw SymEngineException("LLVMVisitor: only ``Interval`` "
692  "implemented for ``Contains``.");
693  }
694 }
695 
696 void LLVMVisitor::bvisit(const Infty &x)
697 {
698  if (x.is_negative_infinity()) {
699  result_ = llvm::ConstantFP::getInfinity(
700  get_float_type(&mod->getContext()), true);
701  } else if (x.is_positive_infinity()) {
702  result_ = llvm::ConstantFP::getInfinity(
703  get_float_type(&mod->getContext()), false);
704  } else {
705  throw SymEngineException(
706  "LLVMDouble can only represent real valued infinity");
707  }
708 }
709 
710 void LLVMVisitor::bvisit(const NaN &x)
711 {
712  result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
713  /*negative=*/false, /*payload=*/0);
714 }
715 
716 void LLVMVisitor::bvisit(const BooleanAtom &x)
717 {
718  const bool val = x.get_val();
719  set_double(val ? 1.0 : 0.0);
720 }
721 
722 void LLVMVisitor::bvisit(const Log &x)
723 {
724  std::vector<llvm::Value *> args;
725  llvm::Function *fun;
726  args.push_back(apply(*x.get_arg()));
727  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
728  llvm::Intrinsic::log, 1, mod);
729  auto r = builder->CreateCall(fun, args);
730  r->setTailCall(true);
731  result_ = r;
732 }
733 
734 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
735  void LLVMVisitor::bvisit(const Class &x) \
736  { \
737  llvm::Value *value = nullptr; \
738  llvm::Value *tmp; \
739  set_double(0.0); \
740  llvm::Value *zero_val = result_; \
741  for (auto &p : x.get_container()) { \
742  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
743  if (value == nullptr) { \
744  value = tmp; \
745  } else { \
746  value = builder->method(value, tmp); \
747  } \
748  } \
749  result_ = builder->CreateUIToFP(value, \
750  get_float_type(&mod->getContext())); \
751  }
752 
753 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
754 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
755 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
756 
757 void LLVMVisitor::bvisit(const Not &x)
758 {
759  builder->CreateNot(apply(*x.get_arg()));
760 }
761 
762 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
763  void LLVMVisitor::bvisit(const Class &x) \
764  { \
765  llvm::Value *left = apply(*x.get_arg1()); \
766  llvm::Value *right = apply(*x.get_arg2()); \
767  result_ = builder->method(left, right); \
768  result_ = builder->CreateUIToFP(result_, \
769  get_float_type(&mod->getContext())); \
770  }
771 
772 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
773 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
774 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
775 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
776 
777 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
778  void LLVMDoubleVisitor::visit(const Class &x) \
779  { \
780  vec_basic basic_args = x.get_args(); \
781  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
782  std::vector<llvm::Value *> args; \
783  for (const auto &arg : basic_args) { \
784  args.push_back(apply(*arg)); \
785  } \
786  auto r = builder->CreateCall(func, args); \
787  r->setTailCall(true); \
788  result_ = r; \
789  } \
790  void LLVMFloatVisitor::visit(const Class &x) \
791  { \
792  vec_basic basic_args = x.get_args(); \
793  llvm::Function *func = get_external_function(#ext + std::string("f"), \
794  basic_args.size()); \
795  std::vector<llvm::Value *> args; \
796  for (const auto &arg : basic_args) { \
797  args.push_back(apply(*arg)); \
798  } \
799  auto r = builder->CreateCall(func, args); \
800  r->setTailCall(true); \
801  result_ = r; \
802  }
803 
804 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
805 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
806  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
807  void LLVMLongDoubleVisitor::visit(const Class &x) \
808  { \
809  vec_basic basic_args = x.get_args(); \
810  llvm::Function *func = get_external_function(#ext + std::string("l"), \
811  basic_args.size()); \
812  std::vector<llvm::Value *> args; \
813  for (const auto &arg : basic_args) { \
814  args.push_back(apply(*arg)); \
815  } \
816  auto r = builder->CreateCall(func, args); \
817  r->setTailCall(true); \
818  result_ = r; \
819  }
820 #else
821 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
822  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
823 #endif
824 
825 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
826 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
827 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
828 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
829 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
830 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
831 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
832 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
833 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
834 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
835 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
836 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
837 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
838 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
839 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
840 
841 void LLVMVisitor::bvisit(const Abs &x)
842 {
843  std::vector<llvm::Value *> args;
844  llvm::Function *fun;
845  args.push_back(apply(*x.get_arg()));
846  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
847  llvm::Intrinsic::fabs, 1, mod);
848  auto r = builder->CreateCall(fun, args);
849  r->setTailCall(true);
850  result_ = r;
851 }
852 
853 void LLVMVisitor::bvisit(const Min &x)
854 {
855  llvm::Value *value = nullptr;
856  llvm::Function *fun;
857  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
858  llvm::Intrinsic::minnum, 1, mod);
859  for (auto &arg : x.get_vec()) {
860  if (value != nullptr) {
861  std::vector<llvm::Value *> args;
862  args.push_back(value);
863  args.push_back(apply(*arg));
864  auto r = builder->CreateCall(fun, args);
865  r->setTailCall(true);
866  value = r;
867  } else {
868  value = apply(*arg);
869  }
870  }
871  result_ = value;
872 }
873 
874 void LLVMVisitor::bvisit(const Max &x)
875 {
876  llvm::Value *value = nullptr;
877  llvm::Function *fun;
878  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
879  llvm::Intrinsic::maxnum, 1, mod);
880  for (auto &arg : x.get_vec()) {
881  if (value != nullptr) {
882  std::vector<llvm::Value *> args;
883  args.push_back(value);
884  args.push_back(apply(*arg));
885  auto r = builder->CreateCall(fun, args);
886  r->setTailCall(true);
887  value = r;
888  } else {
889  value = apply(*arg);
890  }
891  }
892  result_ = value;
893 }
894 
895 void LLVMVisitor::bvisit(const Symbol &x)
896 {
897  unsigned i = 0;
898  for (auto &symb : symbols) {
899  if (eq(x, *symb)) {
900  result_ = symbol_ptrs[i];
901  return;
902  }
903  ++i;
904  }
905  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
906  if (it != replacement_symbol_ptrs.end()) {
907  result_ = it->second;
908  return;
909  }
910 
911  throw SymEngineException("Symbol " + x.__str__()
912  + " not in the symbols vector.");
913 }
914 
915 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
916  size_t nargs)
917 {
918  std::vector<llvm::Type *> func_args(nargs,
919  get_float_type(&mod->getContext()));
920  llvm::FunctionType *func_type = llvm::FunctionType::get(
921  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
922 
923  llvm::Function *func = mod->getFunction(name);
924  if (!func) {
925  func = llvm::Function::Create(
926  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
927  func->setCallingConv(llvm::CallingConv::C);
928  }
929 #if (LLVM_VERSION_MAJOR < 5)
930  llvm::AttributeSet func_attr_set;
931  {
932  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
933  llvm::AttributeSet attr_set;
934  {
935  llvm::AttrBuilder attr_builder;
936  attr_builder.addAttribute(llvm::Attribute::NoUnwind);
937  attr_set
938  = llvm::AttributeSet::get(mod->getContext(), ~0U, attr_builder);
939  }
940 
941  attrs.push_back(attr_set);
942  func_attr_set = llvm::AttributeSet::get(mod->getContext(), attrs);
943  }
944  func->setAttributes(func_attr_set);
945 #else
946  func->addFnAttr(llvm::Attribute::NoUnwind);
947 #endif
948  return func;
949 }
950 
951 void LLVMVisitor::bvisit(const Constant &x)
952 {
953  set_double(eval_double(x));
954 }
955 
956 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
957 void LLVMLongDoubleVisitor::visit(const Constant &x)
958 {
959  convert_from_mpfr(x);
960 }
961 #endif
962 
963 void LLVMVisitor::bvisit(const Basic &x)
964 {
965  throw NotImplementedError(x.__str__());
966 }
967 
968 const std::string &LLVMVisitor::dumps() const
969 {
970  return membuffer;
971 };
972 
973 void LLVMVisitor::loads(const std::string &s)
974 {
975  membuffer = s;
976  llvm::InitializeNativeTarget();
977  llvm::InitializeNativeTargetAsmPrinter();
978  llvm::InitializeNativeTargetAsmParser();
979  context = make_unique<llvm::LLVMContext>();
980 
981  // Create some module to put our function into it.
982  std::unique_ptr<llvm::Module> module
983  = make_unique<llvm::Module>("SymEngine", *context);
984  module->setDataLayout("");
985  mod = module.get();
986 
987  // Only defining the prototype for the function here.
988  // Since we know where the function is stored that's enough
989  // llvm::ObjectCache is designed for caching objects, but it
990  // is used here for loading one specific object.
991  auto F = get_function_type(context.get());
992 
993  std::string error;
994  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
995  llvm::EngineBuilder(std::move(module))
996  .setEngineKind(llvm::EngineKind::Kind::JIT)
997  .setOptLevel(CodeGenOptLevel::Aggressive)
998  .setErrorStr(&error)
999  .create());
1000 
1001  class MCJITObjectLoader : public llvm::ObjectCache
1002  {
1003  const std::string &s_;
1004 
1005  public:
1006  MCJITObjectLoader(const std::string &s) : s_(s) {}
1007  void notifyObjectCompiled(const llvm::Module *M,
1008  llvm::MemoryBufferRef obj) override
1009  {
1010  }
1011 
1012  // No need to check M because there is only one function
1013  // Return it after reading from the file.
1014  std::unique_ptr<llvm::MemoryBuffer>
1015  getObject(const llvm::Module *M) override
1016  {
1017  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
1018  }
1019  };
1020 
1021  MCJITObjectLoader loader(s);
1022  executionengine->setObjectCache(&loader);
1023  executionengine->finalizeObject();
1024  // Set func to compiled function pointer
1025  func = (intptr_t)executionengine->getPointerToFunction(F);
1026 }
1027 
1028 void LLVMVisitor::bvisit(const Floor &x)
1029 {
1030  std::vector<llvm::Value *> args;
1031  llvm::Function *fun;
1032  args.push_back(apply(*x.get_arg()));
1033  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1034  llvm::Intrinsic::floor, 1, mod);
1035  auto r = builder->CreateCall(fun, args);
1036  r->setTailCall(true);
1037  result_ = r;
1038 }
1039 
1040 void LLVMVisitor::bvisit(const Ceiling &x)
1041 {
1042  std::vector<llvm::Value *> args;
1043  llvm::Function *fun;
1044  args.push_back(apply(*x.get_arg()));
1045  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1046  llvm::Intrinsic::ceil, 1, mod);
1047  auto r = builder->CreateCall(fun, args);
1048  r->setTailCall(true);
1049  result_ = r;
1050 }
1051 
1052 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1053 {
1054  apply(*x.get_arg());
1055 }
1056 
1057 void LLVMVisitor::bvisit(const Truncate &x)
1058 {
1059  std::vector<llvm::Value *> args;
1060  llvm::Function *fun;
1061  args.push_back(apply(*x.get_arg()));
1062  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1063  llvm::Intrinsic::trunc, 1, mod);
1064  auto r = builder->CreateCall(fun, args);
1065  r->setTailCall(true);
1066  result_ = r;
1067 }
1068 
1069 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1070 {
1071  return llvm::Type::getDoubleTy(*context);
1072 }
1073 
1074 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1075 {
1076  return llvm::Type::getFloatTy(*context);
1077 }
1078 
1079 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1080 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1081 {
1082  return llvm::Type::getX86_FP80Ty(*context);
1083 }
1084 #endif
1085 
1086 } // namespace SymEngine
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:757
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:642
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127