llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <memory>
34 #include <vector>
35 #include <fstream>
36 
37 #include <symengine/llvm_double.h>
38 #include <symengine/eval_double.h>
39 #include <symengine/eval.h>
40 
41 namespace SymEngine
42 {
43 
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
46 #else
47 using llvm::make_unique;
48 #endif
49 
50 #if (LLVM_VERSION_MAJOR >= 18)
51 using CodeGenOptLevel = llvm::CodeGenOptLevel;
52 #else
53 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
54 #endif
55 
56 #if (LLVM_VERSION_MAJOR >= 20)
57 auto GetDeclaration = [](llvm::Module *M, llvm::Intrinsic::ID id,
58  llvm::ArrayRef<llvm::Type *> Tys) {
59  return llvm::Intrinsic::getOrInsertDeclaration(M, id, Tys);
60 };
61 #else
62 const auto &GetDeclaration = llvm::Intrinsic::getDeclaration;
63 #endif
64 
65 #if (LLVM_VERSION_MAJOR >= 21)
66 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
67 #else
68 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
69 #endif
70 
71 class IRBuilder : public llvm::IRBuilder<>
72 {
73 };
74 
75 LLVMVisitor::LLVMVisitor() = default;
76 LLVMVisitor::~LLVMVisitor() = default;
77 
78 llvm::Value *LLVMVisitor::apply(const Basic &b)
79 {
80  b.accept(*this);
81  return result_;
82 }
83 
84 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
85  unsigned opt_level)
86 {
87  init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
88 }
89 
90 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
91 {
92  std::vector<llvm::Type *> inp;
93  for (int i = 0; i < 2; i++) {
94 #if (LLVM_VERSION_MAJOR >= 22)
95  inp.push_back(llvm::PointerType::get(*context, 0));
96 #else
97  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
98 #endif
99  }
100  llvm::FunctionType *function_type = llvm::FunctionType::get(
101  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
102  auto F = llvm::Function::Create(
103  function_type, llvm::Function::InternalLinkage, "symengine_func", mod);
104  F->setCallingConv(llvm::CallingConv::C);
105  F->addParamAttr(0, llvm::Attribute::ReadOnly);
106 #if (LLVM_VERSION_MAJOR >= 21)
107  F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
108  *context, llvm::CaptureInfo::none()));
109  F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
110  *context, llvm::CaptureInfo::none()));
111 #else
112  F->addParamAttr(0, llvm::Attribute::NoCapture);
113  F->addParamAttr(1, llvm::Attribute::NoCapture);
114 #endif
115  F->addFnAttr(llvm::Attribute::NoUnwind);
116 #if (LLVM_VERSION_MAJOR < 15)
117  F->addFnAttr(llvm::Attribute::UWTable);
118 #else
119  F->addFnAttr(llvm::Attribute::getWithUWTableKind(
120  *context, llvm::UWTableKind::Default));
121 #endif
122  return F;
123 }
124 
125 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
126  const bool symbolic_cse, unsigned opt_level)
127 {
128  executionengine.reset();
129  llvm::InitializeNativeTarget();
130  llvm::InitializeNativeTargetAsmPrinter();
131  llvm::InitializeNativeTargetAsmParser();
132  context = make_unique<llvm::LLVMContext>();
133  symbols = inputs;
134 
135  // Create some module to put our function into it.
136  std::unique_ptr<llvm::Module> module
137  = make_unique<llvm::Module>("SymEngine", *context.get());
138  module->setDataLayout("");
139  mod = module.get();
140 
141  auto F = get_function_type(context.get());
142 
143  // Add a basic block to the function. As before, it automatically
144  // inserts
145  // because of the last argument.
146  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
147 
148  // Create a basic block builder with default parameters. The builder
149  // will
150  // automatically append instructions to the basic block `BB'.
151  llvm::IRBuilder<> _builder(BB);
152  builder = reinterpret_cast<IRBuilder *>(&_builder);
153  builder->SetInsertPoint(BB);
154  auto fmf = llvm::FastMathFlags();
155  // fmf.setUnsafeAlgebra();
156  builder->setFastMathFlags(fmf);
157 
158  // Load all the symbols and create references
159  auto input_arg = &(*(F->args().begin()));
160  for (unsigned i = 0; i < inputs.size(); i++) {
161  if (not is_a<Symbol>(*inputs[i])) {
162  throw SymEngineException("Input contains a non-symbol.");
163  }
164  auto index
165  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
166  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
167  index);
168  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
169  symbol_ptrs.push_back(result_);
170  }
171 
172  auto it = F->args().begin();
173  auto out = &(*(it + 1));
174  std::vector<llvm::Value *> output_vals;
175 
176  if (symbolic_cse) {
177  vec_basic reduced_exprs;
178  vec_pair replacements;
179  // cse the outputs
180  SymEngine::cse(replacements, reduced_exprs, outputs);
181  for (auto &rep : replacements) {
182  // Store the replacement symbol values in a dictionary
183  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
184  }
185  // Generate IR for all the reduced exprs and save references
186  for (unsigned i = 0; i < outputs.size(); i++) {
187  output_vals.push_back(apply(*reduced_exprs[i]));
188  }
189  } else {
190  // Generate IR for all the output exprs and save references
191  for (unsigned i = 0; i < outputs.size(); i++) {
192  output_vals.push_back(apply(*outputs[i]));
193  }
194  }
195 
196  // Store all the output exprs at the end
197  for (unsigned i = 0; i < outputs.size(); i++) {
198  auto index
199  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
200  auto ptr
201  = builder->CreateGEP(get_float_type(context.get()), out, index);
202  builder->CreateStore(output_vals[i], ptr);
203  }
204 
205  // Create the return instruction and add it to the basic block
206  builder->CreateRetVoid();
207 
208  // Validate the generated code, checking for consistency.
209  llvm::verifyFunction(*F, &llvm::outs());
210 
211  // std::cout << "LLVM IR" << std::endl;
212  // module->print(llvm::errs(), nullptr);
213 
214  // Optimize the function using default passes from PassBuilder
215  // FunctionSimplificationPipeline for the opt_level
216 #if (LLVM_VERSION_MAJOR < 14)
217  using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
218 #else
219  using OptimizationLevel = llvm::OptimizationLevel;
220 #endif
221  llvm::PassBuilder PB;
222  llvm::ModuleAnalysisManager MAM;
223  llvm::CGSCCAnalysisManager CGAM;
224  llvm::FunctionAnalysisManager FAM;
225  llvm::LoopAnalysisManager LAM;
226  PB.registerModuleAnalyses(MAM);
227  PB.registerCGSCCAnalyses(CGAM);
228  PB.registerFunctionAnalyses(FAM);
229  PB.registerLoopAnalyses(LAM);
230  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
231  llvm::FunctionPassManager FPM;
232  OptimizationLevel pb_opt_level{OptimizationLevel::O3};
233  if (opt_level == 0) {
234  pb_opt_level = OptimizationLevel::O0;
235  } else if (opt_level == 1) {
236  pb_opt_level = OptimizationLevel::O1;
237  } else if (opt_level == 2) {
238  pb_opt_level = OptimizationLevel::O2;
239  }
240 
241  if (opt_level != 0) {
242 #if (LLVM_VERSION_MAJOR < 6)
243  FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
244 #elif (LLVM_VERSION_MAJOR < 12)
245  FPM = PB.buildFunctionSimplificationPipeline(
246  pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
247 #else
248  FPM = PB.buildFunctionSimplificationPipeline(
249  pb_opt_level, llvm::ThinOrFullLTOPhase::None);
250 #endif
251  FPM.run(*F, FAM);
252  }
253 
254  // std::cout << "Optimized LLVM IR" << std::endl;
255  // module->print(llvm::errs(), nullptr);
256 
257  // Now we create the JIT.
258  std::string error;
259  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
260  llvm::EngineBuilder(std::move(module))
261  .setEngineKind(llvm::EngineKind::Kind::JIT)
262  .setOptLevel(static_cast<CodeGenOptLevel>(opt_level))
263  .setErrorStr(&error)
264  .create());
265 
266  // This is a hack to get the MemoryBuffer of a compiled object.
267  class MemoryBufferRefCallback : public llvm::ObjectCache
268  {
269  public:
270  std::string &ss_;
271  explicit MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
272 
273  void notifyObjectCompiled(const llvm::Module *M,
274  llvm::MemoryBufferRef obj) override
275  {
276  const char *c = obj.getBufferStart();
277  // Saving the object code in a std::string
278  ss_.assign(c, obj.getBufferSize());
279  }
280 
281  std::unique_ptr<llvm::MemoryBuffer>
282  getObject(const llvm::Module *M) override
283  {
284  return nullptr;
285  }
286  };
287 
288  MemoryBufferRefCallback callback(membuffer);
289  executionengine->setObjectCache(&callback);
290  // std::cout << error << std::endl;
291  executionengine->finalizeObject();
292 
293  // Get the symbol's address
294  func = (intptr_t)executionengine->getPointerToFunction(F);
295  symbol_ptrs.clear();
296  replacement_symbol_ptrs.clear();
297  symbols.clear();
298 }
299 
300 LLVMDoubleVisitor::LLVMDoubleVisitor() = default;
301 LLVMDoubleVisitor::~LLVMDoubleVisitor() = default;
302 
303 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
304 {
305  double ret;
306  ((double (*)(const double *, double *))func)(vec.data(), &ret);
307  return ret;
308 }
309 
310 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
311 {
312  ((double (*)(const double *, double *))func)(inps, outs);
313 }
314 
315 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
316 long double
317 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
318 {
319  long double ret;
320  ((long double (*)(const long double *, long double *))func)(vec.data(),
321  &ret);
322  return ret;
323 }
324 
325 void LLVMLongDoubleVisitor::call(long double *outs,
326  const long double *inps) const
327 {
328  ((long double (*)(const long double *, long double *))func)(inps, outs);
329 }
330 #endif
331 
332 LLVMFloatVisitor::LLVMFloatVisitor() = default;
333 LLVMFloatVisitor::~LLVMFloatVisitor() = default;
334 
335 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
336 {
337  float ret;
338  ((float (*)(const float *, float *))func)(vec.data(), &ret);
339  return ret;
340 }
341 
342 void LLVMFloatVisitor::call(float *outs, const float *inps) const
343 {
344  ((float (*)(const float *, float *))func)(inps, outs);
345 }
346 
347 void LLVMVisitor::set_double(double d)
348 {
349  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
350 }
351 
352 void LLVMVisitor::bvisit(const Integer &x)
353 {
354  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
355  mp_get_d(x.as_integer_class()));
356 }
357 
358 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
359 
360 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() = default;
361 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() = default;
362 
363 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
364 {
365 #ifndef HAVE_SYMENGINE_MPFR
366  throw NotImplementedError("Cannot convert to long double without MPFR");
367 #else
368  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
369  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
370  m->__str__());
371 #endif
372 }
373 
374 void LLVMLongDoubleVisitor::visit(const Integer &x)
375 {
376  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
377  x.__str__());
378 }
379 #endif
380 
381 void LLVMVisitor::bvisit(const Rational &x)
382 {
383  set_double(mp_get_d(x.as_rational_class()));
384 }
385 
386 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
387 void LLVMLongDoubleVisitor::visit(const Rational &x)
388 {
389  convert_from_mpfr(x);
390 }
391 #endif
392 
393 void LLVMVisitor::bvisit(const RealDouble &x)
394 {
395  set_double(x.i);
396 }
397 
398 #ifdef HAVE_SYMENGINE_MPFR
399 void LLVMVisitor::bvisit(const RealMPFR &x)
400 {
401  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
402 }
403 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
404 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
405 {
406  convert_from_mpfr(x);
407 }
408 #endif
409 #endif
410 
411 void LLVMVisitor::bvisit(const Add &x)
412 {
413  llvm::Value *tmp, *tmp1, *tmp2;
414  auto it = x.get_dict().begin();
415 
416  if (eq(*x.get_coef(), *zero)) {
417  // `x + 0.0` is not optimized out
418  if (eq(*one, *(it->second))) {
419  tmp = apply(*(it->first));
420  } else {
421  tmp1 = apply(*(it->first));
422  tmp2 = apply(*(it->second));
423  tmp = builder->CreateFMul(tmp1, tmp2);
424  }
425  ++it;
426  } else {
427  tmp = apply(*x.get_coef());
428  }
429 
430  for (; it != x.get_dict().end(); ++it) {
431  if (eq(*one, *(it->second))) {
432  tmp1 = apply(*(it->first));
433  tmp = builder->CreateFAdd(tmp, tmp1);
434  } else {
435  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
436  // tmp =
437  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
438  // llvm::Intrinsic::fma,
439  // 3, context), args);
440  tmp1 = apply(*(it->first));
441  tmp2 = apply(*(it->second));
442  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
443  }
444  }
445  result_ = tmp;
446 }
447 
448 void LLVMVisitor::bvisit(const Mul &x)
449 {
450  llvm::Value *tmp = nullptr;
451  bool first = true;
452  for (const auto &p : x.get_args()) {
453  if (first) {
454  tmp = apply(*p);
455  } else {
456  tmp = builder->CreateFMul(tmp, apply(*p));
457  }
458  first = false;
459  }
460  result_ = tmp;
461 }
462 
463 llvm::Function *LLVMVisitor::get_powi()
464 {
465  std::vector<llvm::Type *> arg_type;
466  arg_type.push_back(get_float_type(&mod->getContext()));
467 #if (LLVM_VERSION_MAJOR > 12)
468  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
469 #endif
470  return GetDeclaration(mod, llvm::Intrinsic::powi, arg_type);
471 }
472 
473 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
474  unsigned n, llvm::Module *mod)
475 {
476  std::vector<llvm::Type *> arg_type(n, type);
477  return GetDeclaration(mod, id, arg_type);
478 }
479 
480 void LLVMVisitor::bvisit(const Pow &x)
481 {
482  std::vector<llvm::Value *> args;
483  llvm::Function *fun;
484  if (eq(*(x.get_base()), *E)) {
485  args.push_back(apply(*x.get_exp()));
486  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
487  llvm::Intrinsic::exp, 1, mod);
488 
489  } else if (eq(*(x.get_base()), *integer(2))) {
490  args.push_back(apply(*x.get_exp()));
491  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
492  llvm::Intrinsic::exp2, 1, mod);
493 
494  } else {
495  if (is_a<Integer>(*x.get_exp())) {
496  if (eq(*x.get_exp(), *integer(2))) {
497  llvm::Value *tmp = apply(*x.get_base());
498  result_ = builder->CreateFMul(tmp, tmp);
499  return;
500  } else {
501  args.push_back(apply(*x.get_base()));
502  int d = numeric_cast<int>(
503  mp_get_si(static_cast<const Integer &>(*x.get_exp())
504  .as_integer_class()));
505  result_ = llvm::ConstantInt::get(
506  llvm::Type::getInt32Ty(mod->getContext()), d, true);
507  args.push_back(result_);
508  fun = get_powi();
509  }
510  } else {
511  args.push_back(apply(*x.get_base()));
512  args.push_back(apply(*x.get_exp()));
513  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
514  llvm::Intrinsic::pow, 1, mod);
515  }
516  }
517  auto r = builder->CreateCall(fun, args);
518  r->setTailCall(true);
519  result_ = r;
520 }
521 
522 void LLVMVisitor::bvisit(const Sin &x)
523 {
524  std::vector<llvm::Value *> args;
525  llvm::Function *fun;
526  args.push_back(apply(*x.get_arg()));
527  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
528  llvm::Intrinsic::sin, 1, mod);
529  auto r = builder->CreateCall(fun, args);
530  r->setTailCall(true);
531  result_ = r;
532 }
533 
534 void LLVMVisitor::bvisit(const Cos &x)
535 {
536  std::vector<llvm::Value *> args;
537  llvm::Function *fun;
538  args.push_back(apply(*x.get_arg()));
539  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
540  llvm::Intrinsic::cos, 1, mod);
541  auto r = builder->CreateCall(fun, args);
542  r->setTailCall(true);
543  result_ = r;
544 }
545 
546 void LLVMVisitor::bvisit(const Piecewise &x)
547 {
548  std::vector<llvm::BasicBlock> blocks;
549 
550  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
551 
552  if (neq(*pw->get_vec().back().second, *boolTrue)) {
553  throw SymEngineException(
554  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
555  }
556 
557  if (pw->get_vec().size() > 2) {
558  PiecewiseVec rest = pw->get_vec();
559  rest.erase(rest.begin());
560  auto rest_pw = piecewise(std::move(rest));
561  PiecewiseVec new_pw;
562  new_pw.push_back(*pw->get_vec().begin());
563  new_pw.push_back({rest_pw, pw->get_vec().back().second});
564  pw = piecewise(std::move(new_pw))
565  ->rcp_from_this_cast<const Piecewise>();
566  } else if (pw->get_vec().size() < 2) {
567  throw SymEngineException("Invalid Piecewise object");
568  }
569 
570  auto cond_basic = pw->get_vec().front().second;
571  llvm::Value *cond = apply(*cond_basic);
572  // check if cond != 0.0
573  cond = builder->CreateFCmpONE(
574  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
575  "ifcond");
576  llvm::Function *function = builder->GetInsertBlock()->getParent();
577  // Create blocks for the then and else cases. Insert the 'then' block at
578  // the
579  // end of the function.
580  llvm::BasicBlock *then_bb
581  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
582  llvm::BasicBlock *else_bb
583  = llvm::BasicBlock::Create(mod->getContext(), "else");
584  llvm::BasicBlock *merge_bb
585  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
586  builder->CreateCondBr(cond, then_bb, else_bb);
587 
588  // Emit then value.
589  builder->SetInsertPoint(then_bb);
590  llvm::Value *then_value = apply(*pw->get_vec().front().first);
591  builder->CreateBr(merge_bb);
592 
593  // Codegen of 'then_value' can change the current block, update then_bb for
594  // the PHI.
595  then_bb = builder->GetInsertBlock();
596 
597  // Emit else block.
598 #if (LLVM_VERSION_MAJOR < 16)
599  function->getBasicBlockList().push_back(else_bb);
600 #else
601  function->insert(function->end(), else_bb);
602 #endif
603  builder->SetInsertPoint(else_bb);
604  llvm::Value *else_value = apply(*pw->get_vec().back().first);
605  builder->CreateBr(merge_bb);
606 
607  // Codegen of 'else_value' can change the current block, update else_bb for
608  // the PHI.
609  else_bb = builder->GetInsertBlock();
610 
611  // Emit merge block.
612 #if (LLVM_VERSION_MAJOR < 16)
613  function->getBasicBlockList().push_back(merge_bb);
614 #else
615  function->insert(function->end(), merge_bb);
616 #endif
617  builder->SetInsertPoint(merge_bb);
618  llvm::PHINode *phi_node
619  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
620 
621  phi_node->addIncoming(then_value, then_bb);
622  phi_node->addIncoming(else_value, else_bb);
623  result_ = phi_node;
624 }
625 
626 void LLVMVisitor::bvisit(const Sign &x)
627 {
628  const auto x2 = x.get_arg();
629  PiecewiseVec new_pw;
630  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
631  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
632  new_pw.push_back({real_double(1.0), boolTrue});
633  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
634  bvisit(*pw);
635 }
636 
637 void LLVMVisitor::bvisit(const Contains &cts)
638 {
639  llvm::Value *expr = apply(*cts.get_expr());
640  const auto set = cts.get_set();
641  if (is_a<Interval>(*set)) {
642  const auto &interv = down_cast<const Interval &>(*set);
643  llvm::Value *start = apply(*interv.get_start());
644  llvm::Value *end = apply(*interv.get_end());
645  const bool left_open = interv.get_left_open();
646  const bool right_open = interv.get_right_open();
647  llvm::Value *left_ok;
648  llvm::Value *right_ok;
649  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
650  : builder->CreateFCmpOLE(start, expr);
651  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
652  : builder->CreateFCmpOLE(expr, end);
653  result_ = builder->CreateAnd(left_ok, right_ok);
654  result_ = builder->CreateUIToFP(result_,
655  get_float_type(&mod->getContext()));
656  } else {
657  throw SymEngineException("LLVMVisitor: only ``Interval`` "
658  "implemented for ``Contains``.");
659  }
660 }
661 
662 void LLVMVisitor::bvisit(const Infty &x)
663 {
664  if (x.is_negative_infinity()) {
665  result_ = llvm::ConstantFP::getInfinity(
666  get_float_type(&mod->getContext()), true);
667  } else if (x.is_positive_infinity()) {
668  result_ = llvm::ConstantFP::getInfinity(
669  get_float_type(&mod->getContext()), false);
670  } else {
671  throw SymEngineException(
672  "LLVMDouble can only represent real valued infinity");
673  }
674 }
675 
676 void LLVMVisitor::bvisit(const NaN &x)
677 {
678  result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
679  /*negative=*/false, /*payload=*/0);
680 }
681 
682 void LLVMVisitor::bvisit(const BooleanAtom &x)
683 {
684  const bool val = x.get_val();
685  set_double(val ? 1.0 : 0.0);
686 }
687 
688 void LLVMVisitor::bvisit(const Log &x)
689 {
690  std::vector<llvm::Value *> args;
691  llvm::Function *fun;
692  args.push_back(apply(*x.get_arg()));
693  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
694  llvm::Intrinsic::log, 1, mod);
695  auto r = builder->CreateCall(fun, args);
696  r->setTailCall(true);
697  result_ = r;
698 }
699 
700 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
701  void LLVMVisitor::bvisit(const Class &x) \
702  { \
703  llvm::Value *value = nullptr; \
704  llvm::Value *tmp; \
705  set_double(0.0); \
706  llvm::Value *zero_val = result_; \
707  for (auto &p : x.get_container()) { \
708  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
709  if (value == nullptr) { \
710  value = tmp; \
711  } else { \
712  value = builder->method(value, tmp); \
713  } \
714  } \
715  result_ = builder->CreateUIToFP(value, \
716  get_float_type(&mod->getContext())); \
717  }
718 
719 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
720 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
721 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
722 
723 void LLVMVisitor::bvisit(const Not &x)
724 {
725  builder->CreateNot(apply(*x.get_arg()));
726 }
727 
728 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
729  void LLVMVisitor::bvisit(const Class &x) \
730  { \
731  llvm::Value *left = apply(*x.get_arg1()); \
732  llvm::Value *right = apply(*x.get_arg2()); \
733  result_ = builder->method(left, right); \
734  result_ = builder->CreateUIToFP(result_, \
735  get_float_type(&mod->getContext())); \
736  }
737 
738 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
739 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
740 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
741 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
742 
743 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
744  void LLVMDoubleVisitor::visit(const Class &x) \
745  { \
746  vec_basic basic_args = x.get_args(); \
747  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
748  std::vector<llvm::Value *> args; \
749  for (const auto &arg : basic_args) { \
750  args.push_back(apply(*arg)); \
751  } \
752  auto r = builder->CreateCall(func, args); \
753  r->setTailCall(true); \
754  result_ = r; \
755  } \
756  void LLVMFloatVisitor::visit(const Class &x) \
757  { \
758  vec_basic basic_args = x.get_args(); \
759  llvm::Function *func = get_external_function(#ext + std::string("f"), \
760  basic_args.size()); \
761  std::vector<llvm::Value *> args; \
762  for (const auto &arg : basic_args) { \
763  args.push_back(apply(*arg)); \
764  } \
765  auto r = builder->CreateCall(func, args); \
766  r->setTailCall(true); \
767  result_ = r; \
768  }
769 
770 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
771 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
772  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
773  void LLVMLongDoubleVisitor::visit(const Class &x) \
774  { \
775  vec_basic basic_args = x.get_args(); \
776  llvm::Function *func = get_external_function(#ext + std::string("l"), \
777  basic_args.size()); \
778  std::vector<llvm::Value *> args; \
779  for (const auto &arg : basic_args) { \
780  args.push_back(apply(*arg)); \
781  } \
782  auto r = builder->CreateCall(func, args); \
783  r->setTailCall(true); \
784  result_ = r; \
785  }
786 #else
787 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
788  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
789 #endif
790 
791 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
792 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
793 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
794 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
795 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
796 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
797 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
798 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
799 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
800 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
801 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
802 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
803 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
804 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
805 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
806 
807 void LLVMVisitor::bvisit(const Abs &x)
808 {
809  std::vector<llvm::Value *> args;
810  llvm::Function *fun;
811  args.push_back(apply(*x.get_arg()));
812  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
813  llvm::Intrinsic::fabs, 1, mod);
814  auto r = builder->CreateCall(fun, args);
815  r->setTailCall(true);
816  result_ = r;
817 }
818 
819 void LLVMVisitor::bvisit(const Min &x)
820 {
821  llvm::Value *value = nullptr;
822  llvm::Function *fun;
823  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
824  llvm::Intrinsic::minnum, 1, mod);
825  for (auto &arg : x.get_vec()) {
826  if (value != nullptr) {
827  std::vector<llvm::Value *> args;
828  args.push_back(value);
829  args.push_back(apply(*arg));
830  auto r = builder->CreateCall(fun, args);
831  r->setTailCall(true);
832  value = r;
833  } else {
834  value = apply(*arg);
835  }
836  }
837  result_ = value;
838 }
839 
840 void LLVMVisitor::bvisit(const Max &x)
841 {
842  llvm::Value *value = nullptr;
843  llvm::Function *fun;
844  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
845  llvm::Intrinsic::maxnum, 1, mod);
846  for (auto &arg : x.get_vec()) {
847  if (value != nullptr) {
848  std::vector<llvm::Value *> args;
849  args.push_back(value);
850  args.push_back(apply(*arg));
851  auto r = builder->CreateCall(fun, args);
852  r->setTailCall(true);
853  value = r;
854  } else {
855  value = apply(*arg);
856  }
857  }
858  result_ = value;
859 }
860 
861 void LLVMVisitor::bvisit(const Symbol &x)
862 {
863  unsigned i = 0;
864  for (auto &symb : symbols) {
865  if (eq(x, *symb)) {
866  result_ = symbol_ptrs[i];
867  return;
868  }
869  ++i;
870  }
871  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
872  if (it != replacement_symbol_ptrs.end()) {
873  result_ = it->second;
874  return;
875  }
876 
877  throw SymEngineException("Symbol " + x.__str__()
878  + " not in the symbols vector.");
879 }
880 
881 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
882  size_t nargs)
883 {
884  std::vector<llvm::Type *> func_args(nargs,
885  get_float_type(&mod->getContext()));
886  llvm::FunctionType *func_type = llvm::FunctionType::get(
887  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
888 
889  llvm::Function *func = mod->getFunction(name);
890  if (!func) {
891  func = llvm::Function::Create(
892  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
893  func->setCallingConv(llvm::CallingConv::C);
894  }
895  func->addFnAttr(llvm::Attribute::NoUnwind);
896  return func;
897 }
898 
899 void LLVMVisitor::bvisit(const Constant &x)
900 {
901  set_double(eval_double(x));
902 }
903 
904 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
905 void LLVMLongDoubleVisitor::visit(const Constant &x)
906 {
907  convert_from_mpfr(x);
908 }
909 #endif
910 
911 void LLVMVisitor::bvisit(const Basic &x)
912 {
913  throw NotImplementedError(x.__str__());
914 }
915 
916 const std::string &LLVMVisitor::dumps() const
917 {
918  return membuffer;
919 };
920 
921 void LLVMVisitor::loads(const std::string &s)
922 {
923  membuffer = s;
924  llvm::InitializeNativeTarget();
925  llvm::InitializeNativeTargetAsmPrinter();
926  llvm::InitializeNativeTargetAsmParser();
927  context = make_unique<llvm::LLVMContext>();
928 
929  // Create some module to put our function into it.
930  std::unique_ptr<llvm::Module> module
931  = make_unique<llvm::Module>("SymEngine", *context);
932  module->setDataLayout("");
933  mod = module.get();
934 
935  // Only defining the prototype for the function here.
936  // Since we know where the function is stored that's enough
937  // llvm::ObjectCache is designed for caching objects, but it
938  // is used here for loading one specific object.
939  auto F = get_function_type(context.get());
940 
941  std::string error;
942  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
943  llvm::EngineBuilder(std::move(module))
944  .setEngineKind(llvm::EngineKind::Kind::JIT)
945  .setOptLevel(CodeGenOptLevel::Aggressive)
946  .setErrorStr(&error)
947  .create());
948 
949  class MCJITObjectLoader : public llvm::ObjectCache
950  {
951  const std::string &s_;
952 
953  public:
954  MCJITObjectLoader(const std::string &s) : s_(s) {}
955  void notifyObjectCompiled(const llvm::Module *M,
956  llvm::MemoryBufferRef obj) override
957  {
958  }
959 
960  // No need to check M because there is only one function
961  // Return it after reading from the file.
962  std::unique_ptr<llvm::MemoryBuffer>
963  getObject(const llvm::Module *M) override
964  {
965  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
966  }
967  };
968 
969  MCJITObjectLoader loader(s);
970  executionengine->setObjectCache(&loader);
971  executionengine->finalizeObject();
972  // Set func to compiled function pointer
973  func = (intptr_t)executionengine->getPointerToFunction(F);
974 }
975 
976 void LLVMVisitor::bvisit(const Floor &x)
977 {
978  std::vector<llvm::Value *> args;
979  llvm::Function *fun;
980  args.push_back(apply(*x.get_arg()));
981  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
982  llvm::Intrinsic::floor, 1, mod);
983  auto r = builder->CreateCall(fun, args);
984  r->setTailCall(true);
985  result_ = r;
986 }
987 
988 void LLVMVisitor::bvisit(const Ceiling &x)
989 {
990  std::vector<llvm::Value *> args;
991  llvm::Function *fun;
992  args.push_back(apply(*x.get_arg()));
993  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
994  llvm::Intrinsic::ceil, 1, mod);
995  auto r = builder->CreateCall(fun, args);
996  r->setTailCall(true);
997  result_ = r;
998 }
999 
1000 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1001 {
1002  apply(*x.get_arg());
1003 }
1004 
1005 void LLVMVisitor::bvisit(const Truncate &x)
1006 {
1007  std::vector<llvm::Value *> args;
1008  llvm::Function *fun;
1009  args.push_back(apply(*x.get_arg()));
1010  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1011  llvm::Intrinsic::trunc, 1, mod);
1012  auto r = builder->CreateCall(fun, args);
1013  r->setTailCall(true);
1014  result_ = r;
1015 }
1016 
1017 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1018 {
1019  return llvm::Type::getDoubleTy(*context);
1020 }
1021 
1022 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1023 {
1024  return llvm::Type::getFloatTy(*context);
1025 }
1026 
1027 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1028 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1029 {
1030  return llvm::Type::getX86_FP80Ty(*context);
1031 }
1032 #endif
1033 
1034 } // namespace SymEngine
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:757
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:642
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127