Loading...
Searching...
No Matches
llvm_double.h
1#ifndef SYMENGINE_LLVM_DOUBLE_H
2#define SYMENGINE_LLVM_DOUBLE_H
3
4#include <symengine/basic.h>
5#include <symengine/visitor.h>
6#include <float.h>
7
8#ifdef HAVE_SYMENGINE_LLVM
9
10// Forward declare llvm types
11namespace llvm
12{
13class Module;
14class Value;
15class Type;
16class Function;
17class ExecutionEngine;
18class MemoryBufferRef;
19class LLVMContext;
20class Pass;
21namespace legacy
22{
23class FunctionPassManager;
24}
25} // namespace llvm
26
27namespace SymEngine
28{
29
30class IRBuilder;
31
32class LLVMVisitor : public BaseVisitor<LLVMVisitor>
33{
34protected:
35 vec_basic symbols;
37 std::map<RCP<const Basic>, llvm::Value *, RCPBasicKeyLess>
38 replacement_symbol_ptrs;
39 llvm::Value *result_;
43 intptr_t func;
44
45 // Following are invalid after the init call.
46 IRBuilder *builder;
47 llvm::Module *mod;
48 std::string membuffer;
49 llvm::Function *get_function_type(llvm::LLVMContext *);
50 virtual llvm::Type *get_float_type(llvm::LLVMContext *) = 0;
51
52public:
53 llvm::Value *apply(const Basic &b);
54 void init(const vec_basic &x, const Basic &b,
55 const bool symbolic_cse = false, unsigned opt_level = 3);
56 void init(const vec_basic &x, const Basic &b, const bool symbolic_cse,
57 const std::vector<llvm::Pass *> &passes, unsigned opt_level = 3);
58 void init(const vec_basic &inputs, const vec_basic &outputs,
59 const bool symbolic_cse = false, unsigned opt_level = 3);
60 void init(const vec_basic &inputs, const vec_basic &outputs,
61 const bool symbolic_cse, const std::vector<llvm::Pass *> &passes,
62 unsigned opt_level = 3);
63
64 static std::vector<llvm::Pass *> create_default_passes(int optlevel);
65
66 // Helper functions
67 void set_double(double d);
68 llvm::Function *get_external_function(const std::string &name,
69 size_t nargs = 1);
70 llvm::Function *get_powi();
71
72 void bvisit(const Integer &x);
73 void bvisit(const Rational &x);
74 void bvisit(const RealDouble &x);
75#ifdef HAVE_SYMENGINE_MPFR
76 void bvisit(const RealMPFR &x);
77#endif
78 void bvisit(const Add &x);
79 void bvisit(const Mul &x);
80 void bvisit(const Pow &x);
81 void bvisit(const Log &x);
82 void bvisit(const Abs &x);
83 void bvisit(const Symbol &x);
84 void bvisit(const Constant &x);
85 void bvisit(const Basic &);
86 void bvisit(const Sin &x);
87 void bvisit(const Cos &x);
88 void bvisit(const Piecewise &x);
89 void bvisit(const BooleanAtom &x);
90 void bvisit(const And &x);
91 void bvisit(const Or &x);
92 void bvisit(const Xor &x);
93 void bvisit(const Not &x);
94 void bvisit(const Equality &x);
95 void bvisit(const Unequality &x);
96 void bvisit(const LessThan &x);
97 void bvisit(const StrictLessThan &x);
98 void bvisit(const Max &x);
99 void bvisit(const Min &x);
100 void bvisit(const Contains &x);
101 void bvisit(const Infty &x);
102 void bvisit(const NaN &x);
103 void bvisit(const Floor &x);
104 void bvisit(const Ceiling &x);
105 void bvisit(const Truncate &x);
106 void bvisit(const Sign &x);
107 // Return the compiled function as a binary string which can be loaded using
108 // `load`
109 const std::string &dumps() const;
110 // Load a previously compiled function from a string
111 void loads(const std::string &s);
112 void bvisit(const UnevaluatedExpr &x);
113};
114
115class LLVMDoubleVisitor : public LLVMVisitor
116{
117public:
118 double call(const std::vector<double> &vec) const;
119 void call(double *outs, const double *inps) const;
120 llvm::Type *get_float_type(llvm::LLVMContext *) override;
121 void visit(const Tan &x) override;
122 void visit(const ASin &x) override;
123 void visit(const ACos &x) override;
124 void visit(const ATan &x) override;
125 void visit(const ATan2 &x) override;
126 void visit(const Sinh &x) override;
127 void visit(const Cosh &x) override;
128 void visit(const Tanh &x) override;
129 void visit(const ASinh &x) override;
130 void visit(const ACosh &x) override;
131 void visit(const ATanh &x) override;
132 void visit(const Gamma &x) override;
133 void visit(const LogGamma &x) override;
134 void visit(const Erf &x) override;
135 void visit(const Erfc &x) override;
136};
137
138class LLVMFloatVisitor : public LLVMVisitor
139{
140public:
141 float call(const std::vector<float> &vec) const;
142 void call(float *outs, const float *inps) const;
143 llvm::Type *get_float_type(llvm::LLVMContext *) override;
144 void visit(const Tan &x) override;
145 void visit(const ASin &x) override;
146 void visit(const ACos &x) override;
147 void visit(const ATan &x) override;
148 void visit(const ATan2 &x) override;
149 void visit(const Sinh &x) override;
150 void visit(const Cosh &x) override;
151 void visit(const Tanh &x) override;
152 void visit(const ASinh &x) override;
153 void visit(const ACosh &x) override;
154 void visit(const ATanh &x) override;
155 void visit(const Gamma &x) override;
156 void visit(const LogGamma &x) override;
157 void visit(const Erf &x) override;
158 void visit(const Erfc &x) override;
159};
160
161#if SYMENGINE_SIZEOF_LONG_DOUBLE > 8 && defined(__x86_64__) || defined(__i386__)
162#define SYMENGINE_HAVE_LLVM_LONG_DOUBLE 1
163class LLVMLongDoubleVisitor : public LLVMVisitor
164{
165public:
166 long double call(const std::vector<long double> &vec) const;
167 void call(long double *outs, const long double *inps) const;
168 llvm::Type *get_float_type(llvm::LLVMContext *) override;
169 void visit(const Tan &x) override;
170 void visit(const ASin &x) override;
171 void visit(const ACos &x) override;
172 void visit(const ATan &x) override;
173 void visit(const ATan2 &x) override;
174 void visit(const Sinh &x) override;
175 void visit(const Cosh &x) override;
176 void visit(const Tanh &x) override;
177 void visit(const ASinh &x) override;
178 void visit(const ACosh &x) override;
179 void visit(const ATanh &x) override;
180 void visit(const Gamma &x) override;
181 void visit(const LogGamma &x) override;
182 void visit(const Erf &x) override;
183 void visit(const Erfc &x) override;
184 void visit(const Integer &x) override;
185 void visit(const Rational &x) override;
186 void convert_from_mpfr(const Basic &x);
187 void visit(const Constant &x) override;
188#ifdef HAVE_SYMENGINE_MPFR
189 void visit(const RealMPFR &x) override;
190#endif
191};
192#endif
193
194} // namespace SymEngine
195#endif
196#endif // SYMENGINE_LAMBDA_DOUBLE_H
The base class for SymEngine.
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Integer > mod(const Integer &n, const Integer &d)
modulo round toward zero
Definition: ntheory.cpp:66