1 #include <symengine/printers/codegen.h>
2 #include <symengine/printers.h>
3 #include <symengine/symengine_exception.h>
11 std::string print_float_literal(
double d)
13 return print_double(d) +
"f";
17 normalize_code_printer_precision(CodePrinterPrecision precision)
20 case CodePrinterPrecision::Double:
21 case CodePrinterPrecision::Float:
24 throw SymEngineException(
"Unknown code printer precision");
30 CodePrinter::CodePrinter(CodePrinterPrecision precision)
31 : precision_{normalize_code_printer_precision(precision)}
35 std::string CodePrinter::print_scalar_literal(
double d)
const
37 if (precision_ == CodePrinterPrecision::Float) {
38 return print_float_literal(d);
40 return print_double(d);
43 std::string CodePrinter::print_math_function(
const std::string &name)
const
45 if (precision_ == CodePrinterPrecision::Float) {
51 void CodePrinter::bvisit(
const Basic &x)
53 throw SymEngineException(
"Not supported");
55 void CodePrinter::bvisit(
const Complex &x)
57 throw NotImplementedError(
"Not implemented");
59 void CodePrinter::bvisit(
const Dummy &x)
62 s << x.get_name() <<
'_' << x.get_index();
65 void CodePrinter::bvisit(
const Interval &x)
67 std::string var = str_;
69 bool is_inf =
eq(*x.get_start(), *NegInf);
72 if (x.get_left_open()) {
77 s << apply(x.get_start());
79 if (
neq(*x.get_end(), *Inf)) {
84 if (x.get_right_open()) {
89 s << apply(x.get_end());
93 void CodePrinter::bvisit(
const Contains &x)
95 x.get_expr()->accept(*
this);
96 x.get_set()->accept(*
this);
98 void CodePrinter::bvisit(
const Piecewise &x)
100 std::ostringstream s;
101 auto vec = x.get_vec();
102 for (
size_t i = 0;; ++i) {
103 if (i == vec.size() - 1) {
104 if (
neq(*vec[i].second, *boolTrue)) {
105 throw SymEngineException(
106 "Code generation requires a (Expr, True) at the end");
108 s <<
"(\n " << apply(vec[i].first) <<
"\n";
112 s << apply(vec[i].second);
114 s << apply(vec[i].first);
118 for (
size_t i = 0; i < vec.size(); i++) {
123 void CodePrinter::bvisit(
const BooleanAtom &x)
125 str_ = print_scalar_literal(x.get_val() ? 1.0 : 0.0);
127 void CodePrinter::bvisit(
const Integer &x)
129 if (precision_ == CodePrinterPrecision::Float) {
130 str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
132 StrPrinter::bvisit(x);
135 void CodePrinter::bvisit(
const And &x)
137 std::ostringstream s;
138 const auto &container = x.get_container();
140 for (
auto it = container.begin(); it != container.end(); ++it) {
141 if (it != container.begin()) {
144 s <<
"(" << apply(*(*it)) <<
")";
149 void CodePrinter::bvisit(
const Or &x)
151 std::ostringstream s;
152 const auto &container = x.get_container();
154 for (
auto it = container.begin(); it != container.end(); ++it) {
155 if (it != container.begin()) {
158 s <<
"(" << apply(*(*it)) <<
")";
163 void CodePrinter::bvisit(
const Xor &x)
165 std::ostringstream s;
166 const auto &container = x.get_container();
168 for (
auto it = container.begin(); it != container.end(); ++it) {
169 if (it != container.begin()) {
172 s <<
"((" << apply(*(*it)) <<
") != 0)";
177 void CodePrinter::bvisit(
const Not &x)
179 std::ostringstream s;
180 s <<
"!(" << apply(*x.get_arg()) <<
")";
183 void CodePrinter::bvisit(
const Rational &x)
185 std::ostringstream o;
186 double n = mp_get_d(get_num(x.as_rational_class()));
187 double d = mp_get_d(get_den(x.as_rational_class()));
188 o << print_scalar_literal(n) <<
"/" << print_scalar_literal(d);
191 void CodePrinter::bvisit(
const Reals &x)
193 throw SymEngineException(
"Not supported");
195 void CodePrinter::bvisit(
const Rationals &x)
197 throw SymEngineException(
"Not supported");
199 void CodePrinter::bvisit(
const Integers &x)
201 throw SymEngineException(
"Not supported");
203 void CodePrinter::bvisit(
const EmptySet &x)
205 throw SymEngineException(
"Not supported");
207 void CodePrinter::bvisit(
const FiniteSet &x)
209 throw SymEngineException(
"Not supported");
211 void CodePrinter::bvisit(
const UniversalSet &x)
213 throw SymEngineException(
"Not supported");
215 void CodePrinter::bvisit(
const Abs &x)
217 std::ostringstream s;
218 s << print_math_function(
"fabs") <<
"(" << apply(x.get_arg()) <<
")";
221 void CodePrinter::bvisit(
const Ceiling &x)
223 std::ostringstream s;
224 s << print_math_function(
"ceil") <<
"(" << apply(x.get_arg()) <<
")";
227 void CodePrinter::bvisit(
const Truncate &x)
229 std::ostringstream s;
230 s << print_math_function(
"trunc") <<
"(" << apply(x.get_arg()) <<
")";
233 void CodePrinter::bvisit(
const Max &x)
235 std::ostringstream s;
236 const auto &args = x.get_args();
237 switch (args.size()) {
240 throw SymEngineException(
"Impossible");
242 s << print_math_function(
"fmax") <<
"(" << apply(args[0]) <<
", "
243 << apply(args[1]) <<
")";
246 vec_basic inner_args(args.begin() + 1, args.end());
247 auto inner =
max(inner_args);
248 s << print_math_function(
"fmax") <<
"(" << apply(args[0]) <<
", "
249 << apply(inner) <<
")";
255 void CodePrinter::bvisit(
const Min &x)
257 std::ostringstream s;
258 const auto &args = x.get_args();
259 switch (args.size()) {
262 throw SymEngineException(
"Impossible");
264 s << print_math_function(
"fmin") <<
"(" << apply(args[0]) <<
", "
265 << apply(args[1]) <<
")";
268 vec_basic inner_args(args.begin() + 1, args.end());
269 auto inner =
min(inner_args);
270 s << print_math_function(
"fmin") <<
"(" << apply(args[0]) <<
", "
271 << apply(inner) <<
")";
277 void CodePrinter::bvisit(
const Constant &x)
280 str_ = precision_ == CodePrinterPrecision::Float
281 ? print_math_function(
"exp") +
"("
282 + print_scalar_literal(1.0) +
")"
284 }
else if (
eq(x, *pi)) {
285 str_ = precision_ == CodePrinterPrecision::Float
286 ? print_math_function(
"acos") +
"("
287 + print_scalar_literal(-1.0) +
")"
293 void CodePrinter::bvisit(
const NaN &x)
295 std::ostringstream s;
299 void CodePrinter::bvisit(
const Equality &x)
301 std::ostringstream s;
302 s << apply(x.get_arg1()) <<
" == " << apply(x.get_arg2());
305 void CodePrinter::bvisit(
const Unequality &x)
307 std::ostringstream s;
308 s << apply(x.get_arg1()) <<
" != " << apply(x.get_arg2());
311 void CodePrinter::bvisit(
const LessThan &x)
313 std::ostringstream s;
314 s << apply(x.get_arg1()) <<
" <= " << apply(x.get_arg2());
317 void CodePrinter::bvisit(
const StrictLessThan &x)
319 std::ostringstream s;
320 s << apply(x.get_arg1()) <<
" < " << apply(x.get_arg2());
323 void CodePrinter::bvisit(
const Sign &x)
325 const std::string arg = apply(x.get_arg());
326 const std::string zero = print_scalar_literal(0.0);
327 const std::string one = print_scalar_literal(1.0);
328 const std::string minus_one = print_scalar_literal(-1.0);
329 std::ostringstream s;
330 s <<
"((" << arg <<
" == " << zero <<
") ? (" << zero <<
") : ((" << arg
331 <<
" < " << zero <<
") ? (" << minus_one <<
") : (" << one <<
")))";
334 void CodePrinter::bvisit(
const UnevaluatedExpr &x)
336 str_ = apply(x.get_arg());
338 void CodePrinter::bvisit(
const UnivariateSeries &x)
340 throw SymEngineException(
"Not supported");
342 void CodePrinter::bvisit(
const Derivative &x)
344 throw SymEngineException(
"Not supported");
346 void CodePrinter::bvisit(
const Subs &x)
348 throw SymEngineException(
"Not supported");
350 void CodePrinter::bvisit(
const GaloisField &x)
352 throw SymEngineException(
"Not supported");
355 void CodePrinter::bvisit(
const Function &x)
357 static const std::vector<std::string> names_ = init_str_printer_names();
358 std::ostringstream o;
359 o << print_math_function(names_[x.get_type_code()]);
360 vec_basic vec = x.get_args();
361 o << parenthesize(apply(vec));
365 void CodePrinter::bvisit(
const RealDouble &x)
367 if (precision_ == CodePrinterPrecision::Float) {
368 str_ = print_scalar_literal(x.i);
370 StrPrinter::bvisit(x);
374 #ifdef HAVE_SYMENGINE_MPFR
375 void CodePrinter::bvisit(
const RealMPFR &x)
377 StrPrinter::bvisit(x);
378 if (precision_ == CodePrinterPrecision::Float) {
384 C89CodePrinter::C89CodePrinter(CodePrinterPrecision precision)
385 : BaseVisitor<C89CodePrinter, CodePrinter>(precision)
389 void C89CodePrinter::bvisit(
const Infty &x)
391 std::ostringstream s;
392 if (x.is_negative_infinity())
394 else if (x.is_positive_infinity())
397 throw SymEngineException(
"Not supported");
400 void C89CodePrinter::_print_pow(std::ostringstream &o,
401 const RCP<const Basic> &a,
402 const RCP<const Basic> &b)
405 o << print_math_function(
"exp") <<
"(" << apply(b) <<
")";
407 o << print_math_function(
"sqrt") <<
"(" << apply(a) <<
")";
409 o << print_math_function(
"pow") <<
"(" << apply(a) <<
", " << apply(b)
414 C99CodePrinter::C99CodePrinter(CodePrinterPrecision precision)
415 : BaseVisitor<C99CodePrinter, C89CodePrinter>(precision)
419 void C99CodePrinter::bvisit(
const Infty &x)
421 std::ostringstream s;
422 if (x.is_negative_infinity())
424 else if (x.is_positive_infinity())
427 throw SymEngineException(
"Not supported");
430 void C99CodePrinter::_print_pow(std::ostringstream &o,
431 const RCP<const Basic> &a,
432 const RCP<const Basic> &b)
435 o << print_math_function(
"exp") <<
"(" << apply(b) <<
")";
437 o << print_math_function(
"sqrt") <<
"(" << apply(a) <<
")";
439 o << print_math_function(
"cbrt") <<
"(" << apply(a) <<
")";
441 o << print_math_function(
"pow") <<
"(" << apply(a) <<
", " << apply(b)
445 void C99CodePrinter::bvisit(
const Gamma &x)
447 std::ostringstream s;
448 s << print_math_function(
"tgamma") <<
"(" << apply(x.get_arg()) <<
")";
451 void C99CodePrinter::bvisit(
const LogGamma &x)
453 std::ostringstream s;
454 s << print_math_function(
"lgamma") <<
"(" << apply(x.get_arg()) <<
")";
458 CudaCodePrinter::CudaCodePrinter(CodePrinterPrecision precision)
459 : BaseVisitor<CudaCodePrinter, C99CodePrinter>(precision)
463 void CudaCodePrinter::bvisit(
const Integer &x)
465 str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
467 void CudaCodePrinter::bvisit(
const Constant &x)
470 str_ = print_math_function(
"exp") +
"(" + print_scalar_literal(1.0)
472 }
else if (
eq(x, *pi)) {
473 str_ = print_math_function(
"acos") +
"(" + print_scalar_literal(-1.0)
480 void CudaCodePrinter::bvisit(
const NaN &x)
482 str_ = precision_ == CodePrinterPrecision::Float ?
"CUDART_NAN_F"
486 void CudaCodePrinter::bvisit(
const Infty &x)
488 if (x.is_negative_infinity())
489 str_ = precision_ == CodePrinterPrecision::Float ?
"-CUDART_INF_F"
491 else if (x.is_positive_infinity())
492 str_ = precision_ == CodePrinterPrecision::Float ?
"CUDART_INF_F"
495 throw SymEngineException(
"Not supported");
498 void JSCodePrinter::bvisit(
const Constant &x)
502 }
else if (
eq(x, *pi)) {
508 void JSCodePrinter::_print_pow(std::ostringstream &o,
const RCP<const Basic> &a,
509 const RCP<const Basic> &b)
512 o <<
"Math.exp(" << apply(b) <<
")";
514 o <<
"Math.sqrt(" << apply(a) <<
")";
516 o <<
"Math.cbrt(" << apply(a) <<
")";
518 o <<
"Math.pow(" << apply(a) <<
", " << apply(b) <<
")";
521 void JSCodePrinter::bvisit(
const Abs &x)
523 std::ostringstream s;
524 s <<
"Math.abs(" << apply(x.get_arg()) <<
")";
527 void JSCodePrinter::bvisit(
const Sin &x)
529 std::ostringstream s;
530 s <<
"Math.sin(" << apply(x.get_arg()) <<
")";
533 void JSCodePrinter::bvisit(
const Cos &x)
535 std::ostringstream s;
536 s <<
"Math.cos(" << apply(x.get_arg()) <<
")";
539 void JSCodePrinter::bvisit(
const Max &x)
541 const auto &args = x.get_args();
542 std::ostringstream s;
544 for (
size_t i = 0; i < args.size(); ++i) {
546 s << ((i == args.size() - 1) ?
")" :
", ");
550 void JSCodePrinter::bvisit(
const Min &x)
552 const auto &args = x.get_args();
553 std::ostringstream s;
555 for (
size_t i = 0; i < args.size(); ++i) {
557 s << ((i == args.size() - 1) ?
")" :
", ");
562 std::string ccode(
const Basic &x, CodePrinterPrecision precision)
564 C99CodePrinter c(precision);
568 std::string cudacode(
const Basic &x, CodePrinterPrecision precision)
570 CudaCodePrinter p(precision);
574 std::string jscode(
const Basic &x)
580 std::string
inline c89code(
const Basic &x)
586 std::string
inline c99code(
const Basic &x)
Main namespace for SymEngine package.
RCP< const Basic > max(const vec_basic &arg)
Canonicalize Max:
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Basic > min(const vec_basic &arg)
Canonicalize Min:
RCP< const Number > rational(long n, long d)
convenience creator from two longs