llvm_double.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/Analysis/Passes.h"
3 #include "llvm/ExecutionEngine/ExecutionEngine.h"
4 #include "llvm/ExecutionEngine/GenericValue.h"
5 #include "llvm/ExecutionEngine/MCJIT.h"
6 #include "llvm/IR/Argument.h"
7 #include "llvm/IR/Attributes.h"
8 #include "llvm/IR/BasicBlock.h"
9 #include "llvm/IR/Constants.h"
10 #include "llvm/IR/DerivedTypes.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/IR/LegacyPassManager.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/IR/Verifier.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/Transforms/Scalar.h"
29 #include "llvm/Transforms/Vectorize.h"
30 #include "llvm/ExecutionEngine/ObjectCache.h"
31 #include "llvm/Support/FileSystem.h"
32 #include "llvm/Support/Path.h"
33 #include <algorithm>
34 #include <cassert>
35 #include <memory>
36 #include <vector>
37 #include <fstream>
38 
39 #if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9) \
40  || (LLVM_VERSION_MAJOR > 3)
41 #include <llvm/Transforms/Scalar/GVN.h>
42 #endif
43 
44 #if (LLVM_VERSION_MAJOR >= 7)
45 #include <llvm/Transforms/InstCombine/InstCombine.h>
46 #include <llvm/Transforms/Scalar/InstSimplifyPass.h>
47 #include <llvm/Transforms/Utils.h>
48 #endif
49 
50 #include <symengine/llvm_double.h>
51 #include <symengine/eval_double.h>
52 #include <symengine/eval.h>
53 
54 namespace SymEngine
55 {
56 
57 #if (LLVM_VERSION_MAJOR >= 10)
58 using std::make_unique;
59 #else
60 using llvm::make_unique;
61 #endif
62 
63 class IRBuilder : public llvm::IRBuilder<>
64 {
65 };
66 
67 llvm::Value *LLVMVisitor::apply(const Basic &b)
68 {
69  b.accept(*this);
70  return result_;
71 }
72 
73 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
74  unsigned opt_level)
75 {
76  init(x, b, symbolic_cse, LLVMVisitor::create_default_passes(opt_level),
77  opt_level);
78 }
79 
80 void LLVMVisitor::init(const vec_basic &x, const Basic &b, bool symbolic_cse,
81  const std::vector<llvm::Pass *> &passes,
82  unsigned opt_level)
83 {
84  init(x, {b.rcp_from_this()}, symbolic_cse, passes, opt_level);
85 }
86 
87 llvm::Function *LLVMVisitor::get_function_type(llvm::LLVMContext *context)
88 {
90  for (int i = 0; i < 2; i++) {
91  inp.push_back(llvm::PointerType::get(get_float_type(context), 0));
92  }
93  llvm::FunctionType *function_type = llvm::FunctionType::get(
94  llvm::Type::getVoidTy(*context), inp, /*isVarArgs=*/false);
95  auto F = llvm::Function::Create(function_type,
96  llvm::Function::InternalLinkage, "", mod);
97  F->setCallingConv(llvm::CallingConv::C);
98 #if (LLVM_VERSION_MAJOR < 5)
99  {
100  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
101  llvm::AttributeSet PAS;
102  {
103  llvm::AttrBuilder B;
104  B.addAttribute(llvm::Attribute::ReadOnly);
105  B.addAttribute(llvm::Attribute::NoCapture);
106  PAS = llvm::AttributeSet::get(mod->getContext(), 1U, B);
107  }
108 
109  attrs.push_back(PAS);
110  {
111  llvm::AttrBuilder B;
112  B.addAttribute(llvm::Attribute::NoCapture);
113  PAS = llvm::AttributeSet::get(mod->getContext(), 2U, B);
114  }
115 
116  attrs.push_back(PAS);
117  {
118  llvm::AttrBuilder B;
119  B.addAttribute(llvm::Attribute::NoUnwind);
120  B.addAttribute(llvm::Attribute::UWTable);
121  PAS = llvm::AttributeSet::get(mod->getContext(), ~0U, B);
122  }
123 
124  attrs.push_back(PAS);
125 
126  F->setAttributes(llvm::AttributeSet::get(mod->getContext(), attrs));
127  }
128 #else
129  F->addParamAttr(0, llvm::Attribute::ReadOnly);
130  F->addParamAttr(0, llvm::Attribute::NoCapture);
131  F->addParamAttr(1, llvm::Attribute::NoCapture);
132  F->addFnAttr(llvm::Attribute::NoUnwind);
133  F->addFnAttr(llvm::Attribute::UWTable);
134 #endif
135  return F;
136 }
137 
138 std::vector<llvm::Pass *> LLVMVisitor::create_default_passes(int optlevel)
139 {
141  if (optlevel == 0) {
142  return passes;
143  }
144 #if (LLVM_VERSION_MAJOR < 4)
145  passes.push_back(llvm::createInstructionCombiningPass());
146 #else
147  passes.push_back(llvm::createInstructionCombiningPass(optlevel > 1));
148 #endif
149  passes.push_back(llvm::createDeadCodeEliminationPass());
150  passes.push_back(llvm::createPromoteMemoryToRegisterPass());
151  passes.push_back(llvm::createReassociatePass());
152  passes.push_back(llvm::createGVNPass());
153  passes.push_back(llvm::createCFGSimplificationPass());
154  passes.push_back(llvm::createPartiallyInlineLibCallsPass());
155 #if (LLVM_VERSION_MAJOR < 5)
156  passes.push_back(llvm::createLoadCombinePass());
157 #endif
158 #if LLVM_VERSION_MAJOR >= 7
159  passes.push_back(llvm::createInstSimplifyLegacyPass());
160 #else
161  passes.push_back(llvm::createInstructionSimplifierPass());
162 #endif
163  passes.push_back(llvm::createMemCpyOptPass());
164  passes.push_back(llvm::createSROAPass());
165  passes.push_back(llvm::createMergedLoadStoreMotionPass());
166  passes.push_back(llvm::createBitTrackingDCEPass());
167  passes.push_back(llvm::createAggressiveDCEPass());
168  if (optlevel > 2) {
169  passes.push_back(llvm::createSLPVectorizerPass());
170 #if LLVM_VERSION_MAJOR >= 7
171  passes.push_back(llvm::createInstSimplifyLegacyPass());
172 #else
173  passes.push_back(llvm::createInstructionSimplifierPass());
174 #endif
175  }
176  return passes;
177 }
178 
179 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
180  const bool symbolic_cse, unsigned opt_level)
181 {
182  init(inputs, outputs, symbolic_cse,
183  LLVMVisitor::create_default_passes(opt_level), opt_level);
184 }
185 
186 void LLVMVisitor::init(const vec_basic &inputs, const vec_basic &outputs,
187  const bool symbolic_cse,
188  const std::vector<llvm::Pass *> &passes,
189  unsigned opt_level)
190 {
191  executionengine.reset();
192  llvm::InitializeNativeTarget();
193  llvm::InitializeNativeTargetAsmPrinter();
194  llvm::InitializeNativeTargetAsmParser();
195  context = std::make_shared<llvm::LLVMContext>();
196  symbols = inputs;
197 
198  // Create some module to put our function into it.
200  = make_unique<llvm::Module>("SymEngine", *context.get());
201  module->setDataLayout("");
202  mod = module.get();
203 
204  // Create a new pass manager attached to it.
205  fpm = std::make_shared<llvm::legacy::FunctionPassManager>(mod);
206  for (auto pass : passes) {
207  fpm->add(pass);
208  }
209  fpm->doInitialization();
210 
211  auto F = get_function_type(context.get());
212 
213  // Add a basic block to the function. As before, it automatically
214  // inserts
215  // because of the last argument.
216  llvm::BasicBlock *BB = llvm::BasicBlock::Create(*context, "EntryBlock", F);
217 
218  // Create a basic block builder with default parameters. The builder
219  // will
220  // automatically append instructions to the basic block `BB'.
221  llvm::IRBuilder<> _builder(BB);
222  builder = reinterpret_cast<IRBuilder *>(&_builder);
223  builder->SetInsertPoint(BB);
224  auto fmf = llvm::FastMathFlags();
225  // fmf.setUnsafeAlgebra();
226  builder->setFastMathFlags(fmf);
227 
228  // Load all the symbols and create references
229  auto input_arg = &(*(F->args().begin()));
230  for (unsigned i = 0; i < inputs.size(); i++) {
231  if (not is_a<Symbol>(*inputs[i])) {
232  throw SymEngineException("Input contains a non-symbol.");
233  }
234  auto index
235  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
236  auto ptr = builder->CreateGEP(get_float_type(context.get()), input_arg,
237  index);
238  result_ = builder->CreateLoad(get_float_type(context.get()), ptr);
239  symbol_ptrs.push_back(result_);
240  }
241 
242  auto it = F->args().begin();
243 #if (LLVM_VERSION_MAJOR < 5)
244  auto out = &(*(++it));
245 #else
246  auto out = &(*(it + 1));
247 #endif
248  std::vector<llvm::Value *> output_vals;
249 
250  if (symbolic_cse) {
251  vec_basic reduced_exprs;
252  vec_pair replacements;
253  // cse the outputs
254  SymEngine::cse(replacements, reduced_exprs, outputs);
255  for (auto &rep : replacements) {
256  // Store the replacement symbol values in a dictionary
257  replacement_symbol_ptrs[rep.first] = apply(*(rep.second));
258  }
259  // Generate IR for all the reduced exprs and save references
260  for (unsigned i = 0; i < outputs.size(); i++) {
261  output_vals.push_back(apply(*reduced_exprs[i]));
262  }
263  } else {
264  // Generate IR for all the output exprs and save references
265  for (unsigned i = 0; i < outputs.size(); i++) {
266  output_vals.push_back(apply(*outputs[i]));
267  }
268  }
269 
270  // Store all the output exprs at the end
271  for (unsigned i = 0; i < outputs.size(); i++) {
272  auto index
273  = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), i);
274  auto ptr
275  = builder->CreateGEP(get_float_type(context.get()), out, index);
276  builder->CreateStore(output_vals[i], ptr);
277  }
278 
279  // Create the return instruction and add it to the basic block
280  builder->CreateRetVoid();
281 
282  // Validate the generated code, checking for consistency.
283  llvm::verifyFunction(*F, &llvm::outs());
284 
285  // std::cout << "LLVM IR" << std::endl;
286  // #if (LLVM_VERSION_MAJOR < 5)
287  // module->dump();
288  // #else
289  // module->print(llvm::errs(), nullptr);
290  // #endif
291 
292  // Optimize the function.
293  fpm->run(*F);
294 
295  // std::cout << "Optimized LLVM IR" << std::endl;
296  // module->dump();
297 
298  // Now we create the JIT.
299  std::string error;
300  executionengine = std::shared_ptr<llvm::ExecutionEngine>(
301  llvm::EngineBuilder(std::move(module))
302  .setEngineKind(llvm::EngineKind::Kind::JIT)
303  .setOptLevel(static_cast<llvm::CodeGenOpt::Level>(opt_level))
304  .setErrorStr(&error)
305  .create());
306 
307  // This is a hack to get the MemoryBuffer of a compiled object.
308  class MemoryBufferRefCallback : public llvm::ObjectCache
309  {
310  public:
311  std::string &ss_;
312  MemoryBufferRefCallback(std::string &ss) : ss_(ss) {}
313 
314  virtual void notifyObjectCompiled(const llvm::Module *M,
315  llvm::MemoryBufferRef obj)
316  {
317  const char *c = obj.getBufferStart();
318  // Saving the object code in a std::string
319  ss_.assign(c, obj.getBufferSize());
320  }
321 
323  getObject(const llvm::Module *M)
324  {
325  return NULL;
326  }
327  };
328 
329  MemoryBufferRefCallback callback(membuffer);
330  executionengine->setObjectCache(&callback);
331  // std::cout << error << std::endl;
332  executionengine->finalizeObject();
333 
334  // Get the symbol's address
335  func = (intptr_t)executionengine->getPointerToFunction(F);
336  symbol_ptrs.clear();
337  replacement_symbol_ptrs.clear();
338  symbols.clear();
339 }
340 
341 double LLVMDoubleVisitor::call(const std::vector<double> &vec) const
342 {
343  double ret;
344  ((double (*)(const double *, double *))func)(vec.data(), &ret);
345  return ret;
346 }
347 
348 void LLVMDoubleVisitor::call(double *outs, const double *inps) const
349 {
350  ((double (*)(const double *, double *))func)(inps, outs);
351 }
352 
353 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
354 long double
355 LLVMLongDoubleVisitor::call(const std::vector<long double> &vec) const
356 {
357  long double ret;
358  ((long double (*)(const long double *, long double *))func)(vec.data(),
359  &ret);
360  return ret;
361 }
362 
363 void LLVMLongDoubleVisitor::call(long double *outs,
364  const long double *inps) const
365 {
366  ((long double (*)(const long double *, long double *))func)(inps, outs);
367 }
368 #endif
369 
370 float LLVMFloatVisitor::call(const std::vector<float> &vec) const
371 {
372  float ret;
373  ((float (*)(const float *, float *))func)(vec.data(), &ret);
374  return ret;
375 }
376 
377 void LLVMFloatVisitor::call(float *outs, const float *inps) const
378 {
379  ((float (*)(const float *, float *))func)(inps, outs);
380 }
381 
382 void LLVMVisitor::set_double(double d)
383 {
384  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()), d);
385 }
386 
387 void LLVMVisitor::bvisit(const Integer &x)
388 {
389  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
390  mp_get_d(x.as_integer_class()));
391 }
392 
393 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
394 void LLVMLongDoubleVisitor::convert_from_mpfr(const Basic &x)
395 {
396 #ifndef HAVE_SYMENGINE_MPFR
397  throw NotImplementedError("Cannot convert to long double without MPFR");
398 #else
399  RCP<const Basic> m = evalf(x, 128, EvalfDomain::Real);
400  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
401  m->__str__());
402 #endif
403 }
404 
405 void LLVMLongDoubleVisitor::visit(const Integer &x)
406 {
407  result_ = llvm::ConstantFP::get(get_float_type(&mod->getContext()),
408  x.__str__());
409 }
410 #endif
411 
412 void LLVMVisitor::bvisit(const Rational &x)
413 {
414  set_double(mp_get_d(x.as_rational_class()));
415 }
416 
417 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
418 void LLVMLongDoubleVisitor::visit(const Rational &x)
419 {
420  convert_from_mpfr(x);
421 }
422 #endif
423 
424 void LLVMVisitor::bvisit(const RealDouble &x)
425 {
426  set_double(x.i);
427 }
428 
429 #ifdef HAVE_SYMENGINE_MPFR
430 void LLVMVisitor::bvisit(const RealMPFR &x)
431 {
432  set_double(mpfr_get_d(x.i.get_mpfr_t(), MPFR_RNDN));
433 }
434 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
435 void LLVMLongDoubleVisitor::visit(const RealMPFR &x)
436 {
437  convert_from_mpfr(x);
438 }
439 #endif
440 #endif
441 
442 void LLVMVisitor::bvisit(const Add &x)
443 {
444  llvm::Value *tmp, *tmp1, *tmp2;
445  auto it = x.get_dict().begin();
446 
447  if (eq(*x.get_coef(), *zero)) {
448  // `x + 0.0` is not optimized out
449  if (eq(*one, *(it->second))) {
450  tmp = apply(*(it->first));
451  } else {
452  tmp1 = apply(*(it->first));
453  tmp2 = apply(*(it->second));
454  tmp = builder->CreateFMul(tmp1, tmp2);
455  }
456  ++it;
457  } else {
458  tmp = apply(*x.get_coef());
459  }
460 
461  for (; it != x.get_dict().end(); ++it) {
462  if (eq(*one, *(it->second))) {
463  tmp1 = apply(*(it->first));
464  tmp = builder->CreateFAdd(tmp, tmp1);
465  } else {
466  // std::vector<llvm::Value *> args({tmp1, tmp2, tmp});
467  // tmp =
468  // builder->CreateCall(get_float_intrinsic(get_float_type(&mod->getContext()),
469  // llvm::Intrinsic::fma,
470  // 3, context), args);
471  tmp1 = apply(*(it->first));
472  tmp2 = apply(*(it->second));
473  tmp = builder->CreateFAdd(tmp, builder->CreateFMul(tmp1, tmp2));
474  }
475  }
476  result_ = tmp;
477 }
478 
479 void LLVMVisitor::bvisit(const Mul &x)
480 {
481  llvm::Value *tmp = nullptr;
482  bool first = true;
483  for (const auto &p : x.get_args()) {
484  if (first) {
485  tmp = apply(*p);
486  } else {
487  tmp = builder->CreateFMul(tmp, apply(*p));
488  }
489  first = false;
490  }
491  result_ = tmp;
492 }
493 
494 llvm::Function *LLVMVisitor::get_powi()
495 {
496  std::vector<llvm::Type *> arg_type;
497  arg_type.push_back(get_float_type(&mod->getContext()));
498 #if (LLVM_VERSION_MAJOR > 12)
499  arg_type.push_back(llvm::Type::getInt32Ty(mod->getContext()));
500 #endif
501  return llvm::Intrinsic::getDeclaration(mod, llvm::Intrinsic::powi,
502  arg_type);
503 }
504 
505 llvm::Function *get_float_intrinsic(llvm::Type *type, llvm::Intrinsic::ID id,
506  unsigned n, llvm::Module *mod)
507 {
508  std::vector<llvm::Type *> arg_type(n, type);
509  return llvm::Intrinsic::getDeclaration(mod, id, arg_type);
510 }
511 
512 void LLVMVisitor::bvisit(const Pow &x)
513 {
515  llvm::Function *fun;
516  if (eq(*(x.get_base()), *E)) {
517  args.push_back(apply(*x.get_exp()));
518  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
519  llvm::Intrinsic::exp, 1, mod);
520 
521  } else if (eq(*(x.get_base()), *integer(2))) {
522  args.push_back(apply(*x.get_exp()));
523  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
524  llvm::Intrinsic::exp2, 1, mod);
525 
526  } else {
527  if (is_a<Integer>(*x.get_exp())) {
528  if (eq(*x.get_exp(), *integer(2))) {
529  llvm::Value *tmp = apply(*x.get_base());
530  result_ = builder->CreateFMul(tmp, tmp);
531  return;
532  } else {
533  args.push_back(apply(*x.get_base()));
534  int d = numeric_cast<int>(
535  mp_get_si(static_cast<const Integer &>(*x.get_exp())
536  .as_integer_class()));
537  result_ = llvm::ConstantInt::get(
538  llvm::Type::getInt32Ty(mod->getContext()), d, true);
539  args.push_back(result_);
540  fun = get_powi();
541  }
542  } else {
543  args.push_back(apply(*x.get_base()));
544  args.push_back(apply(*x.get_exp()));
545  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
546  llvm::Intrinsic::pow, 1, mod);
547  }
548  }
549  auto r = builder->CreateCall(fun, args);
550  r->setTailCall(true);
551  result_ = r;
552 }
553 
554 void LLVMVisitor::bvisit(const Sin &x)
555 {
557  llvm::Function *fun;
558  args.push_back(apply(*x.get_arg()));
559  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
560  llvm::Intrinsic::sin, 1, mod);
561  auto r = builder->CreateCall(fun, args);
562  r->setTailCall(true);
563  result_ = r;
564 }
565 
566 void LLVMVisitor::bvisit(const Cos &x)
567 {
569  llvm::Function *fun;
570  args.push_back(apply(*x.get_arg()));
571  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
572  llvm::Intrinsic::cos, 1, mod);
573  auto r = builder->CreateCall(fun, args);
574  r->setTailCall(true);
575  result_ = r;
576 }
577 
578 void LLVMVisitor::bvisit(const Piecewise &x)
579 {
581 
582  RCP<const Piecewise> pw = x.rcp_from_this_cast<const Piecewise>();
583 
584  if (neq(*pw->get_vec().back().second, *boolTrue)) {
585  throw SymEngineException(
586  "LLVMDouble requires a (Expr, True) at the end of Piecewise");
587  }
588 
589  if (pw->get_vec().size() > 2) {
590  PiecewiseVec rest = pw->get_vec();
591  rest.erase(rest.begin());
592  auto rest_pw = piecewise(std::move(rest));
593  PiecewiseVec new_pw;
594  new_pw.push_back(*pw->get_vec().begin());
595  new_pw.push_back({rest_pw, pw->get_vec().back().second});
596  pw = piecewise(std::move(new_pw))
597  ->rcp_from_this_cast<const Piecewise>();
598  } else if (pw->get_vec().size() < 2) {
599  throw SymEngineException("Invalid Piecewise object");
600  }
601 
602  auto cond_basic = pw->get_vec().front().second;
603  llvm::Value *cond = apply(*cond_basic);
604  // check if cond != 0.0
605  cond = builder->CreateFCmpONE(
606  cond, llvm::ConstantFP::get(get_float_type(&mod->getContext()), 0.0),
607  "ifcond");
608  llvm::Function *function = builder->GetInsertBlock()->getParent();
609  // Create blocks for the then and else cases. Insert the 'then' block at
610  // the
611  // end of the function.
612  llvm::BasicBlock *then_bb
613  = llvm::BasicBlock::Create(mod->getContext(), "then", function);
614  llvm::BasicBlock *else_bb
615  = llvm::BasicBlock::Create(mod->getContext(), "else");
616  llvm::BasicBlock *merge_bb
617  = llvm::BasicBlock::Create(mod->getContext(), "ifcont");
618  builder->CreateCondBr(cond, then_bb, else_bb);
619 
620  // Emit then value.
621  builder->SetInsertPoint(then_bb);
622  llvm::Value *then_value = apply(*pw->get_vec().front().first);
623  builder->CreateBr(merge_bb);
624 
625  // Codegen of 'then_value' can change the current block, update then_bb for
626  // the PHI.
627  then_bb = builder->GetInsertBlock();
628 
629  // Emit else block.
630  function->getBasicBlockList().push_back(else_bb);
631  builder->SetInsertPoint(else_bb);
632  llvm::Value *else_value = apply(*pw->get_vec().back().first);
633  builder->CreateBr(merge_bb);
634 
635  // Codegen of 'else_value' can change the current block, update else_bb for
636  // the PHI.
637  else_bb = builder->GetInsertBlock();
638 
639  // Emit merge block.
640  function->getBasicBlockList().push_back(merge_bb);
641  builder->SetInsertPoint(merge_bb);
642  llvm::PHINode *phi_node
643  = builder->CreatePHI(get_float_type(&mod->getContext()), 2);
644 
645  phi_node->addIncoming(then_value, then_bb);
646  phi_node->addIncoming(else_value, else_bb);
647  result_ = phi_node;
648 }
649 
650 void LLVMVisitor::bvisit(const Sign &x)
651 {
652  const auto x2 = x.get_arg();
653  PiecewiseVec new_pw;
654  new_pw.push_back({real_double(0.0), Eq(x2, real_double(0.0))});
655  new_pw.push_back({real_double(-1.0), Lt(x2, real_double(0.0))});
656  new_pw.push_back({real_double(1.0), boolTrue});
657  auto pw = rcp_static_cast<const Piecewise>(piecewise(std::move(new_pw)));
658  bvisit(*pw);
659 }
660 
661 void LLVMVisitor::bvisit(const Contains &cts)
662 {
663  llvm::Value *expr = apply(*cts.get_expr());
664  const auto set = cts.get_set();
665  if (is_a<Interval>(*set)) {
666  const auto &interv = down_cast<const Interval &>(*set);
667  llvm::Value *start = apply(*interv.get_start());
668  llvm::Value *end = apply(*interv.get_end());
669  const bool left_open = interv.get_left_open();
670  const bool right_open = interv.get_right_open();
671  llvm::Value *left_ok;
672  llvm::Value *right_ok;
673  left_ok = (left_open) ? builder->CreateFCmpOLT(start, expr)
674  : builder->CreateFCmpOLE(start, expr);
675  right_ok = (right_open) ? builder->CreateFCmpOLT(expr, end)
676  : builder->CreateFCmpOLE(expr, end);
677  result_ = builder->CreateAnd(left_ok, right_ok);
678  result_ = builder->CreateUIToFP(result_,
679  get_float_type(&mod->getContext()));
680  } else {
681  throw SymEngineException("LLVMVisitor: only ``Interval`` "
682  "implemented for ``Contains``.");
683  }
684 }
685 
686 void LLVMVisitor::bvisit(const Infty &x)
687 {
688  if (x.is_negative_infinity()) {
689  result_ = llvm::ConstantFP::getInfinity(
690  get_float_type(&mod->getContext()), true);
691  } else if (x.is_positive_infinity()) {
692  result_ = llvm::ConstantFP::getInfinity(
693  get_float_type(&mod->getContext()), false);
694  } else {
695  throw SymEngineException(
696  "LLVMDouble can only represent real valued infinity");
697  }
698 }
699 
700 void LLVMVisitor::bvisit(const BooleanAtom &x)
701 {
702  const bool val = x.get_val();
703  set_double(val ? 1.0 : 0.0);
704 }
705 
706 void LLVMVisitor::bvisit(const Log &x)
707 {
709  llvm::Function *fun;
710  args.push_back(apply(*x.get_arg()));
711  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
712  llvm::Intrinsic::log, 1, mod);
713  auto r = builder->CreateCall(fun, args);
714  r->setTailCall(true);
715  result_ = r;
716 }
717 
718 #define SYMENGINE_LOGIC_FUNCTION(Class, method) \
719  void LLVMVisitor::bvisit(const Class &x) \
720  { \
721  llvm::Value *value = nullptr; \
722  llvm::Value *tmp; \
723  set_double(0.0); \
724  llvm::Value *zero_val = result_; \
725  for (auto &p : x.get_container()) { \
726  tmp = builder->CreateFCmpONE(apply(*p), zero_val); \
727  if (value == nullptr) { \
728  value = tmp; \
729  } else { \
730  value = builder->method(value, tmp); \
731  } \
732  } \
733  result_ = builder->CreateUIToFP(value, \
734  get_float_type(&mod->getContext())); \
735  }
736 
737 SYMENGINE_LOGIC_FUNCTION(And, CreateAnd);
738 SYMENGINE_LOGIC_FUNCTION(Or, CreateOr);
739 SYMENGINE_LOGIC_FUNCTION(Xor, CreateXor);
740 
741 void LLVMVisitor::bvisit(const Not &x)
742 {
743  builder->CreateNot(apply(*x.get_arg()));
744 }
745 
746 #define SYMENGINE_RELATIONAL_FUNCTION(Class, method) \
747  void LLVMVisitor::bvisit(const Class &x) \
748  { \
749  llvm::Value *left = apply(*x.get_arg1()); \
750  llvm::Value *right = apply(*x.get_arg2()); \
751  result_ = builder->method(left, right); \
752  result_ = builder->CreateUIToFP(result_, \
753  get_float_type(&mod->getContext())); \
754  }
755 
756 SYMENGINE_RELATIONAL_FUNCTION(Equality, CreateFCmpOEQ);
757 SYMENGINE_RELATIONAL_FUNCTION(Unequality, CreateFCmpONE);
758 SYMENGINE_RELATIONAL_FUNCTION(LessThan, CreateFCmpOLE);
759 SYMENGINE_RELATIONAL_FUNCTION(StrictLessThan, CreateFCmpOLT);
760 
761 #define _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
762  void LLVMDoubleVisitor::visit(const Class &x) \
763  { \
764  vec_basic basic_args = x.get_args(); \
765  llvm::Function *func = get_external_function(#ext, basic_args.size()); \
766  std::vector<llvm::Value *> args; \
767  for (const auto &arg : basic_args) { \
768  args.push_back(apply(*arg)); \
769  } \
770  auto r = builder->CreateCall(func, args); \
771  r->setTailCall(true); \
772  result_ = r; \
773  } \
774  void LLVMFloatVisitor::visit(const Class &x) \
775  { \
776  vec_basic basic_args = x.get_args(); \
777  llvm::Function *func = get_external_function(#ext + std::string("f"), \
778  basic_args.size()); \
779  std::vector<llvm::Value *> args; \
780  for (const auto &arg : basic_args) { \
781  args.push_back(apply(*arg)); \
782  } \
783  auto r = builder->CreateCall(func, args); \
784  r->setTailCall(true); \
785  result_ = r; \
786  }
787 
788 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
789 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
790  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
791  void LLVMLongDoubleVisitor::visit(const Class &x) \
792  { \
793  vec_basic basic_args = x.get_args(); \
794  llvm::Function *func = get_external_function(#ext + std::string("l"), \
795  basic_args.size()); \
796  std::vector<llvm::Value *> args; \
797  for (const auto &arg : basic_args) { \
798  args.push_back(apply(*arg)); \
799  } \
800  auto r = builder->CreateCall(func, args); \
801  r->setTailCall(true); \
802  result_ = r; \
803  }
804 #else
805 #define SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext) \
806  _SYMENGINE_MACRO_EXTERNAL_FUNCTION(Class, ext)
807 #endif
808 
809 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tan, tan)
810 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASin, asin)
811 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACos, acos)
812 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan, atan)
813 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATan2, atan2)
814 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Sinh, sinh)
815 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Cosh, cosh)
816 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Tanh, tanh)
817 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ASinh, asinh)
818 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ACosh, acosh)
819 SYMENGINE_MACRO_EXTERNAL_FUNCTION(ATanh, atanh)
820 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Gamma, tgamma)
821 SYMENGINE_MACRO_EXTERNAL_FUNCTION(LogGamma, lgamma)
822 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erf, erf)
823 SYMENGINE_MACRO_EXTERNAL_FUNCTION(Erfc, erfc)
824 
825 void LLVMVisitor::bvisit(const Abs &x)
826 {
828  llvm::Function *fun;
829  args.push_back(apply(*x.get_arg()));
830  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
831  llvm::Intrinsic::fabs, 1, mod);
832  auto r = builder->CreateCall(fun, args);
833  r->setTailCall(true);
834  result_ = r;
835 }
836 
837 void LLVMVisitor::bvisit(const Min &x)
838 {
839  llvm::Value *value = nullptr;
840  llvm::Function *fun;
841  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
842  llvm::Intrinsic::minnum, 1, mod);
843  for (auto &arg : x.get_vec()) {
844  if (value != nullptr) {
846  args.push_back(value);
847  args.push_back(apply(*arg));
848  auto r = builder->CreateCall(fun, args);
849  r->setTailCall(true);
850  value = r;
851  } else {
852  value = apply(*arg);
853  }
854  }
855  result_ = value;
856 }
857 
858 void LLVMVisitor::bvisit(const Max &x)
859 {
860  llvm::Value *value = nullptr;
861  llvm::Function *fun;
862  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
863  llvm::Intrinsic::maxnum, 1, mod);
864  for (auto &arg : x.get_vec()) {
865  if (value != nullptr) {
867  args.push_back(value);
868  args.push_back(apply(*arg));
869  auto r = builder->CreateCall(fun, args);
870  r->setTailCall(true);
871  value = r;
872  } else {
873  value = apply(*arg);
874  }
875  }
876  result_ = value;
877 }
878 
879 void LLVMVisitor::bvisit(const Symbol &x)
880 {
881  unsigned i = 0;
882  for (auto &symb : symbols) {
883  if (eq(x, *symb)) {
884  result_ = symbol_ptrs[i];
885  return;
886  }
887  ++i;
888  }
889  auto it = replacement_symbol_ptrs.find(x.rcp_from_this());
890  if (it != replacement_symbol_ptrs.end()) {
891  result_ = it->second;
892  return;
893  }
894 
895  throw SymEngineException("Symbol " + x.__str__()
896  + " not in the symbols vector.");
897 }
898 
899 llvm::Function *LLVMVisitor::get_external_function(const std::string &name,
900  size_t nargs)
901 {
902  std::vector<llvm::Type *> func_args(nargs,
903  get_float_type(&mod->getContext()));
904  llvm::FunctionType *func_type = llvm::FunctionType::get(
905  get_float_type(&mod->getContext()), func_args, /*isVarArgs=*/false);
906 
907  llvm::Function *func = mod->getFunction(name);
908  if (!func) {
909  func = llvm::Function::Create(
910  func_type, llvm::GlobalValue::ExternalLinkage, name, mod);
911  func->setCallingConv(llvm::CallingConv::C);
912  }
913 #if (LLVM_VERSION_MAJOR < 5)
914  llvm::AttributeSet func_attr_set;
915  {
916  llvm::SmallVector<llvm::AttributeSet, 4> attrs;
917  llvm::AttributeSet attr_set;
918  {
919  llvm::AttrBuilder attr_builder;
920  attr_builder.addAttribute(llvm::Attribute::NoUnwind);
921  attr_set
922  = llvm::AttributeSet::get(mod->getContext(), ~0U, attr_builder);
923  }
924 
925  attrs.push_back(attr_set);
926  func_attr_set = llvm::AttributeSet::get(mod->getContext(), attrs);
927  }
928  func->setAttributes(func_attr_set);
929 #else
930  func->addFnAttr(llvm::Attribute::NoUnwind);
931 #endif
932  return func;
933 }
934 
935 void LLVMVisitor::bvisit(const Constant &x)
936 {
937  set_double(eval_double(x));
938 }
939 
940 #ifdef SYMENGINE_HAVE_LLVM_LONG_DOUBLE
941 void LLVMLongDoubleVisitor::visit(const Constant &x)
942 {
943  convert_from_mpfr(x);
944 }
945 #endif
946 
947 void LLVMVisitor::bvisit(const Basic &x)
948 {
949  throw NotImplementedError(x.__str__());
950 }
951 
952 const std::string &LLVMVisitor::dumps() const
953 {
954  return membuffer;
955 };
956 
957 void LLVMVisitor::loads(const std::string &s)
958 {
959  membuffer = s;
960  llvm::InitializeNativeTarget();
961  llvm::InitializeNativeTargetAsmPrinter();
962  llvm::InitializeNativeTargetAsmParser();
963  context = std::make_shared<llvm::LLVMContext>();
964 
965  // Create some module to put our function into it.
967  = make_unique<llvm::Module>("SymEngine", *context);
968  module->setDataLayout("");
969  mod = module.get();
970 
971  // Only defining the prototype for the function here.
972  // Since we know where the function is stored that's enough
973  // llvm::ObjectCache is designed for caching objects, but it
974  // is used here for loading one specific object.
975  auto F = get_function_type(context.get());
976 
977  std::string error;
978  executionengine = std::shared_ptr<llvm::ExecutionEngine>(
979  llvm::EngineBuilder(std::move(module))
980  .setEngineKind(llvm::EngineKind::Kind::JIT)
981  .setOptLevel(llvm::CodeGenOpt::Level::Aggressive)
982  .setErrorStr(&error)
983  .create());
984 
985  class MCJITObjectLoader : public llvm::ObjectCache
986  {
987  const std::string &s_;
988 
989  public:
990  MCJITObjectLoader(const std::string &s) : s_(s) {}
991  virtual void notifyObjectCompiled(const llvm::Module *M,
992  llvm::MemoryBufferRef obj)
993  {
994  }
995 
996  // No need to check M because there is only one function
997  // Return it after reading from the file.
999  getObject(const llvm::Module *M)
1000  {
1001  return llvm::MemoryBuffer::getMemBufferCopy(llvm::StringRef(s_));
1002  }
1003  };
1004 
1005  MCJITObjectLoader loader(s);
1006  executionengine->setObjectCache(&loader);
1007  executionengine->finalizeObject();
1008  // Set func to compiled function pointer
1009  func = (intptr_t)executionengine->getPointerToFunction(F);
1010 }
1011 
1012 void LLVMVisitor::bvisit(const Floor &x)
1013 {
1015  llvm::Function *fun;
1016  args.push_back(apply(*x.get_arg()));
1017  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1018  llvm::Intrinsic::floor, 1, mod);
1019  auto r = builder->CreateCall(fun, args);
1020  r->setTailCall(true);
1021  result_ = r;
1022 }
1023 
1024 void LLVMVisitor::bvisit(const Ceiling &x)
1025 {
1027  llvm::Function *fun;
1028  args.push_back(apply(*x.get_arg()));
1029  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1030  llvm::Intrinsic::ceil, 1, mod);
1031  auto r = builder->CreateCall(fun, args);
1032  r->setTailCall(true);
1033  result_ = r;
1034 }
1035 
1036 void LLVMVisitor::bvisit(const UnevaluatedExpr &x)
1037 {
1038  apply(*x.get_arg());
1039 }
1040 
1041 void LLVMVisitor::bvisit(const Truncate &x)
1042 {
1044  llvm::Function *fun;
1045  args.push_back(apply(*x.get_arg()));
1046  fun = get_float_intrinsic(get_float_type(&mod->getContext()),
1047  llvm::Intrinsic::trunc, 1, mod);
1048  auto r = builder->CreateCall(fun, args);
1049  r->setTailCall(true);
1050  result_ = r;
1051 }
1052 
1053 llvm::Type *LLVMDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1054 {
1055  return llvm::Type::getDoubleTy(*context);
1056 }
1057 
1058 llvm::Type *LLVMFloatVisitor::get_float_type(llvm::LLVMContext *context)
1059 {
1060  return llvm::Type::getFloatTy(*context);
1061 }
1062 
1063 #if defined(SYMENGINE_HAVE_LLVM_LONG_DOUBLE)
1064 llvm::Type *LLVMLongDoubleVisitor::get_float_type(llvm::LLVMContext *context)
1065 {
1066  return llvm::Type::getX86_FP80Ty(*context);
1067 }
1068 #endif
1069 
1070 } // namespace SymEngine
T assign(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:95
T data(T... args)
T end(T... args)
T erase(T... args)
T get(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > acos(const RCP< const Basic > &arg)
Canonicalize ACos:
Definition: functions.cpp:1359
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:200
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Boolean > Lt(const RCP< const Basic > &lhs, const RCP< const Basic > &rhs)
Returns the canonicalized StrictLessThan object from the arguments.
Definition: logic.cpp:707
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:64
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
Definition: functions.cpp:1571
RCP< const Basic > asin(const RCP< const Basic > &arg)
Canonicalize ASin:
Definition: functions.cpp:1317
RCP< const Basic > tan(const RCP< const Basic > &arg)
Canonicalize Tan:
Definition: functions.cpp:964
RCP< const Basic > cosh(const RCP< const Basic > &arg)
Canonicalize Cosh:
Definition: functions.cpp:2169
RCP< const Basic > atan(const RCP< const Basic > &arg)
Canonicalize ATan:
Definition: functions.cpp:1481
RCP< const Basic > asinh(const RCP< const Basic > &arg)
Canonicalize ASinh:
Definition: functions.cpp:2333
RCP< const Basic > tanh(const RCP< const Basic > &arg)
Canonicalize Tanh:
Definition: functions.cpp:2247
RCP< const Basic > atanh(const RCP< const Basic > &arg)
Canonicalize ATanh:
Definition: functions.cpp:2451
RCP< const Basic > erfc(const RCP< const Basic > &arg)
Canonicalize Erfc:
Definition: functions.cpp:2884
RCP< const Basic > acosh(const RCP< const Basic > &arg)
Canonicalize ACosh:
Definition: functions.cpp:2418
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
Definition: logic.cpp:592
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > erf(const RCP< const Basic > &arg)
Canonicalize Erf:
Definition: functions.cpp:2848
RCP< const Basic > sinh(const RCP< const Basic > &arg)
Canonicalize Sinh:
Definition: functions.cpp:2084
T push_back(T... args)