llvm_double.h
1 #ifndef SYMENGINE_LLVM_DOUBLE_H
2 #define SYMENGINE_LLVM_DOUBLE_H
3 
4 #include <symengine/basic.h>
5 #include <symengine/visitor.h>
6 #include <float.h>
7 
8 #ifdef HAVE_SYMENGINE_LLVM
9 
10 // Forward declare llvm types
11 namespace llvm
12 {
13 class Module;
14 class Value;
15 class Type;
16 class Function;
17 class ExecutionEngine;
18 class MemoryBufferRef;
19 class LLVMContext;
20 } // namespace llvm
21 
22 namespace SymEngine
23 {
24 
25 class IRBuilder;
26 
27 class LLVMVisitor : public BaseVisitor<LLVMVisitor>
28 {
29 protected:
30  vec_basic symbols;
31  std::vector<llvm::Value *> symbol_ptrs;
32  std::map<RCP<const Basic>, llvm::Value *, RCPBasicKeyLess>
33  replacement_symbol_ptrs;
34  llvm::Value *result_;
37 
38  intptr_t func;
39 
40  // Following are invalid after the init call.
41  IRBuilder *builder;
42  llvm::Module *mod;
43  std::string membuffer;
44  llvm::Function *get_function_type(llvm::LLVMContext *);
45  virtual llvm::Type *get_float_type(llvm::LLVMContext *) = 0;
46 
47 public:
48  LLVMVisitor();
49  ~LLVMVisitor() override;
50  llvm::Value *apply(const Basic &b);
51  void init(const vec_basic &x, const Basic &b,
52  const bool symbolic_cse = false, unsigned opt_level = 3);
53  void init(const vec_basic &inputs, const vec_basic &outputs,
54  const bool symbolic_cse = false, unsigned opt_level = 3);
55 
56  // Helper functions
57  void set_double(double d);
58  llvm::Function *get_external_function(const std::string &name,
59  size_t nargs = 1);
60  llvm::Function *get_powi();
61 
62  void bvisit(const Integer &x);
63  void bvisit(const Rational &x);
64  void bvisit(const RealDouble &x);
65 #ifdef HAVE_SYMENGINE_MPFR
66  void bvisit(const RealMPFR &x);
67 #endif
68  void bvisit(const Add &x);
69  void bvisit(const Mul &x);
70  void bvisit(const Pow &x);
71  void bvisit(const Log &x);
72  void bvisit(const Abs &x);
73  void bvisit(const Symbol &x);
74  void bvisit(const Constant &x);
75  void bvisit(const Basic &);
76  void bvisit(const Sin &x);
77  void bvisit(const Cos &x);
78  void bvisit(const Piecewise &x);
79  void bvisit(const BooleanAtom &x);
80  void bvisit(const And &x);
81  void bvisit(const Or &x);
82  void bvisit(const Xor &x);
83  void bvisit(const Not &x);
84  void bvisit(const Equality &x);
85  void bvisit(const Unequality &x);
86  void bvisit(const LessThan &x);
87  void bvisit(const StrictLessThan &x);
88  void bvisit(const Max &x);
89  void bvisit(const Min &x);
90  void bvisit(const Contains &x);
91  void bvisit(const Infty &x);
92  void bvisit(const NaN &x);
93  void bvisit(const Floor &x);
94  void bvisit(const Ceiling &x);
95  void bvisit(const Truncate &x);
96  void bvisit(const Sign &x);
97  // Return the compiled function as a binary string which can be loaded using
98  // `load`
99  const std::string &dumps() const;
100  // Load a previously compiled function from a string
101  void loads(const std::string &s);
102  void bvisit(const UnevaluatedExpr &x);
103 };
104 
105 class LLVMDoubleVisitor : public LLVMVisitor
106 {
107 public:
108  LLVMDoubleVisitor();
109  ~LLVMDoubleVisitor() override;
110  double call(const std::vector<double> &vec) const;
111  void call(double *outs, const double *inps) const;
112  llvm::Type *get_float_type(llvm::LLVMContext *) override;
113  void visit(const Tan &x) override;
114  void visit(const ASin &x) override;
115  void visit(const ACos &x) override;
116  void visit(const ATan &x) override;
117  void visit(const ATan2 &x) override;
118  void visit(const Sinh &x) override;
119  void visit(const Cosh &x) override;
120  void visit(const Tanh &x) override;
121  void visit(const ASinh &x) override;
122  void visit(const ACosh &x) override;
123  void visit(const ATanh &x) override;
124  void visit(const Gamma &x) override;
125  void visit(const LogGamma &x) override;
126  void visit(const Erf &x) override;
127  void visit(const Erfc &x) override;
128 };
129 
130 class LLVMFloatVisitor : public LLVMVisitor
131 {
132 public:
133  LLVMFloatVisitor();
134  ~LLVMFloatVisitor() override;
135  float call(const std::vector<float> &vec) const;
136  void call(float *outs, const float *inps) const;
137  llvm::Type *get_float_type(llvm::LLVMContext *) override;
138  void visit(const Tan &x) override;
139  void visit(const ASin &x) override;
140  void visit(const ACos &x) override;
141  void visit(const ATan &x) override;
142  void visit(const ATan2 &x) override;
143  void visit(const Sinh &x) override;
144  void visit(const Cosh &x) override;
145  void visit(const Tanh &x) override;
146  void visit(const ASinh &x) override;
147  void visit(const ACosh &x) override;
148  void visit(const ATanh &x) override;
149  void visit(const Gamma &x) override;
150  void visit(const LogGamma &x) override;
151  void visit(const Erf &x) override;
152  void visit(const Erfc &x) override;
153 };
154 
155 #if SYMENGINE_SIZEOF_LONG_DOUBLE > 8 && defined(__x86_64__) || defined(__i386__)
156 #define SYMENGINE_HAVE_LLVM_LONG_DOUBLE 1
157 class LLVMLongDoubleVisitor : public LLVMVisitor
158 {
159 public:
160  LLVMLongDoubleVisitor();
161  ~LLVMLongDoubleVisitor() override;
162  long double call(const std::vector<long double> &vec) const;
163  void call(long double *outs, const long double *inps) const;
164  llvm::Type *get_float_type(llvm::LLVMContext *) override;
165  void visit(const Tan &x) override;
166  void visit(const ASin &x) override;
167  void visit(const ACos &x) override;
168  void visit(const ATan &x) override;
169  void visit(const ATan2 &x) override;
170  void visit(const Sinh &x) override;
171  void visit(const Cosh &x) override;
172  void visit(const Tanh &x) override;
173  void visit(const ASinh &x) override;
174  void visit(const ACosh &x) override;
175  void visit(const ATanh &x) override;
176  void visit(const Gamma &x) override;
177  void visit(const LogGamma &x) override;
178  void visit(const Erf &x) override;
179  void visit(const Erfc &x) override;
180  void visit(const Integer &x) override;
181  void visit(const Rational &x) override;
182  void convert_from_mpfr(const Basic &x);
183  void visit(const Constant &x) override;
184 #ifdef HAVE_SYMENGINE_MPFR
185  void visit(const RealMPFR &x) override;
186 #endif
187 };
188 #endif
189 
190 } // namespace SymEngine
191 #endif
192 #endif // SYMENGINE_LAMBDA_DOUBLE_H
The base class for SymEngine.
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:67