llvm_double.h
1 #ifndef SYMENGINE_LLVM_DOUBLE_H
2 #define SYMENGINE_LLVM_DOUBLE_H
3 
4 #include <symengine/basic.h>
5 #include <symengine/visitor.h>
6 #include <float.h>
7 
8 #ifdef HAVE_SYMENGINE_LLVM
9 
10 // Forward declare llvm types
11 namespace llvm
12 {
13 class Module;
14 class Value;
15 class Type;
16 class Function;
17 class ExecutionEngine;
18 class MemoryBufferRef;
19 class LLVMContext;
20 class Pass;
21 namespace legacy
22 {
23 class FunctionPassManager;
24 }
25 } // namespace llvm
26 
27 namespace SymEngine
28 {
29 
30 class IRBuilder;
31 
32 class LLVMVisitor : public BaseVisitor<LLVMVisitor>
33 {
34 protected:
35  vec_basic symbols;
36  std::vector<llvm::Value *> symbol_ptrs;
37  std::map<RCP<const Basic>, llvm::Value *, RCPBasicKeyLess>
38  replacement_symbol_ptrs;
39  llvm::Value *result_;
43  intptr_t func;
44 
45  // Following are invalid after the init call.
46  IRBuilder *builder;
47  llvm::Module *mod;
48  std::string membuffer;
49  llvm::Function *get_function_type(llvm::LLVMContext *);
50  virtual llvm::Type *get_float_type(llvm::LLVMContext *) = 0;
51 
52 public:
53  llvm::Value *apply(const Basic &b);
54  void init(const vec_basic &x, const Basic &b,
55  const bool symbolic_cse = false, unsigned opt_level = 3);
56  void init(const vec_basic &x, const Basic &b, const bool symbolic_cse,
57  const std::vector<llvm::Pass *> &passes, unsigned opt_level = 3);
58  void init(const vec_basic &inputs, const vec_basic &outputs,
59  const bool symbolic_cse = false, unsigned opt_level = 3);
60  void init(const vec_basic &inputs, const vec_basic &outputs,
61  const bool symbolic_cse, const std::vector<llvm::Pass *> &passes,
62  unsigned opt_level = 3);
63 
64  static std::vector<llvm::Pass *> create_default_passes(int optlevel);
65 
66  // Helper functions
67  void set_double(double d);
68  llvm::Function *get_external_function(const std::string &name,
69  size_t nargs = 1);
70  llvm::Function *get_powi();
71 
72  void bvisit(const Integer &x);
73  void bvisit(const Rational &x);
74  void bvisit(const RealDouble &x);
75 #ifdef HAVE_SYMENGINE_MPFR
76  void bvisit(const RealMPFR &x);
77 #endif
78  void bvisit(const Add &x);
79  void bvisit(const Mul &x);
80  void bvisit(const Pow &x);
81  void bvisit(const Log &x);
82  void bvisit(const Abs &x);
83  void bvisit(const Symbol &x);
84  void bvisit(const Constant &x);
85  void bvisit(const Basic &);
86  void bvisit(const Sin &x);
87  void bvisit(const Cos &x);
88  void bvisit(const Piecewise &x);
89  void bvisit(const BooleanAtom &x);
90  void bvisit(const And &x);
91  void bvisit(const Or &x);
92  void bvisit(const Xor &x);
93  void bvisit(const Not &x);
94  void bvisit(const Equality &x);
95  void bvisit(const Unequality &x);
96  void bvisit(const LessThan &x);
97  void bvisit(const StrictLessThan &x);
98  void bvisit(const Max &x);
99  void bvisit(const Min &x);
100  void bvisit(const Contains &x);
101  void bvisit(const Infty &x);
102  void bvisit(const Floor &x);
103  void bvisit(const Ceiling &x);
104  void bvisit(const Truncate &x);
105  void bvisit(const Sign &x);
106  // Return the compiled function as a binary string which can be loaded using
107  // `load`
108  const std::string &dumps() const;
109  // Load a previously compiled function from a string
110  void loads(const std::string &s);
111  void bvisit(const UnevaluatedExpr &x);
112 };
113 
114 class LLVMDoubleVisitor : public LLVMVisitor
115 {
116 public:
117  double call(const std::vector<double> &vec) const;
118  void call(double *outs, const double *inps) const;
119  llvm::Type *get_float_type(llvm::LLVMContext *) override;
120  void visit(const Tan &x) override;
121  void visit(const ASin &x) override;
122  void visit(const ACos &x) override;
123  void visit(const ATan &x) override;
124  void visit(const ATan2 &x) override;
125  void visit(const Sinh &x) override;
126  void visit(const Cosh &x) override;
127  void visit(const Tanh &x) override;
128  void visit(const ASinh &x) override;
129  void visit(const ACosh &x) override;
130  void visit(const ATanh &x) override;
131  void visit(const Gamma &x) override;
132  void visit(const LogGamma &x) override;
133  void visit(const Erf &x) override;
134  void visit(const Erfc &x) override;
135 };
136 
137 class LLVMFloatVisitor : public LLVMVisitor
138 {
139 public:
140  float call(const std::vector<float> &vec) const;
141  void call(float *outs, const float *inps) const;
142  llvm::Type *get_float_type(llvm::LLVMContext *) override;
143  void visit(const Tan &x) override;
144  void visit(const ASin &x) override;
145  void visit(const ACos &x) override;
146  void visit(const ATan &x) override;
147  void visit(const ATan2 &x) override;
148  void visit(const Sinh &x) override;
149  void visit(const Cosh &x) override;
150  void visit(const Tanh &x) override;
151  void visit(const ASinh &x) override;
152  void visit(const ACosh &x) override;
153  void visit(const ATanh &x) override;
154  void visit(const Gamma &x) override;
155  void visit(const LogGamma &x) override;
156  void visit(const Erf &x) override;
157  void visit(const Erfc &x) override;
158 };
159 
160 #if SYMENGINE_SIZEOF_LONG_DOUBLE > 8 && defined(__x86_64__) || defined(__i386__)
161 #define SYMENGINE_HAVE_LLVM_LONG_DOUBLE 1
162 class LLVMLongDoubleVisitor : public LLVMVisitor
163 {
164 public:
165  long double call(const std::vector<long double> &vec) const;
166  void call(long double *outs, const long double *inps) const;
167  llvm::Type *get_float_type(llvm::LLVMContext *) override;
168  void visit(const Tan &x) override;
169  void visit(const ASin &x) override;
170  void visit(const ACos &x) override;
171  void visit(const ATan &x) override;
172  void visit(const ATan2 &x) override;
173  void visit(const Sinh &x) override;
174  void visit(const Cosh &x) override;
175  void visit(const Tanh &x) override;
176  void visit(const ASinh &x) override;
177  void visit(const ACosh &x) override;
178  void visit(const ATanh &x) override;
179  void visit(const Gamma &x) override;
180  void visit(const LogGamma &x) override;
181  void visit(const Erf &x) override;
182  void visit(const Erfc &x) override;
183  void visit(const Integer &x) override;
184  void visit(const Rational &x) override;
185  void convert_from_mpfr(const Basic &x);
186  void visit(const Constant &x) override;
187 #ifdef HAVE_SYMENGINE_MPFR
188  void visit(const RealMPFR &x) override;
189 #endif
190 };
191 #endif
192 
193 } // namespace SymEngine
194 #endif
195 #endif // SYMENGINE_LAMBDA_DOUBLE_H
The base class for SymEngine.
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:64