1 #include <symengine/printers/codegen.h>
5 #include <symengine/printers.h>
6 #include <symengine/symengine_exception.h>
14 const char *print_precision_suffix(CodePrinterPrecision precision)
17 case CodePrinterPrecision::Double:
19 case CodePrinterPrecision::Float:
21 case CodePrinterPrecision::Half:
24 throw SymEngineException(
"Unknown code printer precision");
29 normalize_code_printer_precision(CodePrinterPrecision precision)
32 case CodePrinterPrecision::Double:
33 case CodePrinterPrecision::Float:
34 case CodePrinterPrecision::Half:
37 throw SymEngineException(
"Unknown code printer precision");
43 CodePrinter::CodePrinter(CodePrinterPrecision precision)
44 : precision_{normalize_code_printer_precision(precision)}
48 std::string CodePrinter::print_scalar_literal(
double d)
const
50 return print_double(d) + print_precision_suffix(precision_);
53 std::string CodePrinter::print_math_function(
const std::string &name)
const
55 if (precision_ == CodePrinterPrecision::Float) {
62 CodePrinter::format_codegen_function_name(
const std::string &name)
const
64 return print_math_function(name);
67 std::string CodePrinter::print_binary_reduction(
const vec_basic &args,
68 const std::string &func_name)
70 if (args.size() < 2) {
71 throw SymEngineException(
"Impossible");
74 return print_binary_reduction_impl(args.begin(), args.end(), func_name);
78 CodePrinter::print_binary_reduction_impl(vec_basic::const_iterator begin,
79 vec_basic::const_iterator end,
80 const std::string &func_name)
82 const auto size =
static_cast<std::size_t
>(std::distance(begin, end));
87 const auto mid = begin + size / 2;
89 s << func_name <<
"(" << print_binary_reduction_impl(begin, mid, func_name)
90 <<
", " << print_binary_reduction_impl(mid, end, func_name) <<
")";
94 void CodePrinter::bvisit(
const Basic &x)
96 throw SymEngineException(
"Not supported");
98 void CodePrinter::bvisit(
const Complex &x)
100 throw NotImplementedError(
"Not implemented");
102 void CodePrinter::bvisit(
const Dummy &x)
104 std::ostringstream s;
105 s << x.get_name() <<
'_' << x.get_index();
108 void CodePrinter::bvisit(
const Interval &x)
110 std::string var = str_;
111 std::ostringstream s;
112 bool is_inf =
eq(*x.get_start(), *NegInf);
115 if (x.get_left_open()) {
120 s << apply(x.get_start());
122 if (
neq(*x.get_end(), *Inf)) {
127 if (x.get_right_open()) {
132 s << apply(x.get_end());
136 void CodePrinter::bvisit(
const Contains &x)
138 x.get_expr()->accept(*
this);
139 x.get_set()->accept(*
this);
141 void CodePrinter::bvisit(
const Piecewise &x)
143 std::ostringstream s;
144 auto vec = x.get_vec();
145 for (
size_t i = 0;; ++i) {
146 if (i == vec.size() - 1) {
147 if (
neq(*vec[i].second, *boolTrue)) {
148 throw SymEngineException(
149 "Code generation requires a (Expr, True) at the end");
151 s <<
"(\n " << apply(vec[i].first) <<
"\n";
155 s << apply(vec[i].second);
157 s << apply(vec[i].first);
161 for (
size_t i = 0; i < vec.size(); i++) {
166 void CodePrinter::bvisit(
const BooleanAtom &x)
168 str_ = print_scalar_literal(x.get_val() ? 1.0 : 0.0);
170 void CodePrinter::bvisit(
const Integer &x)
172 if (precision_ != CodePrinterPrecision::Double) {
173 str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
175 StrPrinter::bvisit(x);
178 void CodePrinter::bvisit(
const And &x)
180 std::ostringstream s;
181 const auto &container = x.get_container();
183 for (
auto it = container.begin(); it != container.end(); ++it) {
184 if (it != container.begin()) {
187 s <<
"(" << apply(*(*it)) <<
")";
192 void CodePrinter::bvisit(
const Or &x)
194 std::ostringstream s;
195 const auto &container = x.get_container();
197 for (
auto it = container.begin(); it != container.end(); ++it) {
198 if (it != container.begin()) {
201 s <<
"(" << apply(*(*it)) <<
")";
206 void CodePrinter::bvisit(
const Xor &x)
208 std::ostringstream s;
209 const auto &container = x.get_container();
211 for (
auto it = container.begin(); it != container.end(); ++it) {
212 if (it != container.begin()) {
215 s <<
"((" << apply(*(*it)) <<
") != 0)";
220 void CodePrinter::bvisit(
const Not &x)
222 std::ostringstream s;
223 s <<
"!(" << apply(*x.get_arg()) <<
")";
226 void CodePrinter::bvisit(
const Rational &x)
228 std::ostringstream o;
229 double n = mp_get_d(get_num(x.as_rational_class()));
230 double d = mp_get_d(get_den(x.as_rational_class()));
231 o << print_scalar_literal(n) <<
"/" << print_scalar_literal(d);
234 void CodePrinter::bvisit(
const Reals &x)
236 throw SymEngineException(
"Not supported");
238 void CodePrinter::bvisit(
const Rationals &x)
240 throw SymEngineException(
"Not supported");
242 void CodePrinter::bvisit(
const Integers &x)
244 throw SymEngineException(
"Not supported");
246 void CodePrinter::bvisit(
const EmptySet &x)
248 throw SymEngineException(
"Not supported");
250 void CodePrinter::bvisit(
const FiniteSet &x)
252 throw SymEngineException(
"Not supported");
254 void CodePrinter::bvisit(
const UniversalSet &x)
256 throw SymEngineException(
"Not supported");
258 void CodePrinter::bvisit(
const Abs &x)
260 std::ostringstream s;
261 s << print_math_function(
"fabs") <<
"(" << apply(x.get_arg()) <<
")";
264 void CodePrinter::bvisit(
const Ceiling &x)
266 std::ostringstream s;
267 s << print_math_function(
"ceil") <<
"(" << apply(x.get_arg()) <<
")";
270 void CodePrinter::bvisit(
const Truncate &x)
272 std::ostringstream s;
273 s << print_math_function(
"trunc") <<
"(" << apply(x.get_arg()) <<
")";
276 void CodePrinter::bvisit(
const Max &x)
278 str_ = print_binary_reduction(x.get_args(), print_math_function(
"fmax"));
280 void CodePrinter::bvisit(
const Min &x)
282 str_ = print_binary_reduction(x.get_args(), print_math_function(
"fmin"));
284 void CodePrinter::bvisit(
const Constant &x)
287 str_ = precision_ != CodePrinterPrecision::Double
288 ? print_math_function(
"exp") +
"("
289 + print_scalar_literal(1.0) +
")"
291 }
else if (
eq(x, *pi)) {
292 str_ = precision_ != CodePrinterPrecision::Double
293 ? print_math_function(
"acos") +
"("
294 + print_scalar_literal(-1.0) +
")"
300 void CodePrinter::bvisit(
const NaN &x)
302 std::ostringstream s;
306 void CodePrinter::bvisit(
const Equality &x)
308 std::ostringstream s;
309 s << apply(x.get_arg1()) <<
" == " << apply(x.get_arg2());
312 void CodePrinter::bvisit(
const Unequality &x)
314 std::ostringstream s;
315 s << apply(x.get_arg1()) <<
" != " << apply(x.get_arg2());
318 void CodePrinter::bvisit(
const LessThan &x)
320 std::ostringstream s;
321 s << apply(x.get_arg1()) <<
" <= " << apply(x.get_arg2());
324 void CodePrinter::bvisit(
const StrictLessThan &x)
326 std::ostringstream s;
327 s << apply(x.get_arg1()) <<
" < " << apply(x.get_arg2());
330 void CodePrinter::bvisit(
const Sign &x)
332 const std::string arg = apply(x.get_arg());
333 const std::string zero = print_scalar_literal(0.0);
334 const std::string one = print_scalar_literal(1.0);
335 const std::string minus_one = print_scalar_literal(-1.0);
336 std::ostringstream s;
337 s <<
"((" << arg <<
" == " << zero <<
") ? (" << zero <<
") : ((" << arg
338 <<
" < " << zero <<
") ? (" << minus_one <<
") : (" << one <<
")))";
341 void CodePrinter::bvisit(
const UnevaluatedExpr &x)
343 str_ = apply(x.get_arg());
345 void CodePrinter::bvisit(
const UnivariateSeries &x)
347 throw SymEngineException(
"Not supported");
349 void CodePrinter::bvisit(
const Derivative &x)
351 throw SymEngineException(
"Not supported");
353 void CodePrinter::bvisit(
const Subs &x)
355 throw SymEngineException(
"Not supported");
357 void CodePrinter::bvisit(
const GaloisField &x)
359 throw SymEngineException(
"Not supported");
362 void CodePrinter::bvisit(
const Function &x)
364 static const std::vector<std::string> names_ = init_str_printer_names();
365 std::ostringstream o;
366 o << format_codegen_function_name(names_[x.get_type_code()]);
367 vec_basic vec = x.get_args();
368 o << parenthesize(apply(vec));
372 void CodePrinter::bvisit(
const RealDouble &x)
374 if (precision_ != CodePrinterPrecision::Double) {
375 str_ = print_scalar_literal(x.i);
377 StrPrinter::bvisit(x);
381 #ifdef HAVE_SYMENGINE_MPFR
382 void CodePrinter::bvisit(
const RealMPFR &x)
384 StrPrinter::bvisit(x);
385 if (precision_ != CodePrinterPrecision::Double) {
386 str_ += print_precision_suffix(precision_);
391 C89CodePrinter::C89CodePrinter(CodePrinterPrecision precision)
392 : RewriteTrigVisitor<C89CodePrinter, CodePrinter>(precision)
394 if (precision_ == CodePrinterPrecision::Half) {
395 throw SymEngineException(
396 "C-family code printers do not support half precision");
400 void C89CodePrinter::bvisit(
const Infty &x)
402 std::ostringstream s;
403 if (x.is_negative_infinity())
405 else if (x.is_positive_infinity())
408 throw SymEngineException(
"Not supported");
411 void C89CodePrinter::_print_pow(std::ostringstream &o,
412 const RCP<const Basic> &a,
413 const RCP<const Basic> &b)
416 o << print_math_function(
"exp") <<
"(" << apply(b) <<
")";
417 }
else if (
eq(*b, *minus_one)) {
418 o << apply(*one) <<
"/" << parenthesizeLE(a, PrecedenceEnum::Mul);
420 o << print_math_function(
"sqrt") <<
"(" << apply(a) <<
")";
422 o << print_math_function(
"pow") <<
"(" << apply(a) <<
", " << apply(b)
427 C99CodePrinter::C99CodePrinter(CodePrinterPrecision precision)
428 : RewriteTrigVisitor<C99CodePrinter, C89CodePrinter>(precision)
432 void C99CodePrinter::bvisit(
const Infty &x)
434 std::ostringstream s;
435 if (x.is_negative_infinity())
437 else if (x.is_positive_infinity())
440 throw SymEngineException(
"Not supported");
443 void C99CodePrinter::_print_pow(std::ostringstream &o,
444 const RCP<const Basic> &a,
445 const RCP<const Basic> &b)
448 o << print_math_function(
"exp") <<
"(" << apply(b) <<
")";
449 }
else if (
eq(*b, *minus_one)) {
450 o << apply(*one) <<
"/" << parenthesizeLE(a, PrecedenceEnum::Mul);
452 o << print_math_function(
"sqrt") <<
"(" << apply(a) <<
")";
454 o << print_math_function(
"cbrt") <<
"(" << apply(a) <<
")";
456 o << print_math_function(
"pow") <<
"(" << apply(a) <<
", " << apply(b)
460 void C99CodePrinter::bvisit(
const Gamma &x)
462 std::ostringstream s;
463 s << print_math_function(
"tgamma") <<
"(" << apply(x.get_arg()) <<
")";
466 void C99CodePrinter::bvisit(
const LogGamma &x)
468 std::ostringstream s;
469 s << print_math_function(
"lgamma") <<
"(" << apply(x.get_arg()) <<
")";
473 CudaCodePrinter::CudaCodePrinter(CodePrinterPrecision precision)
474 : RewriteTrigVisitor<CudaCodePrinter, C99CodePrinter>(precision)
478 void CudaCodePrinter::bvisit(
const Integer &x)
480 str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
482 void CudaCodePrinter::bvisit(
const Constant &x)
485 str_ = print_math_function(
"exp") +
"(" + print_scalar_literal(1.0)
487 }
else if (
eq(x, *pi)) {
488 str_ = print_math_function(
"acos") +
"(" + print_scalar_literal(-1.0)
495 void CudaCodePrinter::bvisit(
const NaN &x)
497 str_ = precision_ == CodePrinterPrecision::Float ?
"CUDART_NAN_F"
501 void CudaCodePrinter::bvisit(
const Infty &x)
503 if (x.is_negative_infinity())
504 str_ = precision_ == CodePrinterPrecision::Float ?
"-CUDART_INF_F"
506 else if (x.is_positive_infinity())
507 str_ = precision_ == CodePrinterPrecision::Float ?
"CUDART_INF_F"
510 throw SymEngineException(
"Not supported");
513 MetalCodePrinter::MetalCodePrinter(CodePrinterPrecision precision)
514 : RewriteTrigVisitor<MetalCodePrinter, CodePrinter>(precision)
516 if (precision_ == CodePrinterPrecision::Double) {
517 throw SymEngineException(
518 "Metal code printer currently only supports float and half "
523 void MetalCodePrinter::bvisit(
const Constant &x)
526 str_ =
"exp(" + print_scalar_literal(1.0) +
")";
527 }
else if (
eq(x, *pi)) {
528 str_ =
"acos(" + print_scalar_literal(-1.0) +
")";
534 void MetalCodePrinter::bvisit(
const NaN &x)
536 if (precision_ == CodePrinterPrecision::Half) {
543 void MetalCodePrinter::bvisit(
const Infty &x)
545 if (x.is_negative_infinity()) {
546 str_ = precision_ == CodePrinterPrecision::Half ?
"-HUGE_VALH"
548 }
else if (x.is_positive_infinity()) {
549 str_ = precision_ == CodePrinterPrecision::Half ?
"HUGE_VALH"
552 throw SymEngineException(
"Not supported");
556 void MetalCodePrinter::bvisit(
const Abs &x)
558 std::ostringstream s;
559 s <<
"fabs(" << apply(x.get_arg()) <<
")";
563 void MetalCodePrinter::bvisit(
const Ceiling &x)
565 std::ostringstream s;
566 s <<
"ceil(" << apply(x.get_arg()) <<
")";
570 void MetalCodePrinter::bvisit(
const Truncate &x)
572 std::ostringstream s;
573 s <<
"trunc(" << apply(x.get_arg()) <<
")";
577 void MetalCodePrinter::bvisit(
const Max &x)
579 str_ = print_binary_reduction(x.get_args(),
"fmax");
582 void MetalCodePrinter::bvisit(
const Min &x)
584 str_ = print_binary_reduction(x.get_args(),
"fmin");
588 MetalCodePrinter::format_codegen_function_name(
const std::string &name)
const
593 void MetalCodePrinter::_print_pow(std::ostringstream &o,
594 const RCP<const Basic> &a,
595 const RCP<const Basic> &b)
598 o <<
"exp(" << apply(b) <<
")";
599 }
else if (
eq(*b, *minus_one)) {
600 o << apply(*one) <<
"/" << parenthesizeLE(a, PrecedenceEnum::Mul);
602 o <<
"sqrt(" << apply(a) <<
")";
604 o <<
"pow(" << apply(a) <<
", " << apply(b) <<
")";
608 void JSCodePrinter::bvisit(
const Constant &x)
612 }
else if (
eq(x, *pi)) {
618 void JSCodePrinter::_print_pow(std::ostringstream &o,
const RCP<const Basic> &a,
619 const RCP<const Basic> &b)
622 o <<
"Math.exp(" << apply(b) <<
")";
623 }
else if (
eq(*b, *minus_one)) {
624 o << apply(*one) <<
"/" << parenthesizeLE(a, PrecedenceEnum::Mul);
626 o <<
"Math.sqrt(" << apply(a) <<
")";
628 o <<
"Math.cbrt(" << apply(a) <<
")";
630 o <<
"Math.pow(" << apply(a) <<
", " << apply(b) <<
")";
633 void JSCodePrinter::bvisit(
const Abs &x)
635 std::ostringstream s;
636 s <<
"Math.abs(" << apply(x.get_arg()) <<
")";
639 void JSCodePrinter::bvisit(
const Sin &x)
641 std::ostringstream s;
642 s <<
"Math.sin(" << apply(x.get_arg()) <<
")";
645 void JSCodePrinter::bvisit(
const Cos &x)
647 std::ostringstream s;
648 s <<
"Math.cos(" << apply(x.get_arg()) <<
")";
651 void JSCodePrinter::bvisit(
const Max &x)
653 const auto &args = x.get_args();
654 std::ostringstream s;
656 for (
size_t i = 0; i < args.size(); ++i) {
658 s << ((i == args.size() - 1) ?
")" :
", ");
662 void JSCodePrinter::bvisit(
const Min &x)
664 const auto &args = x.get_args();
665 std::ostringstream s;
667 for (
size_t i = 0; i < args.size(); ++i) {
669 s << ((i == args.size() - 1) ?
")" :
", ");
675 JSCodePrinter::format_codegen_function_name(
const std::string &name)
const
677 return "Math." + name;
680 std::string ccode(
const Basic &x, CodePrinterPrecision precision)
682 C99CodePrinter c(precision);
686 std::string cudacode(
const Basic &x, CodePrinterPrecision precision)
688 CudaCodePrinter p(precision);
692 std::string metalcode(
const Basic &x, CodePrinterPrecision precision)
694 MetalCodePrinter p(precision);
698 std::string jscode(
const Basic &x)
704 std::string
inline c89code(
const Basic &x)
710 std::string
inline c99code(
const Basic &x)
Main namespace for SymEngine package.
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Number > rational(long n, long d)
convenience creator from two longs