1 #include <symengine/printers/codegen.h>
2 #include <symengine/printers.h>
3 #include <symengine/symengine_exception.h>
11 const char *print_precision_suffix(CodePrinterPrecision precision)
14 case CodePrinterPrecision::Double:
16 case CodePrinterPrecision::Float:
18 case CodePrinterPrecision::Half:
21 throw SymEngineException(
"Unknown code printer precision");
26 normalize_code_printer_precision(CodePrinterPrecision precision)
29 case CodePrinterPrecision::Double:
30 case CodePrinterPrecision::Float:
31 case CodePrinterPrecision::Half:
34 throw SymEngineException(
"Unknown code printer precision");
40 CodePrinter::CodePrinter(CodePrinterPrecision precision)
41 : precision_{normalize_code_printer_precision(precision)}
45 std::string CodePrinter::print_scalar_literal(
double d)
const
47 return print_double(d) + print_precision_suffix(precision_);
50 std::string CodePrinter::print_math_function(
const std::string &name)
const
52 if (precision_ == CodePrinterPrecision::Float) {
58 std::string CodePrinter::print_binary_reduction(
const vec_basic &args,
59 const std::string &func_name)
61 if (args.size() < 2) {
62 throw SymEngineException(
"Impossible");
65 return print_binary_reduction_impl(args.begin(), args.end(), func_name);
69 CodePrinter::print_binary_reduction_impl(vec_basic::const_iterator begin,
70 vec_basic::const_iterator end,
71 const std::string &func_name)
73 const auto size =
static_cast<std::size_t
>(std::distance(begin, end));
78 const auto mid = begin + size / 2;
80 s << func_name <<
"(" << print_binary_reduction_impl(begin, mid, func_name)
81 <<
", " << print_binary_reduction_impl(mid, end, func_name) <<
")";
85 void CodePrinter::bvisit(
const Basic &x)
87 throw SymEngineException(
"Not supported");
89 void CodePrinter::bvisit(
const Complex &x)
91 throw NotImplementedError(
"Not implemented");
93 void CodePrinter::bvisit(
const Dummy &x)
96 s << x.get_name() <<
'_' << x.get_index();
99 void CodePrinter::bvisit(
const Interval &x)
101 std::string var = str_;
102 std::ostringstream s;
103 bool is_inf =
eq(*x.get_start(), *NegInf);
106 if (x.get_left_open()) {
111 s << apply(x.get_start());
113 if (
neq(*x.get_end(), *Inf)) {
118 if (x.get_right_open()) {
123 s << apply(x.get_end());
127 void CodePrinter::bvisit(
const Contains &x)
129 x.get_expr()->accept(*
this);
130 x.get_set()->accept(*
this);
132 void CodePrinter::bvisit(
const Piecewise &x)
134 std::ostringstream s;
135 auto vec = x.get_vec();
136 for (
size_t i = 0;; ++i) {
137 if (i == vec.size() - 1) {
138 if (
neq(*vec[i].second, *boolTrue)) {
139 throw SymEngineException(
140 "Code generation requires a (Expr, True) at the end");
142 s <<
"(\n " << apply(vec[i].first) <<
"\n";
146 s << apply(vec[i].second);
148 s << apply(vec[i].first);
152 for (
size_t i = 0; i < vec.size(); i++) {
157 void CodePrinter::bvisit(
const BooleanAtom &x)
159 str_ = print_scalar_literal(x.get_val() ? 1.0 : 0.0);
161 void CodePrinter::bvisit(
const Integer &x)
163 if (precision_ != CodePrinterPrecision::Double) {
164 str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
166 StrPrinter::bvisit(x);
169 void CodePrinter::bvisit(
const And &x)
171 std::ostringstream s;
172 const auto &container = x.get_container();
174 for (
auto it = container.begin(); it != container.end(); ++it) {
175 if (it != container.begin()) {
178 s <<
"(" << apply(*(*it)) <<
")";
183 void CodePrinter::bvisit(
const Or &x)
185 std::ostringstream s;
186 const auto &container = x.get_container();
188 for (
auto it = container.begin(); it != container.end(); ++it) {
189 if (it != container.begin()) {
192 s <<
"(" << apply(*(*it)) <<
")";
197 void CodePrinter::bvisit(
const Xor &x)
199 std::ostringstream s;
200 const auto &container = x.get_container();
202 for (
auto it = container.begin(); it != container.end(); ++it) {
203 if (it != container.begin()) {
206 s <<
"((" << apply(*(*it)) <<
") != 0)";
211 void CodePrinter::bvisit(
const Not &x)
213 std::ostringstream s;
214 s <<
"!(" << apply(*x.get_arg()) <<
")";
217 void CodePrinter::bvisit(
const Rational &x)
219 std::ostringstream o;
220 double n = mp_get_d(get_num(x.as_rational_class()));
221 double d = mp_get_d(get_den(x.as_rational_class()));
222 o << print_scalar_literal(n) <<
"/" << print_scalar_literal(d);
225 void CodePrinter::bvisit(
const Reals &x)
227 throw SymEngineException(
"Not supported");
229 void CodePrinter::bvisit(
const Rationals &x)
231 throw SymEngineException(
"Not supported");
233 void CodePrinter::bvisit(
const Integers &x)
235 throw SymEngineException(
"Not supported");
237 void CodePrinter::bvisit(
const EmptySet &x)
239 throw SymEngineException(
"Not supported");
241 void CodePrinter::bvisit(
const FiniteSet &x)
243 throw SymEngineException(
"Not supported");
245 void CodePrinter::bvisit(
const UniversalSet &x)
247 throw SymEngineException(
"Not supported");
249 void CodePrinter::bvisit(
const Abs &x)
251 std::ostringstream s;
252 s << print_math_function(
"fabs") <<
"(" << apply(x.get_arg()) <<
")";
255 void CodePrinter::bvisit(
const Ceiling &x)
257 std::ostringstream s;
258 s << print_math_function(
"ceil") <<
"(" << apply(x.get_arg()) <<
")";
261 void CodePrinter::bvisit(
const Truncate &x)
263 std::ostringstream s;
264 s << print_math_function(
"trunc") <<
"(" << apply(x.get_arg()) <<
")";
267 void CodePrinter::bvisit(
const Max &x)
269 str_ = print_binary_reduction(x.get_args(), print_math_function(
"fmax"));
271 void CodePrinter::bvisit(
const Min &x)
273 str_ = print_binary_reduction(x.get_args(), print_math_function(
"fmin"));
275 void CodePrinter::bvisit(
const Constant &x)
278 str_ = precision_ != CodePrinterPrecision::Double
279 ? print_math_function(
"exp") +
"("
280 + print_scalar_literal(1.0) +
")"
282 }
else if (
eq(x, *pi)) {
283 str_ = precision_ != CodePrinterPrecision::Double
284 ? print_math_function(
"acos") +
"("
285 + print_scalar_literal(-1.0) +
")"
291 void CodePrinter::bvisit(
const NaN &x)
293 std::ostringstream s;
297 void CodePrinter::bvisit(
const Equality &x)
299 std::ostringstream s;
300 s << apply(x.get_arg1()) <<
" == " << apply(x.get_arg2());
303 void CodePrinter::bvisit(
const Unequality &x)
305 std::ostringstream s;
306 s << apply(x.get_arg1()) <<
" != " << apply(x.get_arg2());
309 void CodePrinter::bvisit(
const LessThan &x)
311 std::ostringstream s;
312 s << apply(x.get_arg1()) <<
" <= " << apply(x.get_arg2());
315 void CodePrinter::bvisit(
const StrictLessThan &x)
317 std::ostringstream s;
318 s << apply(x.get_arg1()) <<
" < " << apply(x.get_arg2());
321 void CodePrinter::bvisit(
const Sign &x)
323 const std::string arg = apply(x.get_arg());
324 const std::string zero = print_scalar_literal(0.0);
325 const std::string one = print_scalar_literal(1.0);
326 const std::string minus_one = print_scalar_literal(-1.0);
327 std::ostringstream s;
328 s <<
"((" << arg <<
" == " << zero <<
") ? (" << zero <<
") : ((" << arg
329 <<
" < " << zero <<
") ? (" << minus_one <<
") : (" << one <<
")))";
332 void CodePrinter::bvisit(
const UnevaluatedExpr &x)
334 str_ = apply(x.get_arg());
336 void CodePrinter::bvisit(
const UnivariateSeries &x)
338 throw SymEngineException(
"Not supported");
340 void CodePrinter::bvisit(
const Derivative &x)
342 throw SymEngineException(
"Not supported");
344 void CodePrinter::bvisit(
const Subs &x)
346 throw SymEngineException(
"Not supported");
348 void CodePrinter::bvisit(
const GaloisField &x)
350 throw SymEngineException(
"Not supported");
353 void CodePrinter::bvisit(
const Function &x)
355 static const std::vector<std::string> names_ = init_str_printer_names();
356 std::ostringstream o;
357 o << print_math_function(names_[x.get_type_code()]);
358 vec_basic vec = x.get_args();
359 o << parenthesize(apply(vec));
363 void CodePrinter::bvisit(
const RealDouble &x)
365 if (precision_ != CodePrinterPrecision::Double) {
366 str_ = print_scalar_literal(x.i);
368 StrPrinter::bvisit(x);
372 #ifdef HAVE_SYMENGINE_MPFR
373 void CodePrinter::bvisit(
const RealMPFR &x)
375 StrPrinter::bvisit(x);
376 if (precision_ != CodePrinterPrecision::Double) {
377 str_ += print_precision_suffix(precision_);
382 C89CodePrinter::C89CodePrinter(CodePrinterPrecision precision)
383 : BaseVisitor<C89CodePrinter, CodePrinter>(precision)
385 if (precision_ == CodePrinterPrecision::Half) {
386 throw SymEngineException(
387 "C-family code printers do not support half precision");
391 void C89CodePrinter::bvisit(
const Infty &x)
393 std::ostringstream s;
394 if (x.is_negative_infinity())
396 else if (x.is_positive_infinity())
399 throw SymEngineException(
"Not supported");
402 void C89CodePrinter::_print_pow(std::ostringstream &o,
403 const RCP<const Basic> &a,
404 const RCP<const Basic> &b)
407 o << print_math_function(
"exp") <<
"(" << apply(b) <<
")";
409 o << print_math_function(
"sqrt") <<
"(" << apply(a) <<
")";
411 o << print_math_function(
"pow") <<
"(" << apply(a) <<
", " << apply(b)
416 C99CodePrinter::C99CodePrinter(CodePrinterPrecision precision)
417 : BaseVisitor<C99CodePrinter, C89CodePrinter>(precision)
421 void C99CodePrinter::bvisit(
const Infty &x)
423 std::ostringstream s;
424 if (x.is_negative_infinity())
426 else if (x.is_positive_infinity())
429 throw SymEngineException(
"Not supported");
432 void C99CodePrinter::_print_pow(std::ostringstream &o,
433 const RCP<const Basic> &a,
434 const RCP<const Basic> &b)
437 o << print_math_function(
"exp") <<
"(" << apply(b) <<
")";
439 o << print_math_function(
"sqrt") <<
"(" << apply(a) <<
")";
441 o << print_math_function(
"cbrt") <<
"(" << apply(a) <<
")";
443 o << print_math_function(
"pow") <<
"(" << apply(a) <<
", " << apply(b)
447 void C99CodePrinter::bvisit(
const Gamma &x)
449 std::ostringstream s;
450 s << print_math_function(
"tgamma") <<
"(" << apply(x.get_arg()) <<
")";
453 void C99CodePrinter::bvisit(
const LogGamma &x)
455 std::ostringstream s;
456 s << print_math_function(
"lgamma") <<
"(" << apply(x.get_arg()) <<
")";
460 CudaCodePrinter::CudaCodePrinter(CodePrinterPrecision precision)
461 : BaseVisitor<CudaCodePrinter, C99CodePrinter>(precision)
465 void CudaCodePrinter::bvisit(
const Integer &x)
467 str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
469 void CudaCodePrinter::bvisit(
const Constant &x)
472 str_ = print_math_function(
"exp") +
"(" + print_scalar_literal(1.0)
474 }
else if (
eq(x, *pi)) {
475 str_ = print_math_function(
"acos") +
"(" + print_scalar_literal(-1.0)
482 void CudaCodePrinter::bvisit(
const NaN &x)
484 str_ = precision_ == CodePrinterPrecision::Float ?
"CUDART_NAN_F"
488 void CudaCodePrinter::bvisit(
const Infty &x)
490 if (x.is_negative_infinity())
491 str_ = precision_ == CodePrinterPrecision::Float ?
"-CUDART_INF_F"
493 else if (x.is_positive_infinity())
494 str_ = precision_ == CodePrinterPrecision::Float ?
"CUDART_INF_F"
497 throw SymEngineException(
"Not supported");
500 MetalCodePrinter::MetalCodePrinter(CodePrinterPrecision precision)
501 : BaseVisitor<MetalCodePrinter, CodePrinter>(precision)
503 if (precision_ == CodePrinterPrecision::Double) {
504 throw SymEngineException(
505 "Metal code printer currently only supports float and half "
510 void MetalCodePrinter::bvisit(
const Constant &x)
513 str_ =
"exp(" + print_scalar_literal(1.0) +
")";
514 }
else if (
eq(x, *pi)) {
515 str_ =
"acos(" + print_scalar_literal(-1.0) +
")";
521 void MetalCodePrinter::bvisit(
const NaN &x)
523 if (precision_ == CodePrinterPrecision::Half) {
530 void MetalCodePrinter::bvisit(
const Infty &x)
532 if (x.is_negative_infinity()) {
533 str_ = precision_ == CodePrinterPrecision::Half ?
"-HUGE_VALH"
535 }
else if (x.is_positive_infinity()) {
536 str_ = precision_ == CodePrinterPrecision::Half ?
"HUGE_VALH"
539 throw SymEngineException(
"Not supported");
543 void MetalCodePrinter::bvisit(
const Abs &x)
545 std::ostringstream s;
546 s <<
"fabs(" << apply(x.get_arg()) <<
")";
550 void MetalCodePrinter::bvisit(
const Ceiling &x)
552 std::ostringstream s;
553 s <<
"ceil(" << apply(x.get_arg()) <<
")";
557 void MetalCodePrinter::bvisit(
const Truncate &x)
559 std::ostringstream s;
560 s <<
"trunc(" << apply(x.get_arg()) <<
")";
564 void MetalCodePrinter::bvisit(
const Max &x)
566 str_ = print_binary_reduction(x.get_args(),
"fmax");
569 void MetalCodePrinter::bvisit(
const Min &x)
571 str_ = print_binary_reduction(x.get_args(),
"fmin");
574 void MetalCodePrinter::bvisit(
const Function &x)
576 static const std::vector<std::string> names_ = init_str_printer_names();
577 std::ostringstream o;
578 o << names_[x.get_type_code()];
579 vec_basic vec = x.get_args();
580 o << parenthesize(apply(vec));
584 void MetalCodePrinter::_print_pow(std::ostringstream &o,
585 const RCP<const Basic> &a,
586 const RCP<const Basic> &b)
589 o <<
"exp(" << apply(b) <<
")";
591 o <<
"sqrt(" << apply(a) <<
")";
593 o <<
"pow(" << apply(a) <<
", " << apply(b) <<
")";
597 void JSCodePrinter::bvisit(
const Constant &x)
601 }
else if (
eq(x, *pi)) {
607 void JSCodePrinter::_print_pow(std::ostringstream &o,
const RCP<const Basic> &a,
608 const RCP<const Basic> &b)
611 o <<
"Math.exp(" << apply(b) <<
")";
613 o <<
"Math.sqrt(" << apply(a) <<
")";
615 o <<
"Math.cbrt(" << apply(a) <<
")";
617 o <<
"Math.pow(" << apply(a) <<
", " << apply(b) <<
")";
620 void JSCodePrinter::bvisit(
const Abs &x)
622 std::ostringstream s;
623 s <<
"Math.abs(" << apply(x.get_arg()) <<
")";
626 void JSCodePrinter::bvisit(
const Sin &x)
628 std::ostringstream s;
629 s <<
"Math.sin(" << apply(x.get_arg()) <<
")";
632 void JSCodePrinter::bvisit(
const Cos &x)
634 std::ostringstream s;
635 s <<
"Math.cos(" << apply(x.get_arg()) <<
")";
638 void JSCodePrinter::bvisit(
const Max &x)
640 const auto &args = x.get_args();
641 std::ostringstream s;
643 for (
size_t i = 0; i < args.size(); ++i) {
645 s << ((i == args.size() - 1) ?
")" :
", ");
649 void JSCodePrinter::bvisit(
const Min &x)
651 const auto &args = x.get_args();
652 std::ostringstream s;
654 for (
size_t i = 0; i < args.size(); ++i) {
656 s << ((i == args.size() - 1) ?
")" :
", ");
661 std::string ccode(
const Basic &x, CodePrinterPrecision precision)
663 C99CodePrinter c(precision);
667 std::string cudacode(
const Basic &x, CodePrinterPrecision precision)
669 CudaCodePrinter p(precision);
673 std::string metalcode(
const Basic &x, CodePrinterPrecision precision)
675 MetalCodePrinter p(precision);
679 std::string jscode(
const Basic &x)
685 std::string
inline c89code(
const Basic &x)
691 std::string
inline c99code(
const Basic &x)
Main namespace for SymEngine package.
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Number > rational(long n, long d)
convenience creator from two longs