Loading...
Searching...
No Matches
codegen.cpp
1#include <symengine/printers/codegen.h>
2#include <symengine/printers.h>
3#include <symengine/symengine_exception.h>
4
5namespace SymEngine
6{
7
8void CodePrinter::bvisit(const Basic &x)
9{
10 throw SymEngineException("Not supported");
11}
12void CodePrinter::bvisit(const Complex &x)
13{
14 throw NotImplementedError("Not implemented");
15}
16void CodePrinter::bvisit(const Interval &x)
17{
18 std::string var = str_;
20 bool is_inf = eq(*x.get_start(), *NegInf);
21 if (not is_inf) {
22 s << var;
23 if (x.get_left_open()) {
24 s << " > ";
25 } else {
26 s << " >= ";
27 }
28 s << apply(x.get_start());
29 }
30 if (neq(*x.get_end(), *Inf)) {
31 if (not is_inf) {
32 s << " && ";
33 }
34 s << var;
35 if (x.get_right_open()) {
36 s << " < ";
37 } else {
38 s << " <= ";
39 }
40 s << apply(x.get_end());
41 }
42 str_ = s.str();
43}
44void CodePrinter::bvisit(const Contains &x)
45{
46 x.get_expr()->accept(*this);
47 x.get_set()->accept(*this);
48}
49void CodePrinter::bvisit(const Piecewise &x)
50{
52 auto vec = x.get_vec();
53 for (size_t i = 0;; ++i) {
54 if (i == vec.size() - 1) {
55 if (neq(*vec[i].second, *boolTrue)) {
56 throw SymEngineException(
57 "Code generation requires a (Expr, True) at the end");
58 }
59 s << "(\n " << apply(vec[i].first) << "\n";
60 break;
61 } else {
62 s << "((";
63 s << apply(vec[i].second);
64 s << ") ? (\n ";
65 s << apply(vec[i].first);
66 s << "\n)\n: ";
67 }
68 }
69 for (size_t i = 0; i < vec.size(); i++) {
70 s << ")";
71 }
72 str_ = s.str();
73}
74void CodePrinter::bvisit(const Rational &x)
75{
77 double n = mp_get_d(get_num(x.as_rational_class()));
78 double d = mp_get_d(get_den(x.as_rational_class()));
79 o << print_double(n) << "/" << print_double(d);
80 str_ = o.str();
81}
82void CodePrinter::bvisit(const Reals &x)
83{
84 throw SymEngineException("Not supported");
85}
86void CodePrinter::bvisit(const Rationals &x)
87{
88 throw SymEngineException("Not supported");
89}
90void CodePrinter::bvisit(const Integers &x)
91{
92 throw SymEngineException("Not supported");
93}
94void CodePrinter::bvisit(const EmptySet &x)
95{
96 throw SymEngineException("Not supported");
97}
98void CodePrinter::bvisit(const FiniteSet &x)
99{
100 throw SymEngineException("Not supported");
101}
102void CodePrinter::bvisit(const UniversalSet &x)
103{
104 throw SymEngineException("Not supported");
105}
106void CodePrinter::bvisit(const Abs &x)
107{
109 s << "fabs(" << apply(x.get_arg()) << ")";
110 str_ = s.str();
111}
112void CodePrinter::bvisit(const Ceiling &x)
113{
115 s << "ceil(" << apply(x.get_arg()) << ")";
116 str_ = s.str();
117}
118void CodePrinter::bvisit(const Truncate &x)
119{
121 s << "trunc(" << apply(x.get_arg()) << ")";
122 str_ = s.str();
123}
124void CodePrinter::bvisit(const Max &x)
125{
127 const auto &args = x.get_args();
128 switch (args.size()) {
129 case 0:
130 case 1:
131 throw SymEngineException("Impossible");
132 case 2:
133 s << "fmax(" << apply(args[0]) << ", " << apply(args[1]) << ")";
134 break;
135 default: {
136 vec_basic inner_args(args.begin() + 1, args.end());
137 auto inner = max(inner_args);
138 s << "fmax(" << apply(args[0]) << ", " << apply(inner) << ")";
139 break;
140 }
141 }
142 str_ = s.str();
143}
144void CodePrinter::bvisit(const Min &x)
145{
147 const auto &args = x.get_args();
148 switch (args.size()) {
149 case 0:
150 case 1:
151 throw SymEngineException("Impossible");
152 case 2:
153 s << "fmin(" << apply(args[0]) << ", " << apply(args[1]) << ")";
154 break;
155 default: {
156 vec_basic inner_args(args.begin() + 1, args.end());
157 auto inner = min(inner_args);
158 s << "fmin(" << apply(args[0]) << ", " << apply(inner) << ")";
159 break;
160 }
161 }
162 str_ = s.str();
163}
164void CodePrinter::bvisit(const Constant &x)
165{
166 if (eq(x, *E)) {
167 str_ = "exp(1)";
168 } else if (eq(x, *pi)) {
169 str_ = "acos(-1)";
170 } else {
171 str_ = x.get_name();
172 }
173}
174void CodePrinter::bvisit(const NaN &x)
175{
177 s << "NAN";
178 str_ = s.str();
179}
180void CodePrinter::bvisit(const Equality &x)
181{
183 s << apply(x.get_arg1()) << " == " << apply(x.get_arg2());
184 str_ = s.str();
185}
186void CodePrinter::bvisit(const Unequality &x)
187{
189 s << apply(x.get_arg1()) << " != " << apply(x.get_arg2());
190 str_ = s.str();
191}
192void CodePrinter::bvisit(const LessThan &x)
193{
195 s << apply(x.get_arg1()) << " <= " << apply(x.get_arg2());
196 str_ = s.str();
197}
198void CodePrinter::bvisit(const StrictLessThan &x)
199{
201 s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
202 str_ = s.str();
203}
204void CodePrinter::bvisit(const UnivariateSeries &x)
205{
206 throw SymEngineException("Not supported");
207}
208void CodePrinter::bvisit(const Derivative &x)
209{
210 throw SymEngineException("Not supported");
211}
212void CodePrinter::bvisit(const Subs &x)
213{
214 throw SymEngineException("Not supported");
215}
216void CodePrinter::bvisit(const GaloisField &x)
217{
218 throw SymEngineException("Not supported");
219}
220
221void C89CodePrinter::bvisit(const Infty &x)
222{
224 if (x.is_negative_infinity())
225 s << "-HUGE_VAL";
226 else if (x.is_positive_infinity())
227 s << "HUGE_VAL";
228 else
229 throw SymEngineException("Not supported");
230 str_ = s.str();
231}
232void C89CodePrinter::_print_pow(std::ostringstream &o,
233 const RCP<const Basic> &a,
234 const RCP<const Basic> &b)
235{
236 if (eq(*a, *E)) {
237 o << "exp(" << apply(b) << ")";
238 } else if (eq(*b, *rational(1, 2))) {
239 o << "sqrt(" << apply(a) << ")";
240 } else {
241 o << "pow(" << apply(a) << ", " << apply(b) << ")";
242 }
243}
244
245void C99CodePrinter::bvisit(const Infty &x)
246{
248 if (x.is_negative_infinity())
249 s << "-INFINITY";
250 else if (x.is_positive_infinity())
251 s << "INFINITY";
252 else
253 throw SymEngineException("Not supported");
254 str_ = s.str();
255}
256void C99CodePrinter::_print_pow(std::ostringstream &o,
257 const RCP<const Basic> &a,
258 const RCP<const Basic> &b)
259{
260 if (eq(*a, *E)) {
261 o << "exp(" << apply(b) << ")";
262 } else if (eq(*b, *rational(1, 2))) {
263 o << "sqrt(" << apply(a) << ")";
264 } else if (eq(*b, *rational(1, 3))) {
265 o << "cbrt(" << apply(a) << ")";
266 } else {
267 o << "pow(" << apply(a) << ", " << apply(b) << ")";
268 }
269}
270void C99CodePrinter::bvisit(const Gamma &x)
271{
273 s << "tgamma(" << apply(x.get_arg()) << ")";
274 str_ = s.str();
275}
276void C99CodePrinter::bvisit(const LogGamma &x)
277{
279 s << "lgamma(" << apply(x.get_arg()) << ")";
280 str_ = s.str();
281}
282
283void JSCodePrinter::bvisit(const Constant &x)
284{
285 if (eq(x, *E)) {
286 str_ = "Math.E";
287 } else if (eq(x, *pi)) {
288 str_ = "Math.PI";
289 } else {
290 str_ = x.get_name();
291 }
292}
293void JSCodePrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
294 const RCP<const Basic> &b)
295{
296 if (eq(*a, *E)) {
297 o << "Math.exp(" << apply(b) << ")";
298 } else if (eq(*b, *rational(1, 2))) {
299 o << "Math.sqrt(" << apply(a) << ")";
300 } else if (eq(*b, *rational(1, 3))) {
301 o << "Math.cbrt(" << apply(a) << ")";
302 } else {
303 o << "Math.pow(" << apply(a) << ", " << apply(b) << ")";
304 }
305}
306void JSCodePrinter::bvisit(const Abs &x)
307{
309 s << "Math.abs(" << apply(x.get_arg()) << ")";
310 str_ = s.str();
311}
312void JSCodePrinter::bvisit(const Sin &x)
313{
315 s << "Math.sin(" << apply(x.get_arg()) << ")";
316 str_ = s.str();
317}
318void JSCodePrinter::bvisit(const Cos &x)
319{
321 s << "Math.cos(" << apply(x.get_arg()) << ")";
322 str_ = s.str();
323}
324void JSCodePrinter::bvisit(const Max &x)
325{
326 const auto &args = x.get_args();
328 s << "Math.max(";
329 for (size_t i = 0; i < args.size(); ++i) {
330 s << apply(args[i]);
331 s << ((i == args.size() - 1) ? ")" : ", ");
332 }
333 str_ = s.str();
334}
335void JSCodePrinter::bvisit(const Min &x)
336{
337 const auto &args = x.get_args();
339 s << "Math.min(";
340 for (size_t i = 0; i < args.size(); ++i) {
341 s << apply(args[i]);
342 s << ((i == args.size() - 1) ? ")" : ", ");
343 }
344 str_ = s.str();
345}
346
347std::string ccode(const Basic &x)
348{
349 C99CodePrinter c;
350 return c.apply(x);
351}
352
353std::string jscode(const Basic &x)
354{
355 JSCodePrinter p;
356 return p.apply(x);
357}
358
359std::string inline c89code(const Basic &x)
360{
361 C89CodePrinter p;
362 return p.apply(x);
363}
364
365std::string inline c99code(const Basic &x)
366{
367 C99CodePrinter p;
368 return p.apply(x);
369}
370
371} // namespace SymEngine
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > max(const vec_basic &arg)
Canonicalize Max:
Definition: functions.cpp:3555
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > min(const vec_basic &arg)
Canonicalize Min:
Definition: functions.cpp:3659
T str(T... args)