llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
3 #include "llvm/ExecutionEngine/GenericValue.h"
4 #include "llvm/ExecutionEngine/MCJIT.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/Analysis/Passes.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/ExecutionEngine/ObjectCache.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Path.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <memory>
34 #include <vector>
35 #include <fstream>
36 
37 #include <symengine/llvm_double.h>
38 #include <symengine/eval_double.h>
39 #include <symengine/eval.h>
40 
41 namespace SymEngine
42 {
43 
44 #if (LLVM_VERSION_MAJOR >= 10)
45 using std::make_unique;
46 #else
47 using llvm::make_unique;
48 #endif
49 #if (LLVM_VERSION_MAJOR < 18)
50 using CodeGenOptLevel = llvm::CodeGenOpt::Level;
51 #else
52 using CodeGenOptLevel = llvm::CodeGenOptLevel;
53 #endif
54 
55 class IRBuilder : public llvm::IRBuilder<>
56 {
57 };
58 
59 LLVMVisitor::LLVMVisitor() = default;
60 LLVMVisitor::~LLVMVisitor() = default;
61 
62 llvm::Value *LLVMVisitor::apply(const Basic &b)
63 {
64  b.accept(*this);
65  return result_;
66 }
67 
68 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
69  unsigned opt_level)
70 {
71  init(x, {b.rcp_from_this()}, symbolic_cse, opt_level);
72 }
73 
74 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
75 {
77  for (int i = 0; i < 2; i++) {
78  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
79  }
80  llvm::FunctionType *function_type = llvm::FunctionType::get(
81  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
82  auto F = llvm::Function::Create(function_type,
83  llvm::Function::InternalLinkage, "", mod);
84  F->setCallingConv(llvm::CallingConv::C);
85 #if (LLVM_VERSION_MAJOR < 5)
86  {
87  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
88  llvm::AttributeSet PAS;
89  {
90  llvm::AttrBuilder B;
91  B.addAttribute(llvm::Attribute::ReadOnly);
92  B.addAttribute(llvm::Attribute::NoCapture);
93  PAS = llvm::AttributeSet::get(mod->getContext(), 1U, B);
94  }
95 
96  attrs.push_back(PAS);
97  {
98  llvm::AttrBuilder B;
99  B.addAttribute(llvm::Attribute::NoCapture);
100  PAS = llvm::AttributeSet::get(mod->getContext(), 2U, B);
101  }
102 
103  attrs.push_back(PAS);
104  {
105  llvm::AttrBuilder B;
106  B.addAttribute(llvm::Attribute::NoUnwind);
107  B.addAttribute(llvm::Attribute::UWTable);
108  PAS = llvm::AttributeSet::get(mod->getContext(), ~0U, B);
109  }
110 
111  attrs.push_back(PAS);
112 
113  F->setAttributes(llvm::AttributeSet::get(mod->getContext(), attrs));
114  }
115 #else
116  F->addParamAttr(0, llvm::Attribute::ReadOnly);
117  F->addParamAttr(0, llvm::Attribute::NoCapture);
118  F->addParamAttr(1, llvm::Attribute::NoCapture);
119  F->addFnAttr(llvm::Attribute::NoUnwind);
120 #if (LLVM_VERSION_MAJOR < 15)
121  F->addFnAttr(llvm::Attribute::UWTable);
122 #else
123  F->addFnAttr(llvm::Attribute::getWithUWTableKind(
124  *context, llvm::UWTableKind::Default));
125 #endif
126 #endif
127  return F;
128 }
129 
130 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
131  const bool symbolic_cse, unsigned opt_level)
132 {
133  executionengine.reset();
134  llvm::InitializeNativeTarget();
135  llvm::InitializeNativeTargetAsmPrinter();
136  llvm::InitializeNativeTargetAsmParser();
137  context = make_unique<llvm::LLVMContext>();
138  symbols = inputs;
139 
140  // Create some module to put our function into it.
142  = make_unique<llvm::Module>("SymEngine", *context.get());
143  module->setDataLayout("");
144  mod = module.get();
145 
146  auto F = get_function_type(context.get());
147 
148  // Add a basic block to the function. As before, it automatically
149  // inserts
150  // because of the last argument.
151  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
152 
153  // Create a basic block builder with default parameters. The builder
154  // will
155  // automatically append instructions to the basic block `BB'.
156  llvm::IRBuilder<> _builder(BB);
157  builder = reinterpret_cast<IRBuilder *>(&_builder);
158  builder->SetInsertPoint(BB);
159  auto fmf = llvm::FastMathFlags();
160  // fmf.setUnsafeAlgebra();
161  builder->setFastMathFlags(fmf);
162 
163  // Load all the symbols and create references
164  auto input_arg = &(*(F->args().begin()));
165  for (unsigned i = 0; i < inputs.size(); i++) {
166  if (not is_a<Symbol>(*inputs[i])) {
167  throw SymEngineException("Input contains a non-symbol.");
168  }
169  auto index
170  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
171  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
172  index);
173  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
174  symbol_ptrs.push_back(result_);
175  }
176 
177  auto it = F->args().begin();
178 #if (LLVM_VERSION_MAJOR < 5)
179  auto out = &(*(++it));
180 #else
181  auto out = &(*(it + 1));
182 #endif
183  std::vector<llvm::Value *> output_vals;
184 
185  if (symbolic_cse) {
186  vec_basic reduced_exprs;
187  vec_pair replacements;
188  // cse the outputs
189  SymEngine::cse(replacements, reduced_exprs, outputs);
190  for (auto &rep : replacements) {
191  // Store the replacement symbol values in a dictionary
192  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
193  }
194  // Generate IR for all the reduced exprs and save references
195  for (unsigned i = 0; i < outputs.size(); i++) {
196  output_vals.push_back(apply(*reduced_exprs[i]));
197  }
198  } else {
199  // Generate IR for all the output exprs and save references
200  for (unsigned i = 0; i < outputs.size(); i++) {
201  output_vals.push_back(apply(*outputs[i]));
202  }
203  }
204 
205  // Store all the output exprs at the end
206  for (unsigned i = 0; i < outputs.size(); i++) {
207  auto index
208  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
209  auto ptr
210  = builder->CreateGEP(get_float_type(context.get()), out, index);
211  builder->CreateStore(output_vals[i], ptr);
212  }
213 
214  // Create the return instruction and add it to the basic block
215  builder->CreateRetVoid();
216 
217  // Validate the generated code, checking for consistency.
218  llvm::verifyFunction(*F, &llvm::outs());
219 
220  // std::cout << "LLVM IR" << std::endl;
221  // #if (LLVM_VERSION_MAJOR < 5)
222  // module->dump();
223  // #else
224  // module->print(llvm::errs(), nullptr);
225  // #endif
226 
227  // module->print(llvm::errs(), nullptr);
228 
229  // Optimize the function using default passes from PassBuilder
230  // FunctionSimplificationPipeline for the opt_level
231 #if (LLVM_VERSION_MAJOR < 14)
232  using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
233 #else
234  using OptimizationLevel = llvm::OptimizationLevel;
235 #endif
236  llvm::PassBuilder PB;
237  llvm::ModuleAnalysisManager MAM;
238  llvm::CGSCCAnalysisManager CGAM;
239  llvm::FunctionAnalysisManager FAM;
240  llvm::LoopAnalysisManager LAM;
241  PB.registerModuleAnalyses(MAM);
242  PB.registerCGSCCAnalyses(CGAM);
243  PB.registerFunctionAnalyses(FAM);
244  PB.registerLoopAnalyses(LAM);
245  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
246  llvm::FunctionPassManager FPM;
247  OptimizationLevel pb_opt_level{OptimizationLevel::O3};
248  if (opt_level == 0) {
249  pb_opt_level = OptimizationLevel::O0;
250  } else if (opt_level == 1) {
251  pb_opt_level = OptimizationLevel::O1;
252  } else if (opt_level == 2) {
253  pb_opt_level = OptimizationLevel::O2;
254  }
255 #if (LLVM_VERSION_MAJOR < 6)
256  FPM = PB.buildFunctionSimplificationPipeline(pb_opt_level);
257 #elif (LLVM_VERSION_MAJOR < 12)
258  FPM = PB.buildFunctionSimplificationPipeline(
259  pb_opt_level, llvm::PassBuilder::ThinLTOPhase::None);
260 #else
261  FPM = PB.buildFunctionSimplificationPipeline(
262  pb_opt_level, llvm::ThinOrFullLTOPhase::None);
263 #endif
264  FPM.run(*F, FAM);
265 
266  // std::cout << "Optimized LLVM IR" << std::endl;
267  // module->print(llvm::errs(), nullptr);
268 
269  // Now we create the JIT.
270  std::string error;
271  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
272  llvm::EngineBuilder(std::move(module))
273  .setEngineKind(llvm::EngineKind::Kind::JIT)
274  .setOptLevel(static_cast<CodeGenOptLevel>(opt_level))
275  .setErrorStr(&error)
276  .create());
277 
278  // This is a hack to get the MemoryBuffer of a compiled object.
279  class MemoryBufferRefCallback : public llvm::ObjectCache
280  {
281  public:
282  std::string &ss_;
283  MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
284 
285  void notifyObjectCompiled(const llvm::Module *M,
286  llvm::MemoryBufferRef obj) override
287  {
288  const char *c = obj.getBufferStart();
289  // Saving the object code in a std::string
290  ss_.assign(c, obj.getBufferSize());
291  }
292 
294  getObject(const llvm::Module *M) override
295  {
296  return NULL;
297  }
298  };
299 
300  MemoryBufferRefCallback callback(membuffer);
301  executionengine->setObjectCache(&callback);
302  // std::cout << error << std::endl;
303  executionengine->finalizeObject();
304 
305  // Get the symbol's address
306  func = (intptr_t)executionengine->getPointerToFunction(F);
307  symbol_ptrs.clear();
308  replacement_symbol_ptrs.clear();
309  symbols.clear();
310 }
311 
312 LLVMDoubleVisitor::LLVMDoubleVisitor() = default;
313 LLVMDoubleVisitor::~LLVMDoubleVisitor() = default;
314 
315 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
316 {
317  double ret;
318  ((double (*)(const double *, double *))func)(vec.data(), &ret);
319  return ret;
320 }
321 
322 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
323 {
324  ((double (*)(const double *, double *))func)(inps, outs);
325 }
326 
327 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
328 long double
329 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
330 {
331  long double ret;
332  ((long double (*)(const long double *, long double *))func)(vec.data(),
333  &ret);
334  return ret;
335 }
336 
337 void LLVMLongDoubleVisitor::call(long double *outs,
338  const long double *inps) const
339 {
340  ((long double (*)(const long double *, long double *))func)(inps, outs);
341 }
342 #endif
343 
344 LLVMFloatVisitor::LLVMFloatVisitor() = default;
345 LLVMFloatVisitor::~LLVMFloatVisitor() = default;
346 
347 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
348 {
349  float ret;
350  ((float (*)(const float *, float *))func)(vec.data(), &ret);
351  return ret;
352 }
353 
354 void LLVMFloatVisitor::call(float *outs, const float *inps) const
355 {
356  ((float (*)(const float *, float *))func)(inps, outs);
357 }
358 
359 void LLVMVisitor::set_double(double d)
360 {
361  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
362 }
363 
364 void LLVMVisitor::bvisit(const Integer &x)
365 {
366  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
367  mp_get_d(x.as_integer_class()));
368 }
369 
370 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
371 
372 LLVMLongDoubleVisitor::LLVMLongDoubleVisitor() = default;
373 LLVMLongDoubleVisitor::~LLVMLongDoubleVisitor() = default;
374 
375 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
376 {
377 #ifndef HAVE_SYMENGINE_MPFR
378  throw NotImplementedError("Cannot convert to long double without MPFR");
379 #else
380  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
381  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
382  m->__str__());
383 #endif
384 }
385 
386 void LLVMLongDoubleVisitor::visit(const Integer &x)
387 {
388  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
389  x.__str__());
390 }
391 #endif
392 
393 void LLVMVisitor::bvisit(const Rational &x)
394 {
395  set_double(mp_get_d(x.as_rational_class()));
396 }
397 
398 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
399 void LLVMLongDoubleVisitor::visit(const Rational &x)
400 {
401  convert_from_mpfr(x);
402 }
403 #endif
404 
405 void LLVMVisitor::bvisit(const RealDouble &x)
406 {
407  set_double(x.i);
408 }
409 
410 #ifdef HAVE_SYMENGINE_MPFR
411 void LLVMVisitor::bvisit(const RealMPFR &x)
412 {
413  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
414 }
415 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
416 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
417 {
418  convert_from_mpfr(x);
419 }
420 #endif
421 #endif
422 
423 void LLVMVisitor::bvisit(const Add &x)
424 {
425  llvm::Value *tmp, *tmp1, *tmp2;
426  auto it = x.get_dict().begin();
427 
428  if (eq(*x.get_coef(), *zero)) {
429  // `x + 0.0` is not optimized out
430  if (eq(*one, *(it->second))) {
431  tmp = apply(*(it->first));
432  } else {
433  tmp1 = apply(*(it->first));
434  tmp2 = apply(*(it->second));
435  tmp = builder->CreateFMul(tmp1, tmp2);
436  }
437  ++it;
438  } else {
439  tmp = apply(*x.get_coef());
440  }
441 
442  for (; it != x.get_dict().end(); ++it) {
443  if (eq(*one, *(it->second))) {
444  tmp1 = apply(*(it->first));
445  tmp = builder->CreateFAdd(tmp, tmp1);
446  } else {
447  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
448  // tmp =
449  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
450  // llvm::Intrinsic::fma,
451  // 3, context), args);
452  tmp1 = apply(*(it->first));
453  tmp2 = apply(*(it->second));
454  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
455  }
456  }
457  result_ = tmp;
458 }
459 
460 void LLVMVisitor::bvisit(const Mul &x)
461 {
462  llvm::Value *tmp = nullptr;
463  bool first = true;
464  for (const auto &p : x.get_args()) {
465  if (first) {
466  tmp = apply(*p);
467  } else {
468  tmp = builder->CreateFMul(tmp, apply(*p));
469  }
470  first = false;
471  }
472  result_ = tmp;
473 }
474 
475 llvm::Function *LLVMVisitor::get_powi()
476 {
477  std::vector<llvm::Type *> arg_type;
478  arg_type.push_back(get_float_type(&mod->getContext()));
479 #if (LLVM_VERSION_MAJOR > 12)
480  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
481 #endif
482  return llvm::Intrinsic::getDeclaration(mod, llvm::Intrinsic::powi,
483  arg_type);
484 }
485 
486 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
487  unsigned n, llvm::Module *mod)
488 {
489  std::vector<llvm::Type *> arg_type(n, type);
490  return llvm::Intrinsic::getDeclaration(mod, id, arg_type);
491 }
492 
493 void LLVMVisitor::bvisit(const Pow &x)
494 {
496  llvm::Function *fun;
497  if (eq(*(x.get_base()), *E)) {
498  args.push_back(apply(*x.get_exp()));
499  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
500  llvm::Intrinsic::exp, 1, mod);
501 
502  } else if (eq(*(x.get_base()), *integer(2))) {
503  args.push_back(apply(*x.get_exp()));
504  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
505  llvm::Intrinsic::exp2, 1, mod);
506 
507  } else {
508  if (is_a<Integer>(*x.get_exp())) {
509  if (eq(*x.get_exp(), *integer(2))) {
510  llvm::Value *tmp = apply(*x.get_base());
511  result_ = builder->CreateFMul(tmp, tmp);
512  return;
513  } else {
514  args.push_back(apply(*x.get_base()));
515  int d = numeric_cast<int>(
516  mp_get_si(static_cast<const Integer &>(*x.get_exp())
517  .as_integer_class()));
518  result_ = llvm::ConstantInt::get(
519  llvm::Type::getInt32Ty(mod->getContext()), d, true);
520  args.push_back(result_);
521  fun = get_powi();
522  }
523  } else {
524  args.push_back(apply(*x.get_base()));
525  args.push_back(apply(*x.get_exp()));
526  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
527  llvm::Intrinsic::pow, 1, mod);
528  }
529  }
530  auto r = builder->CreateCall(fun, args);
531  r->setTailCall(true);
532  result_ = r;
533 }
534 
535 void LLVMVisitor::bvisit(const Sin &x)
536 {
538  llvm::Function *fun;
539  args.push_back(apply(*x.get_arg()));
540  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
541  llvm::Intrinsic::sin, 1, mod);
542  auto r = builder->CreateCall(fun, args);
543  r->setTailCall(true);
544  result_ = r;
545 }
546 
547 void LLVMVisitor::bvisit(const Cos &x)
548 {
550  llvm::Function *fun;
551  args.push_back(apply(*x.get_arg()));
552  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
553  llvm::Intrinsic::cos, 1, mod);
554  auto r = builder->CreateCall(fun, args);
555  r->setTailCall(true);
556  result_ = r;
557 }
558 
559 void LLVMVisitor::bvisit(const Piecewise &x)
560 {
562 
563  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
564 
565  if (neq(*pw->get_vec().back().second, *boolTrue)) {
566  throw SymEngineException(
567  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
568  }
569 
570  if (pw->get_vec().size() > 2) {
571  PiecewiseVec rest = pw->get_vec();
572  rest.erase(rest.begin());
573  auto rest_pw = piecewise(std::move(rest));
574  PiecewiseVec new_pw;
575  new_pw.push_back(*pw->get_vec().begin());
576  new_pw.push_back({rest_pw, pw->get_vec().back().second});
577  pw = piecewise(std::move(new_pw))
578  ->rcp_from_this_cast<const Piecewise>();
579  } else if (pw->get_vec().size() < 2) {
580  throw SymEngineException("Invalid Piecewise object");
581  }
582 
583  auto cond_basic = pw->get_vec().front().second;
584  llvm::Value *cond = apply(*cond_basic);
585  // check if cond != 0.0
586  cond = builder->CreateFCmpONE(
587  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
588  "ifcond");
589  llvm::Function *function = builder->GetInsertBlock()->getParent();
590  // Create blocks for the then and else cases. Insert the 'then' block at
591  // the
592  // end of the function.
593  llvm::BasicBlock *then_bb
594  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
595  llvm::BasicBlock *else_bb
596  = llvm::BasicBlock::Create(mod->getContext(), "else");
597  llvm::BasicBlock *merge_bb
598  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
599  builder->CreateCondBr(cond, then_bb, else_bb);
600 
601  // Emit then value.
602  builder->SetInsertPoint(then_bb);
603  llvm::Value *then_value = apply(*pw->get_vec().front().first);
604  builder->CreateBr(merge_bb);
605 
606  // Codegen of 'then_value' can change the current block, update then_bb for
607  // the PHI.
608  then_bb = builder->GetInsertBlock();
609 
610  // Emit else block.
611 #if (LLVM_VERSION_MAJOR < 16)
612  function->getBasicBlockList().push_back(else_bb);
613 #else
614  function->insert(function->end(), else_bb);
615 #endif
616  builder->SetInsertPoint(else_bb);
617  llvm::Value *else_value = apply(*pw->get_vec().back().first);
618  builder->CreateBr(merge_bb);
619 
620  // Codegen of 'else_value' can change the current block, update else_bb for
621  // the PHI.
622  else_bb = builder->GetInsertBlock();
623 
624  // Emit merge block.
625 #if (LLVM_VERSION_MAJOR < 16)
626  function->getBasicBlockList().push_back(merge_bb);
627 #else
628  function->insert(function->end(), merge_bb);
629 #endif
630  builder->SetInsertPoint(merge_bb);
631  llvm::PHINode *phi_node
632  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
633 
634  phi_node->addIncoming(then_value, then_bb);
635  phi_node->addIncoming(else_value, else_bb);
636  result_ = phi_node;
637 }
638 
639 void LLVMVisitor::bvisit(const Sign &x)
640 {
641  const auto x2 = x.get_arg();
642  PiecewiseVec new_pw;
643  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
644  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
645  new_pw.push_back({real_double(1.0), boolTrue});
646  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
647  bvisit(*pw);
648 }
649 
650 void LLVMVisitor::bvisit(const Contains &cts)
651 {
652  llvm::Value *expr = apply(*cts.get_expr());
653  const auto set = cts.get_set();
654  if (is_a<Interval>(*set)) {
655  const auto &interv = down_cast<const Interval &>(*set);
656  llvm::Value *start = apply(*interv.get_start());
657  llvm::Value *end = apply(*interv.get_end());
658  const bool left_open = interv.get_left_open();
659  const bool right_open = interv.get_right_open();
660  llvm::Value *left_ok;
661  llvm::Value *right_ok;
662  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
663  : builder->CreateFCmpOLE(start, expr);
664  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
665  : builder->CreateFCmpOLE(expr, end);
666  result_ = builder->CreateAnd(left_ok, right_ok);
667  result_ = builder->CreateUIToFP(result_,
668  get_float_type(&mod->getContext()));
669  } else {
670  throw SymEngineException("LLVMVisitor: only ``Interval`` "
671  "implemented for ``Contains``.");
672  }
673 }
674 
675 void LLVMVisitor::bvisit(const Infty &x)
676 {
677  if (x.is_negative_infinity()) {
678  result_ = llvm::ConstantFP::getInfinity(
679  get_float_type(&mod->getContext()), true);
680  } else if (x.is_positive_infinity()) {
681  result_ = llvm::ConstantFP::getInfinity(
682  get_float_type(&mod->getContext()), false);
683  } else {
684  throw SymEngineException(
685  "LLVMDouble can only represent real valued infinity");
686  }
687 }
688 
689 void LLVMVisitor::bvisit(const NaN &x)
690 {
691  result_ = llvm::ConstantFP::getNaN(get_float_type(&mod->getContext()),
692  /*negative=*/false, /*payload=*/0);
693 }
694 
695 void LLVMVisitor::bvisit(const BooleanAtom &x)
696 {
697  const bool val = x.get_val();
698  set_double(val ? 1.0 : 0.0);
699 }
700 
701 void LLVMVisitor::bvisit(const Log &x)
702 {
704  llvm::Function *fun;
705  args.push_back(apply(*x.get_arg()));
706  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
707  llvm::Intrinsic::log, 1, mod);
708  auto r = builder->CreateCall(fun, args);
709  r->setTailCall(true);
710  result_ = r;
711 }
712 
713 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
714  void LLVMVisitor::bvisit(const Class &x) \
715  { \
716  llvm::Value *value = nullptr; \
717  llvm::Value *tmp; \
718  set_double(0.0); \
719  llvm::Value *zero_val = result_; \
720  for (auto &p : x.get_container()) { \
721  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
722  if (value == nullptr) { \
723  value = tmp; \
724  } else { \
725  value = builder->method(value, tmp); \
726  } \
727  } \
728  result_ = builder->CreateUIToFP(value, \
729  get_float_type(&mod->getContext())); \
730  }
731 
732 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
733 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
734 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
735 
736 void LLVMVisitor::bvisit(const Not &x)
737 {
738  builder->CreateNot(apply(*x.get_arg()));
739 }
740 
741 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
742  void LLVMVisitor::bvisit(const Class &x) \
743  { \
744  llvm::Value *left = apply(*x.get_arg1()); \
745  llvm::Value *right = apply(*x.get_arg2()); \
746  result_ = builder->method(left, right); \
747  result_ = builder->CreateUIToFP(result_, \
748  get_float_type(&mod->getContext())); \
749  }
750 
751 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
752 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
753 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
754 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
755 
756 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
757  void LLVMDoubleVisitor::visit(const Class &x) \
758  { \
759  vec_basic basic_args = x.get_args(); \
760  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
761  std::vector<llvm::Value *> args; \
762  for (const auto &arg : basic_args) { \
763  args.push_back(apply(*arg)); \
764  } \
765  auto r = builder->CreateCall(func, args); \
766  r->setTailCall(true); \
767  result_ = r; \
768  } \
769  void LLVMFloatVisitor::visit(const Class &x) \
770  { \
771  vec_basic basic_args = x.get_args(); \
772  llvm::Function *func = get_external_function(#ext + std::string("f"), \
773  basic_args.size()); \
774  std::vector<llvm::Value *> args; \
775  for (const auto &arg : basic_args) { \
776  args.push_back(apply(*arg)); \
777  } \
778  auto r = builder->CreateCall(func, args); \
779  r->setTailCall(true); \
780  result_ = r; \
781  }
782 
783 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
784 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
785  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
786  void LLVMLongDoubleVisitor::visit(const Class &x) \
787  { \
788  vec_basic basic_args = x.get_args(); \
789  llvm::Function *func = get_external_function(#ext + std::string("l"), \
790  basic_args.size()); \
791  std::vector<llvm::Value *> args; \
792  for (const auto &arg : basic_args) { \
793  args.push_back(apply(*arg)); \
794  } \
795  auto r = builder->CreateCall(func, args); \
796  r->setTailCall(true); \
797  result_ = r; \
798  }
799 #else
800 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
801  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
802 #endif
803 
804 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
805 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
806 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
807 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
808 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
809 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
810 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
811 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
812 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
813 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
814 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
815 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
816 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
817 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
818 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
819 
820 void LLVMVisitor::bvisit(const Abs &x)
821 {
823  llvm::Function *fun;
824  args.push_back(apply(*x.get_arg()));
825  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
826  llvm::Intrinsic::fabs, 1, mod);
827  auto r = builder->CreateCall(fun, args);
828  r->setTailCall(true);
829  result_ = r;
830 }
831 
832 void LLVMVisitor::bvisit(const Min &x)
833 {
834  llvm::Value *value = nullptr;
835  llvm::Function *fun;
836  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
837  llvm::Intrinsic::minnum, 1, mod);
838  for (auto &arg : x.get_vec()) {
839  if (value != nullptr) {
841  args.push_back(value);
842  args.push_back(apply(*arg));
843  auto r = builder->CreateCall(fun, args);
844  r->setTailCall(true);
845  value = r;
846  } else {
847  value = apply(*arg);
848  }
849  }
850  result_ = value;
851 }
852 
853 void LLVMVisitor::bvisit(const Max &x)
854 {
855  llvm::Value *value = nullptr;
856  llvm::Function *fun;
857  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
858  llvm::Intrinsic::maxnum, 1, mod);
859  for (auto &arg : x.get_vec()) {
860  if (value != nullptr) {
862  args.push_back(value);
863  args.push_back(apply(*arg));
864  auto r = builder->CreateCall(fun, args);
865  r->setTailCall(true);
866  value = r;
867  } else {
868  value = apply(*arg);
869  }
870  }
871  result_ = value;
872 }
873 
874 void LLVMVisitor::bvisit(const Symbol &x)
875 {
876  unsigned i = 0;
877  for (auto &symb : symbols) {
878  if (eq(x, *symb)) {
879  result_ = symbol_ptrs[i];
880  return;
881  }
882  ++i;
883  }
884  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
885  if (it != replacement_symbol_ptrs.end()) {
886  result_ = it->second;
887  return;
888  }
889 
890  throw SymEngineException("Symbol " + x.__str__()
891  + " not in the symbols vector.");
892 }
893 
894 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
895  size_t nargs)
896 {
897  std::vector<llvm::Type *> func_args(nargs,
898  get_float_type(&mod->getContext()));
899  llvm::FunctionType *func_type = llvm::FunctionType::get(
900  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
901 
902  llvm::Function *func = mod->getFunction(name);
903  if (!func) {
904  func = llvm::Function::Create(
905  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
906  func->setCallingConv(llvm::CallingConv::C);
907  }
908 #if (LLVM_VERSION_MAJOR < 5)
909  llvm::AttributeSet func_attr_set;
910  {
911  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
912  llvm::AttributeSet attr_set;
913  {
914  llvm::AttrBuilder attr_builder;
915  attr_builder.addAttribute(llvm::Attribute::NoUnwind);
916  attr_set
917  = llvm::AttributeSet::get(mod->getContext(), ~0U, attr_builder);
918  }
919 
920  attrs.push_back(attr_set);
921  func_attr_set = llvm::AttributeSet::get(mod->getContext(), attrs);
922  }
923  func->setAttributes(func_attr_set);
924 #else
925  func->addFnAttr(llvm::Attribute::NoUnwind);
926 #endif
927  return func;
928 }
929 
930 void LLVMVisitor::bvisit(const Constant &x)
931 {
932  set_double(eval_double(x));
933 }
934 
935 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
936 void LLVMLongDoubleVisitor::visit(const Constant &x)
937 {
938  convert_from_mpfr(x);
939 }
940 #endif
941 
942 void LLVMVisitor::bvisit(const Basic &x)
943 {
944  throw NotImplementedError(x.__str__());
945 }
946 
947 const std::string &LLVMVisitor::dumps() const
948 {
949  return membuffer;
950 };
951 
952 void LLVMVisitor::loads(const std::string &s)
953 {
954  membuffer = s;
955  llvm::InitializeNativeTarget();
956  llvm::InitializeNativeTargetAsmPrinter();
957  llvm::InitializeNativeTargetAsmParser();
958  context = make_unique<llvm::LLVMContext>();
959 
960  // Create some module to put our function into it.
962  = make_unique<llvm::Module>("SymEngine", *context);
963  module->setDataLayout("");
964  mod = module.get();
965 
966  // Only defining the prototype for the function here.
967  // Since we know where the function is stored that's enough
968  // llvm::ObjectCache is designed for caching objects, but it
969  // is used here for loading one specific object.
970  auto F = get_function_type(context.get());
971 
972  std::string error;
973  executionengine = std::unique_ptr<llvm::ExecutionEngine>(
974  llvm::EngineBuilder(std::move(module))
975  .setEngineKind(llvm::EngineKind::Kind::JIT)
976  .setOptLevel(CodeGenOptLevel::Aggressive)
977  .setErrorStr(&error)
978  .create());
979 
980  class MCJITObjectLoader : public llvm::ObjectCache
981  {
982  const std::string &s_;
983 
984  public:
985  MCJITObjectLoader(const std::string &s) : s_(s) {}
986  void notifyObjectCompiled(const llvm::Module *M,
987  llvm::MemoryBufferRef obj) override
988  {
989  }
990 
991  // No need to check M because there is only one function
992  // Return it after reading from the file.
994  getObject(const llvm::Module *M) override
995  {
996  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
997  }
998  };
999 
1000  MCJITObjectLoader loader(s);
1001  executionengine->setObjectCache(&loader);
1002  executionengine->finalizeObject();
1003  // Set func to compiled function pointer
1004  func = (intptr_t)executionengine->getPointerToFunction(F);
1005 }
1006 
1007 void LLVMVisitor::bvisit(const Floor &x)
1008 {
1010  llvm::Function *fun;
1011  args.push_back(apply(*x.get_arg()));
1012  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1013  llvm::Intrinsic::floor, 1, mod);
1014  auto r = builder->CreateCall(fun, args);
1015  r->setTailCall(true);
1016  result_ = r;
1017 }
1018 
1019 void LLVMVisitor::bvisit(const Ceiling &x)
1020 {
1022  llvm::Function *fun;
1023  args.push_back(apply(*x.get_arg()));
1024  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1025  llvm::Intrinsic::ceil, 1, mod);
1026  auto r = builder->CreateCall(fun, args);
1027  r->setTailCall(true);
1028  result_ = r;
1029 }
1030 
1031 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1032 {
1033  apply(*x.get_arg());
1034 }
1035 
1036 void LLVMVisitor::bvisit(const Truncate &x)
1037 {
1039  llvm::Function *fun;
1040  args.push_back(apply(*x.get_arg()));
1041  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1042  llvm::Intrinsic::trunc, 1, mod);
1043  auto r = builder->CreateCall(fun, args);
1044  r->setTailCall(true);
1045  result_ = r;
1046 }
1047 
1048 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1049 {
1050  return llvm::Type::getDoubleTy(*context);
1051 }
1052 
1053 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1054 {
1055  return llvm::Type::getFloatTy(*context);
1056 }
1057 
1058 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1059 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1060 {
1061  return llvm::Type::getX86_FP80Ty(*context);
1062 }
1063 #endif
1064 
1065 } // namespace SymEngine
T assign(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
T data(T... args)
T end(T... args)
T erase(T... args)
T get(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1402
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:757
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1614
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1360
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:1007
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2212
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1524
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2376
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2290
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2494
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2927
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2461
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:642
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2891
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2127
T push_back(T... args)