llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <memory>
34 #include <vector>
35 #include <fstream>
36 
37 #include <symengine/llvm_double.h>
38 #include <symengine/eval_double.h>
39 #include <symengine/eval.h>
40 
41 namespace SymEngine
42 {
43 
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
46 #else
47 using llvm::make_unique;
48 #endif
49 
50 #if (LLVM_VERSION_MAJOR >= 18)
51 using CodeGenOptLevel = llvm::CodeGenOptLevel;
52 #else
53 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
54 #endif
55 
56 #if (LLVM_VERSION_MAJOR >= 20)
57 auto GetDeclaration = [](llvm::Module *M, llvm::Intrinsic::ID id,
58  llvm::ArrayRef<llvm::Type *> Tys) {
59  return llvm::Intrinsic::getOrInsertDeclaration(M, id, Tys);
60 };
61 #else
62 const auto &GetDeclaration = llvm::Intrinsic::getDeclaration;
63 #endif
64 
65 #if (LLVM_VERSION_MAJOR >= 21)
66 #define AddNoCapture(A) A.addCapturesAttr(llvm::CaptureInfo::none())
67 #else
68 #define AddNoCapture(A) A.addAttribute(llvm::Attribute::NoCapture)
69 #endif
70 
71 class IRBuilder : public llvm::IRBuilder<>
72 {
73 };
74 
75 LLVMVisitor::LLVMVisitor() = default;
76 LLVMVisitor::~LLVMVisitor() = default;
77 
78 llvm::Value *LLVMVisitor::apply(const Basic &b)
79 {
80  b.accept(*this);
81  return result_;
82 }
83 
84 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
85  unsigned opt_level)
86 {
87  init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
88 }
89 
90 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
91 {
92  std::vector<llvm::Type *> inp;
93  for (int i = 0; i < 2; i++) {
94 #if (LLVM_VERSION_MAJOR >= 22)
95  inp.push_back(llvm::PointerType::get(*context, 0));
96 #else
97  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
98 #endif
99  }
100  llvm::FunctionType *function_type = llvm::FunctionType::get(
101  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
102  auto F = llvm::Function::Create(
103  function_type, llvm::Function::InternalLinkage, "symengine_func", mod);
104  F->setCallingConv(llvm::CallingConv::C);
105  F->addParamAttr(0, llvm::Attribute::ReadOnly);
106 #if (LLVM_VERSION_MAJOR >= 21)
107  F->addParamAttr(1, llvm::Attribute::getWithCaptureInfo(
108  *context, llvm::CaptureInfo::none()));
109  F->addParamAttr(0, llvm::Attribute::getWithCaptureInfo(
110  *context, llvm::CaptureInfo::none()));
111 #else
112  F->addParamAttr(0, llvm::Attribute::NoCapture);
113  F->addParamAttr(1, llvm::Attribute::NoCapture);
114 #endif
115  F->addFnAttr(llvm::Attribute::NoUnwind);
116 #if (LLVM_VERSION_MAJOR < 15)
117  F->addFnAttr(llvm::Attribute::UWTable);
118 #else
119  F->addFnAttr(llvm::Attribute::getWithUWTableKind(
120  *context, llvm::UWTableKind::Default));
121 #endif
122  return F;
123 }
124 
125 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
126  const bool symbolic_cse, unsigned opt_level)
127 {
128  executionengine.reset();
129  llvm::InitializeNativeTarget();
130  llvm::InitializeNativeTargetAsmPrinter();
131  llvm::InitializeNativeTargetAsmParser();
132  context = make_unique<llvm::LLVMContext>();
133  symbols = inputs;
134 
135  // Create some module to put our function into it.
136  std::unique_ptr<llvm::Module> module
137  = make_unique<llvm::Module>("SymEngine", *context.get());
138  module->setDataLayout("");
139  mod = module.get();
140 
141  auto F = get_function_type(context.get());
142 
143  // Add a basic block to the function. As before, it automatically
144  // inserts
145  // because of the last argument.
146  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
147 
148  // Create a basic block builder with default parameters. The builder
149  // will
150  // automatically append instructions to the basic block `BB'.
151  llvm::IRBuilder<> _builder(BB);
152  builder = reinterpret_cast<IRBuilder *>(&_builder);
153  builder->SetInsertPoint(BB);
154  auto fmf = llvm::FastMathFlags();
155  // fmf.setUnsafeAlgebra();
156  builder->setFastMathFlags(fmf);
157 
158  // Load all the symbols and create references
159  auto input_arg = &(*(F->args().begin()));
160  for (unsigned i = 0; i < inputs.size(); i++) {
161  if (not is_a<Symbol>(*inputs[i])) {
162  throw SymEngineException("Input contains a non-symbol.");
163  }
164  auto index
165  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
166  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
167  index);
168  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
169  symbol_ptrs.push_back(result_);
170  }
171 
172  auto it = F->args().begin();
173  auto out = &(*(it + 1));
174  std::vector<llvm::Value *> output_vals;
175 
176  if (symbolic_cse) {
177  vec_basic reduced_exprs;
178  vec_pair replacements;
179  // cse the outputs
180  SymEngine::cse(replacements, reduced_exprs, outputs);
181  for (auto &rep : replacements) {
182  // Store the replacement symbol values in a dictionary
183  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
184  }
185  // Generate IR for all the reduced exprs and save references
186  for (unsigned i = 0; i < outputs.size(); i++) {
187  output_vals.push_back(apply(*reduced_exprs[i]));
188  }
189  } else {
190  // Generate IR for all the output exprs and save references
191  for (unsigned i = 0; i < outputs.size(); i++) {
192  output_vals.push_back(apply(*outputs[i]));
193  }
194  }
195 
196  // Store all the output exprs at the end
197  for (unsigned i = 0; i < outputs.size(); i++) {
198  auto index
199  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
200  auto ptr
201  = builder->CreateGEP(get_float_type(context.get()), out, index);
202  builder->CreateStore(output_vals[i], ptr);
203  }
204 
205  // Create the return instruction and add it to the basic block
206  builder->CreateRetVoid();
207 
208  // Validate the generated code, checking for consistency.
209  llvm::verifyFunction(*F, &llvm::outs());
210 
211  // std::cout << "LLVM IR" << std::endl;
212  // module->print(llvm::errs(), nullptr);
213 
214  // Optimize the function using default passes from PassBuilder
215  // FunctionSimplificationPipeline for the opt_level
216 #if (LLVM_VERSION_MAJOR < 14)
217  using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
218 #else
219  using OptimizationLevel = llvm::OptimizationLevel;
220 #endif
221  llvm::PassBuilder PB;
222  llvm::ModuleAnalysisManager MAM;
223  llvm::CGSCCAnalysisManager CGAM;
224  llvm::FunctionAnalysisManager FAM;
225  llvm::LoopAnalysisManager LAM;
226  PB.registerModuleAnalyses(MAM);
227  PB.registerCGSCCAnalyses(CGAM);
228  PB.registerFunctionAnalyses(FAM);
229  PB.registerLoopAnalyses(LAM);
230  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
231  llvm::FunctionPassManager FPM;
232  OptimizationLevel pb_opt_level{OptimizationLevel::O3};
233  if (opt_level == 0) {
234  pb_opt_level = OptimizationLevel::O0;
235  } else if (opt_level == 1) {
236  pb_opt_level = OptimizationLevel::O1;
237  } else if (opt_level == 2) {
238  pb_opt_level = OptimizationLevel::O2;
239  }
240 
241  if (opt_level != 0) {
242 #if (LLVM_VERSION_MAJOR < 6)
243  FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
244 #elif (LLVM_VERSION_MAJOR < 12)
245  FPM = PB.buildFunctionSimplificationPipeline(
246  pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
247 #else
248  FPM = PB.buildFunctionSimplificationPipeline(
249  pb_opt_level, llvm::ThinOrFullLTOPhase::None);
250 #endif
251  FPM.run(*F, FAM);
252  }
253 
254  // std::cout << "Optimized LLVM IR" << std::endl;
255  // module->print(llvm::errs(), nullptr);
256 
257  // Now we create the JIT.
258  std::string error;
259  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
260  llvm::EngineBuilder(std::move(module))
261  .setEngineKind(llvm::EngineKind::Kind::JIT)
262  .setOptLevel(static_cast<CodeGenOptLevel>(opt_level))
263  .setErrorStr(&error)
264  .create());
265 
266  // This is a hack to get the MemoryBuffer of a compiled object.
267  class MemoryBufferRefCallback : public llvm::ObjectCache
268  {
269  public:
270  std::string &ss_;
271  explicit MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
272 
273  void notifyObjectCompiled(const llvm::Module *M,
274  llvm::MemoryBufferRef obj) override
275  {
276  const char *c = obj.getBufferStart();
277  // Saving the object code in a std::string
278  ss_.assign(c, obj.getBufferSize());
279  }
280 
281  std::unique_ptr<llvm::MemoryBuffer>
282  getObject(const llvm::Module *M) override
283  {
284  return nullptr;
285  }
286  };
287 
288  MemoryBufferRefCallback callback(membuffer);
289  executionengine->setObjectCache(&callback);
290  // std::cout << error << std::endl;
291  executionengine->finalizeObject();
292 
293  // Get the symbol's address
294  func = (intptr_t)executionengine->getPointerToFunction(F);
295  symbol_ptrs.clear();
296  replacement_symbol_ptrs.clear();
297  symbols.clear();
298 }
299 
300 LLVMDoubleVisitor::LLVMDoubleVisitor() = default;
301 LLVMDoubleVisitor::~LLVMDoubleVisitor() = default;
302 
303 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
304 {
305  double ret;
306  ((double (*)(const double *, double *))func)(vec.data(), &ret);
307  return ret;
308 }
309 
310 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
311 {
312  ((double (*)(const double *, double *))func)(inps, outs);
313 }
314 
315 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
316 long double
317 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
318 {
319  long double ret;
320  ((long double (*)(const long double *, long double *))func)(vec.data(),
321  &ret);
322  return ret;
323 }
324 
325 void LLVMLongDoubleVisitor::call(long double *outs,
326  const long double *inps) const
327 {
328  ((long double (*)(const long double *, long double *))func)(inps, outs);
329 }
330 #endif
331 
332 LLVMFloatVisitor::LLVMFloatVisitor() = default;
333 LLVMFloatVisitor::~LLVMFloatVisitor() = default;
334 
335 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
336 {
337  float ret;
338  ((float (*)(const float *, float *))func)(vec.data(), &ret);
339  return ret;
340 }
341 
342 void LLVMFloatVisitor::call(float *outs, const float *inps) const
343 {
344  ((float (*)(const float *, float *))func)(inps, outs);
345 }
346 
347 void LLVMVisitor::set_double(double d)
348 {
349  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
350 }
351 
352 void LLVMVisitor::bvisit(const Integer &x)
353 {
354  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
355  mp_get_d(x.as_integer_class()));
356 }
357 
358 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
359 
360 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() = default;
361 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() = default;
362 
363 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
364 {
365 #ifndef HAVE_SYMENGINE_MPFR
366  throw NotImplementedError("Cannot convert to long double without MPFR");
367 #else
368  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
369  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
370  m->__str__());
371 #endif
372 }
373 
374 void LLVMLongDoubleVisitor::visit(const Integer &x)
375 {
376  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
377  x.__str__());
378 }
379 #endif
380 
381 void LLVMVisitor::bvisit(const Rational &x)
382 {
383  set_double(mp_get_d(x.as_rational_class()));
384 }
385 
386 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
387 void LLVMLongDoubleVisitor::visit(const Rational &x)
388 {
389  convert_from_mpfr(x);
390 }
391 #endif
392 
393 void LLVMVisitor::bvisit(const RealDouble &x)
394 {
395  set_double(x.i);
396 }
397 
398 #ifdef HAVE_SYMENGINE_MPFR
399 void LLVMVisitor::bvisit(const RealMPFR &x)
400 {
401  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
402 }
403 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
404 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
405 {
406  convert_from_mpfr(x);
407 }
408 #endif
409 #endif
410 
411 void LLVMVisitor::bvisit(const Add &x)
412 {
413  llvm::Value *tmp, *tmp1, *tmp2;
414  auto it = x.get_dict().begin();
415 
416  if (eq(*x.get_coef(), *zero)) {
417  // `x + 0.0` is not optimized out
418  if (eq(*one, *(it->second))) {
419  tmp = apply(*(it->first));
420  } else {
421  tmp1 = apply(*(it->first));
422  tmp2 = apply(*(it->second));
423  tmp = builder->CreateFMul(tmp1, tmp2);
424  }
425  ++it;
426  } else {
427  tmp = apply(*x.get_coef());
428  }
429 
430  for (; it != x.get_dict().end(); ++it) {
431  if (eq(*one, *(it->second))) {
432  tmp1 = apply(*(it->first));
433  tmp = builder->CreateFAdd(tmp, tmp1);
434  } else {
435  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
436  // tmp =
437  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
438  // llvm::Intrinsic::fma,
439  // 3, context), args);
440  tmp1 = apply(*(it->first));
441  tmp2 = apply(*(it->second));
442  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
443  }
444  }
445  result_ = tmp;
446 }
447 
448 void LLVMVisitor::bvisit(const Mul &x)
449 {
450  llvm::Value *tmp = nullptr;
451  bool first = true;
452  for (const auto &p : x.get_args()) {
453  if (first) {
454  tmp = apply(*p);
455  } else {
456  tmp = builder->CreateFMul(tmp, apply(*p));
457  }
458  first = false;
459  }
460  result_ = tmp;
461 }
462 
463 llvm::Function *LLVMVisitor::get_powi()
464 {
465  std::vector<llvm::Type *> arg_type;
466  arg_type.push_back(get_float_type(&mod->getContext()));
467 #if (LLVM_VERSION_MAJOR > 12)
468  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
469 #endif
470  return GetDeclaration(mod, llvm::Intrinsic::powi, arg_type);
471 }
472 
473 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
474  unsigned n, llvm::Module *mod)
475 {
476  std::vector<llvm::Type *> arg_type(n, type);
477  return GetDeclaration(mod, id, arg_type);
478 }
479 
480 void LLVMVisitor::bvisit(const Pow &x)
481 {
482  std::vector<llvm::Value *> args;
483  llvm::Function *fun;
484  if (eq(*(x.get_base()), *E)) {
485  args.push_back(apply(*x.get_exp()));
486  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
487  llvm::Intrinsic::exp, 1, mod);
488 
489  } else if (eq(*(x.get_base()), *integer(2))) {
490  args.push_back(apply(*x.get_exp()));
491  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
492  llvm::Intrinsic::exp2, 1, mod);
493 
494  } else {
495  if (is_a<Integer>(*x.get_exp())) {
496  if (eq(*x.get_exp(), *integer(2))) {
497  llvm::Value *tmp = apply(*x.get_base());
498  result_ = builder->CreateFMul(tmp, tmp);
499  return;
500  } else {
501  args.push_back(apply(*x.get_base()));
502  int d = numeric_cast<int>(
503  mp_get_si(static_cast<const Integer &>(*x.get_exp())
504  .as_integer_class()));
505  result_ = llvm::ConstantInt::get(
506  llvm::Type::getInt32Ty(mod->getContext()), d, true);
507  args.push_back(result_);
508  fun = get_powi();
509  }
510  } else {
511  args.push_back(apply(*x.get_base()));
512  args.push_back(apply(*x.get_exp()));
513  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
514  llvm::Intrinsic::pow, 1, mod);
515  }
516  }
517  auto r = builder->CreateCall(fun, args);
518  r->setTailCall(true);
519  result_ = r;
520 }
521 
522 void LLVMVisitor::bvisit(const Sin &x)
523 {
524  std::vector<llvm::Value *> args;
525  llvm::Function *fun;
526  args.push_back(apply(*x.get_arg()));
527  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
528  llvm::Intrinsic::sin, 1, mod);
529  auto r = builder->CreateCall(fun, args);
530  r->setTailCall(true);
531  result_ = r;
532 }
533 
534 void LLVMVisitor::bvisit(const Cos &x)
535 {
536  std::vector<llvm::Value *> args;
537  llvm::Function *fun;
538  args.push_back(apply(*x.get_arg()));
539  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
540  llvm::Intrinsic::cos, 1, mod);
541  auto r = builder->CreateCall(fun, args);
542  r->setTailCall(true);
543  result_ = r;
544 }
545 
546 void LLVMVisitor::bvisit(const Piecewise &x)
547 {
548  std::vector<llvm::BasicBlock> blocks;
549 
550  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
551 
552  if (neq(*pw->get_vec().back().second, *boolTrue)) {
553  throw SymEngineException(
554  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
555  }
556 
557  if (pw->get_vec().size() > 2) {
558  PiecewiseVec rest = pw->get_vec();
559  rest.erase(rest.begin());
560  auto rest_pw = piecewise(std::move(rest));
561  PiecewiseVec new_pw;
562  new_pw.push_back(*pw->get_vec().begin());
563  new_pw.push_back({rest_pw, pw->get_vec().back().second});
564  pw = piecewise(std::move(new_pw))
565  ->rcp_from_this_cast<const Piecewise>();
566  } else if (pw->get_vec().size() < 2) {
567  throw SymEngineException("Invalid Piecewise object");
568  }
569 
570  auto cond_basic = pw->get_vec().front().second;
571  llvm::Value *cond = apply(*cond_basic);
572  // check if cond != 0.0
573  cond = builder->CreateFCmpONE(
574  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
575  "ifcond");
576  llvm::Function *function = builder->GetInsertBlock()->getParent();
577  // Create blocks for the then and else cases. Insert the 'then' block at
578  // the
579  // end of the function.
580  llvm::BasicBlock *then_bb
581  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
582  llvm::BasicBlock *else_bb
583  = llvm::BasicBlock::Create(mod->getContext(), "else");
584  llvm::BasicBlock *merge_bb
585  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
586  builder->CreateCondBr(cond, then_bb, else_bb);
587 
588  // Emit then value.
589  builder->SetInsertPoint(then_bb);
590  llvm::Value *then_value = apply(*pw->get_vec().front().first);
591  builder->CreateBr(merge_bb);
592 
593  // Codegen of 'then_value' can change the current block, update then_bb for
594  // the PHI.
595  then_bb = builder->GetInsertBlock();
596 
597  // Emit else block.
598 #if (LLVM_VERSION_MAJOR < 16)
599  function->getBasicBlockList().push_back(else_bb);
600 #else
601  function->insert(function->end(), else_bb);
602 #endif
603  builder->SetInsertPoint(else_bb);
604  llvm::Value *else_value = apply(*pw->get_vec().back().first);
605  builder->CreateBr(merge_bb);
606 
607  // Codegen of 'else_value' can change the current block, update else_bb for
608  // the PHI.
609  else_bb = builder->GetInsertBlock();
610 
611  // Emit merge block.
612 #if (LLVM_VERSION_MAJOR < 16)
613  function->getBasicBlockList().push_back(merge_bb);
614 #else
615  function->insert(function->end(), merge_bb);
616 #endif
617  builder->SetInsertPoint(merge_bb);
618  llvm::PHINode *phi_node
619  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
620 
621  phi_node->addIncoming(then_value, then_bb);
622  phi_node->addIncoming(else_value, else_bb);
623  result_ = phi_node;
624 }
625 
626 void LLVMVisitor::bvisit(const Sign &x)
627 {
628  const auto x2 = x.get_arg();
629  PiecewiseVec new_pw;
630  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
631  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
632  new_pw.push_back({real_double(1.0), boolTrue});
633  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
634  bvisit(*pw);
635 }
636 
637 void LLVMVisitor::bvisit(const Contains &cts)
638 {
639  llvm::Value *expr = apply(*cts.get_expr());
640  const auto set = cts.get_set();
641  if (is_a<Interval>(*set)) {
642  const auto &interv = down_cast<const Interval &>(*set);
643  llvm::Value *start = apply(*interv.get_start());
644  llvm::Value *end = apply(*interv.get_end());
645  const bool left_open = interv.get_left_open();
646  const bool right_open = interv.get_right_open();
647  llvm::Value *left_ok;
648  llvm::Value *right_ok;
649  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
650  : builder->CreateFCmpOLE(start, expr);
651  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
652  : builder->CreateFCmpOLE(expr, end);
653  result_ = builder->CreateAnd(left_ok, right_ok);
654  result_ = builder->CreateUIToFP(result_,
655  get_float_type(&mod->getContext()));
656  } else {
657  throw SymEngineException("LLVMVisitor: only ``Interval`` "
658  "implemented for ``Contains``.");
659  }
660 }
661 
662 void LLVMVisitor::bvisit(const Infty &x)
663 {
664  if (x.is_negative_infinity()) {
665  result_ = llvm::ConstantFP::getInfinity(
666  get_float_type(&mod->getContext()), true);
667  } else if (x.is_positive_infinity()) {
668  result_ = llvm::ConstantFP::getInfinity(
669  get_float_type(&mod->getContext()), false);
670  } else {
671  throw SymEngineException(
672  "LLVMDouble can only represent real valued infinity");
673  }
674 }
675 
676 void LLVMVisitor::bvisit(const NaN &x)
677 {
678  result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
679  /*negative=*/false, /*payload=*/0);
680 }
681 
682 void LLVMVisitor::bvisit(const BooleanAtom &x)
683 {
684  const bool val = x.get_val();
685  set_double(val ? 1.0 : 0.0);
686 }
687 
688 void LLVMVisitor::bvisit(const Log &x)
689 {
690  std::vector<llvm::Value *> args;
691  llvm::Function *fun;
692  args.push_back(apply(*x.get_arg()));
693  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
694  llvm::Intrinsic::log, 1, mod);
695  auto r = builder->CreateCall(fun, args);
696  r->setTailCall(true);
697  result_ = r;
698 }
699 
700 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
701  void LLVMVisitor::bvisit(const Class &x) \
702  { \
703  llvm::Value *value = nullptr; \
704  llvm::Value *tmp; \
705  set_double(0.0); \
706  llvm::Value *zero_val = result_; \
707  for (auto &p : x.get_container()) { \
708  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
709  if (value == nullptr) { \
710  value = tmp; \
711  } else { \
712  value = builder->method(value, tmp); \
713  } \
714  } \
715  result_ = builder->CreateUIToFP(value, \
716  get_float_type(&mod->getContext())); \
717  }
718 
719 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
720 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
721 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
722 
723 void LLVMVisitor::bvisit(const Not &x)
724 {
725  set_double(0.0);
726  llvm::Value *zero_val = result_;
727  llvm::Value *value = builder->CreateFCmpONE(apply(*x.get_arg()), zero_val);
728  result_ = builder->CreateUIToFP(builder->CreateNot(value),
729  get_float_type(&mod->getContext()));
730 }
731 
732 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
733  void LLVMVisitor::bvisit(const Class &x) \
734  { \
735  llvm::Value *left = apply(*x.get_arg1()); \
736  llvm::Value *right = apply(*x.get_arg2()); \
737  result_ = builder->method(left, right); \
738  result_ = builder->CreateUIToFP(result_, \
739  get_float_type(&mod->getContext())); \
740  }
741 
742 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
743 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
744 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
745 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
746 
747 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
748  void LLVMDoubleVisitor::visit(const Class &x) \
749  { \
750  vec_basic basic_args = x.get_args(); \
751  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
752  std::vector<llvm::Value *> args; \
753  for (const auto &arg : basic_args) { \
754  args.push_back(apply(*arg)); \
755  } \
756  auto r = builder->CreateCall(func, args); \
757  r->setTailCall(true); \
758  result_ = r; \
759  } \
760  void LLVMFloatVisitor::visit(const Class &x) \
761  { \
762  vec_basic basic_args = x.get_args(); \
763  llvm::Function *func = get_external_function(#ext + std::string("f"), \
764  basic_args.size()); \
765  std::vector<llvm::Value *> args; \
766  for (const auto &arg : basic_args) { \
767  args.push_back(apply(*arg)); \
768  } \
769  auto r = builder->CreateCall(func, args); \
770  r->setTailCall(true); \
771  result_ = r; \
772  }
773 
774 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
775 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
776  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
777  void LLVMLongDoubleVisitor::visit(const Class &x) \
778  { \
779  vec_basic basic_args = x.get_args(); \
780  llvm::Function *func = get_external_function(#ext + std::string("l"), \
781  basic_args.size()); \
782  std::vector<llvm::Value *> args; \
783  for (const auto &arg : basic_args) { \
784  args.push_back(apply(*arg)); \
785  } \
786  auto r = builder->CreateCall(func, args); \
787  r->setTailCall(true); \
788  result_ = r; \
789  }
790 #else
791 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
792  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
793 #endif
794 
795 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
796 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
797 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
798 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
799 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
800 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
801 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
802 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
803 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
804 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
805 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
806 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
807 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
808 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
809 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
810 
811 void LLVMVisitor::bvisit(const Abs &x)
812 {
813  std::vector<llvm::Value *> args;
814  llvm::Function *fun;
815  args.push_back(apply(*x.get_arg()));
816  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
817  llvm::Intrinsic::fabs, 1, mod);
818  auto r = builder->CreateCall(fun, args);
819  r->setTailCall(true);
820  result_ = r;
821 }
822 
823 void LLVMVisitor::bvisit(const Min &x)
824 {
825  llvm::Value *value = nullptr;
826  llvm::Function *fun;
827  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
828  llvm::Intrinsic::minnum, 1, mod);
829  for (auto &arg : x.get_vec()) {
830  if (value != nullptr) {
831  std::vector<llvm::Value *> args;
832  args.push_back(value);
833  args.push_back(apply(*arg));
834  auto r = builder->CreateCall(fun, args);
835  r->setTailCall(true);
836  value = r;
837  } else {
838  value = apply(*arg);
839  }
840  }
841  result_ = value;
842 }
843 
844 void LLVMVisitor::bvisit(const Max &x)
845 {
846  llvm::Value *value = nullptr;
847  llvm::Function *fun;
848  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
849  llvm::Intrinsic::maxnum, 1, mod);
850  for (auto &arg : x.get_vec()) {
851  if (value != nullptr) {
852  std::vector<llvm::Value *> args;
853  args.push_back(value);
854  args.push_back(apply(*arg));
855  auto r = builder->CreateCall(fun, args);
856  r->setTailCall(true);
857  value = r;
858  } else {
859  value = apply(*arg);
860  }
861  }
862  result_ = value;
863 }
864 
865 void LLVMVisitor::bvisit(const Symbol &x)
866 {
867  unsigned i = 0;
868  for (auto &symb : symbols) {
869  if (eq(x, *symb)) {
870  result_ = symbol_ptrs[i];
871  return;
872  }
873  ++i;
874  }
875  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
876  if (it != replacement_symbol_ptrs.end()) {
877  result_ = it->second;
878  return;
879  }
880 
881  throw SymEngineException("Symbol " + x.__str__()
882  + " not in the symbols vector.");
883 }
884 
885 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
886  size_t nargs)
887 {
888  std::vector<llvm::Type *> func_args(nargs,
889  get_float_type(&mod->getContext()));
890  llvm::FunctionType *func_type = llvm::FunctionType::get(
891  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
892 
893  llvm::Function *func = mod->getFunction(name);
894  if (!func) {
895  func = llvm::Function::Create(
896  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
897  func->setCallingConv(llvm::CallingConv::C);
898  }
899  func->addFnAttr(llvm::Attribute::NoUnwind);
900  return func;
901 }
902 
903 void LLVMVisitor::bvisit(const Constant &x)
904 {
905  set_double(eval_double(x));
906 }
907 
908 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
909 void LLVMLongDoubleVisitor::visit(const Constant &x)
910 {
911  convert_from_mpfr(x);
912 }
913 #endif
914 
915 void LLVMVisitor::bvisit(const Basic &x)
916 {
917  throw NotImplementedError(x.__str__());
918 }
919 
920 const std::string &LLVMVisitor::dumps() const
921 {
922  return membuffer;
923 };
924 
925 void LLVMVisitor::loads(const std::string &s)
926 {
927  membuffer = s;
928  llvm::InitializeNativeTarget();
929  llvm::InitializeNativeTargetAsmPrinter();
930  llvm::InitializeNativeTargetAsmParser();
931  context = make_unique<llvm::LLVMContext>();
932 
933  // Create some module to put our function into it.
934  std::unique_ptr<llvm::Module> module
935  = make_unique<llvm::Module>("SymEngine", *context);
936  module->setDataLayout("");
937  mod = module.get();
938 
939  // Only defining the prototype for the function here.
940  // Since we know where the function is stored that's enough
941  // llvm::ObjectCache is designed for caching objects, but it
942  // is used here for loading one specific object.
943  auto F = get_function_type(context.get());
944 
945  std::string error;
946  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
947  llvm::EngineBuilder(std::move(module))
948  .setEngineKind(llvm::EngineKind::Kind::JIT)
949  .setOptLevel(CodeGenOptLevel::Aggressive)
950  .setErrorStr(&error)
951  .create());
952 
953  class MCJITObjectLoader : public llvm::ObjectCache
954  {
955  const std::string &s_;
956 
957  public:
958  MCJITObjectLoader(const std::string &s) : s_(s) {}
959  void notifyObjectCompiled(const llvm::Module *M,
960  llvm::MemoryBufferRef obj) override
961  {
962  }
963 
964  // No need to check M because there is only one function
965  // Return it after reading from the file.
966  std::unique_ptr<llvm::MemoryBuffer>
967  getObject(const llvm::Module *M) override
968  {
969  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
970  }
971  };
972 
973  MCJITObjectLoader loader(s);
974  executionengine->setObjectCache(&loader);
975  executionengine->finalizeObject();
976  // Set func to compiled function pointer
977  func = (intptr_t)executionengine->getPointerToFunction(F);
978 }
979 
980 void LLVMVisitor::bvisit(const Floor &x)
981 {
982  std::vector<llvm::Value *> args;
983  llvm::Function *fun;
984  args.push_back(apply(*x.get_arg()));
985  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
986  llvm::Intrinsic::floor, 1, mod);
987  auto r = builder->CreateCall(fun, args);
988  r->setTailCall(true);
989  result_ = r;
990 }
991 
992 void LLVMVisitor::bvisit(const Ceiling &x)
993 {
994  std::vector<llvm::Value *> args;
995  llvm::Function *fun;
996  args.push_back(apply(*x.get_arg()));
997  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
998  llvm::Intrinsic::ceil, 1, mod);
999  auto r = builder->CreateCall(fun, args);
1000  r->setTailCall(true);
1001  result_ = r;
1002 }
1003 
1004 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1005 {
1006  apply(*x.get_arg());
1007 }
1008 
1009 void LLVMVisitor::bvisit(const Truncate &x)
1010 {
1011  std::vector<llvm::Value *> args;
1012  llvm::Function *fun;
1013  args.push_back(apply(*x.get_arg()));
1014  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1015  llvm::Intrinsic::trunc, 1, mod);
1016  auto r = builder->CreateCall(fun, args);
1017  r->setTailCall(true);
1018  result_ = r;
1019 }
1020 
1021 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1022 {
1023  return llvm::Type::getDoubleTy(*context);
1024 }
1025 
1026 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1027 {
1028  return llvm::Type::getFloatTy(*context);
1029 }
1030 
1031 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1032 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1033 {
1034  return llvm::Type::getX86_FP80Ty(*context);
1035 }
1036 #endif
1037 
1038 } // namespace SymEngine
The lowest unit of symbolic representation.
Definition: basic.h:97
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:757
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:642
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127