codegen.cpp
1 #include <symengine/printers/codegen.h>
2 #include <symengine/constants.h>
3 #include <symengine/mul.h>
4 #include <symengine/visitor.h>
5 #include <symengine/printers.h>
6 #include <symengine/symengine_exception.h>
7 
8 namespace SymEngine
9 {
10 
11 namespace
12 {
13 
14 const char *print_precision_suffix(CodePrinterPrecision precision)
15 {
16  switch (precision) {
17  case CodePrinterPrecision::Double:
18  return "";
19  case CodePrinterPrecision::Float:
20  return "f";
21  case CodePrinterPrecision::Half:
22  return "h";
23  default:
24  throw SymEngineException("Unknown code printer precision");
25  }
26 }
27 
28 CodePrinterPrecision
29 normalize_code_printer_precision(CodePrinterPrecision precision)
30 {
31  switch (precision) {
32  case CodePrinterPrecision::Double:
33  case CodePrinterPrecision::Float:
34  case CodePrinterPrecision::Half:
35  return precision;
36  default:
37  throw SymEngineException("Unknown code printer precision");
38  }
39 }
40 
41 } // namespace
42 
43 CodePrinter::CodePrinter(CodePrinterPrecision precision)
44  : precision_{normalize_code_printer_precision(precision)}
45 {
46 }
47 
48 std::string CodePrinter::print_scalar_literal(double d) const
49 {
50  return print_double(d) + print_precision_suffix(precision_);
51 }
52 
53 std::string CodePrinter::print_math_function(const std::string &name) const
54 {
55  if (precision_ == CodePrinterPrecision::Float) {
56  return name + "f";
57  }
58  return name;
59 }
60 
61 std::string
62 CodePrinter::format_codegen_function_name(const std::string &name) const
63 {
64  return print_math_function(name);
65 }
66 
67 std::string CodePrinter::print_binary_reduction(const vec_basic &args,
68  const std::string &func_name)
69 {
70  if (args.size() < 2) {
71  throw SymEngineException("Impossible");
72  }
73 
74  return print_binary_reduction_impl(args.begin(), args.end(), func_name);
75 }
76 
77 std::string
78 CodePrinter::print_binary_reduction_impl(vec_basic::const_iterator begin,
79  vec_basic::const_iterator end,
80  const std::string &func_name)
81 {
82  const auto size = static_cast<std::size_t>(std::distance(begin, end));
83  if (size == 1) {
84  return apply(*begin);
85  }
86 
87  const auto mid = begin + size / 2;
88  std::ostringstream s;
89  s << func_name << "(" << print_binary_reduction_impl(begin, mid, func_name)
90  << ", " << print_binary_reduction_impl(mid, end, func_name) << ")";
91  return s.str();
92 }
93 
94 void CodePrinter::bvisit(const Basic &x)
95 {
96  throw SymEngineException("Not supported");
97 }
98 void CodePrinter::bvisit(const Complex &x)
99 {
100  throw NotImplementedError("Not implemented");
101 }
102 void CodePrinter::bvisit(const Dummy &x)
103 {
104  std::ostringstream s;
105  s << x.get_name() << '_' << x.get_index();
106  str_ = s.str();
107 }
108 void CodePrinter::bvisit(const Interval &x)
109 {
110  std::string var = str_;
111  std::ostringstream s;
112  bool is_inf = eq(*x.get_start(), *NegInf);
113  if (not is_inf) {
114  s << var;
115  if (x.get_left_open()) {
116  s << " > ";
117  } else {
118  s << " >= ";
119  }
120  s << apply(x.get_start());
121  }
122  if (neq(*x.get_end(), *Inf)) {
123  if (not is_inf) {
124  s << " && ";
125  }
126  s << var;
127  if (x.get_right_open()) {
128  s << " < ";
129  } else {
130  s << " <= ";
131  }
132  s << apply(x.get_end());
133  }
134  str_ = s.str();
135 }
136 void CodePrinter::bvisit(const Contains &x)
137 {
138  x.get_expr()->accept(*this);
139  x.get_set()->accept(*this);
140 }
141 void CodePrinter::bvisit(const Piecewise &x)
142 {
143  std::ostringstream s;
144  auto vec = x.get_vec();
145  for (size_t i = 0;; ++i) {
146  if (i == vec.size() - 1) {
147  if (neq(*vec[i].second, *boolTrue)) {
148  throw SymEngineException(
149  "Code generation requires a (Expr, True) at the end");
150  }
151  s << "(\n " << apply(vec[i].first) << "\n";
152  break;
153  } else {
154  s << "((";
155  s << apply(vec[i].second);
156  s << ") ? (\n ";
157  s << apply(vec[i].first);
158  s << "\n)\n: ";
159  }
160  }
161  for (size_t i = 0; i < vec.size(); i++) {
162  s << ")";
163  }
164  str_ = s.str();
165 }
166 void CodePrinter::bvisit(const BooleanAtom &x)
167 {
168  str_ = print_scalar_literal(x.get_val() ? 1.0 : 0.0);
169 }
170 void CodePrinter::bvisit(const Integer &x)
171 {
172  if (precision_ != CodePrinterPrecision::Double) {
173  str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
174  } else {
175  StrPrinter::bvisit(x);
176  }
177 }
178 void CodePrinter::bvisit(const And &x)
179 {
180  std::ostringstream s;
181  const auto &container = x.get_container();
182  s << "(";
183  for (auto it = container.begin(); it != container.end(); ++it) {
184  if (it != container.begin()) {
185  s << " && ";
186  }
187  s << "(" << apply(*(*it)) << ")";
188  }
189  s << ")";
190  str_ = s.str();
191 }
192 void CodePrinter::bvisit(const Or &x)
193 {
194  std::ostringstream s;
195  const auto &container = x.get_container();
196  s << "(";
197  for (auto it = container.begin(); it != container.end(); ++it) {
198  if (it != container.begin()) {
199  s << " || ";
200  }
201  s << "(" << apply(*(*it)) << ")";
202  }
203  s << ")";
204  str_ = s.str();
205 }
206 void CodePrinter::bvisit(const Xor &x)
207 {
208  std::ostringstream s;
209  const auto &container = x.get_container();
210  s << "(";
211  for (auto it = container.begin(); it != container.end(); ++it) {
212  if (it != container.begin()) {
213  s << " != ";
214  }
215  s << "((" << apply(*(*it)) << ") != 0)";
216  }
217  s << ")";
218  str_ = s.str();
219 }
220 void CodePrinter::bvisit(const Not &x)
221 {
222  std::ostringstream s;
223  s << "!(" << apply(*x.get_arg()) << ")";
224  str_ = s.str();
225 }
226 void CodePrinter::bvisit(const Rational &x)
227 {
228  std::ostringstream o;
229  double n = mp_get_d(get_num(x.as_rational_class()));
230  double d = mp_get_d(get_den(x.as_rational_class()));
231  o << print_scalar_literal(n) << "/" << print_scalar_literal(d);
232  str_ = o.str();
233 }
234 void CodePrinter::bvisit(const Reals &x)
235 {
236  throw SymEngineException("Not supported");
237 }
238 void CodePrinter::bvisit(const Rationals &x)
239 {
240  throw SymEngineException("Not supported");
241 }
242 void CodePrinter::bvisit(const Integers &x)
243 {
244  throw SymEngineException("Not supported");
245 }
246 void CodePrinter::bvisit(const EmptySet &x)
247 {
248  throw SymEngineException("Not supported");
249 }
250 void CodePrinter::bvisit(const FiniteSet &x)
251 {
252  throw SymEngineException("Not supported");
253 }
254 void CodePrinter::bvisit(const UniversalSet &x)
255 {
256  throw SymEngineException("Not supported");
257 }
258 void CodePrinter::bvisit(const Abs &x)
259 {
260  std::ostringstream s;
261  s << print_math_function("fabs") << "(" << apply(x.get_arg()) << ")";
262  str_ = s.str();
263 }
264 void CodePrinter::bvisit(const Ceiling &x)
265 {
266  std::ostringstream s;
267  s << print_math_function("ceil") << "(" << apply(x.get_arg()) << ")";
268  str_ = s.str();
269 }
270 void CodePrinter::bvisit(const Truncate &x)
271 {
272  std::ostringstream s;
273  s << print_math_function("trunc") << "(" << apply(x.get_arg()) << ")";
274  str_ = s.str();
275 }
276 void CodePrinter::bvisit(const Max &x)
277 {
278  str_ = print_binary_reduction(x.get_args(), print_math_function("fmax"));
279 }
280 void CodePrinter::bvisit(const Min &x)
281 {
282  str_ = print_binary_reduction(x.get_args(), print_math_function("fmin"));
283 }
284 void CodePrinter::bvisit(const Constant &x)
285 {
286  if (eq(x, *E)) {
287  str_ = precision_ != CodePrinterPrecision::Double
288  ? print_math_function("exp") + "("
289  + print_scalar_literal(1.0) + ")"
290  : "exp(1)";
291  } else if (eq(x, *pi)) {
292  str_ = precision_ != CodePrinterPrecision::Double
293  ? print_math_function("acos") + "("
294  + print_scalar_literal(-1.0) + ")"
295  : "acos(-1)";
296  } else {
297  str_ = x.get_name();
298  }
299 }
300 void CodePrinter::bvisit(const NaN &x)
301 {
302  std::ostringstream s;
303  s << "NAN";
304  str_ = s.str();
305 }
306 void CodePrinter::bvisit(const Equality &x)
307 {
308  std::ostringstream s;
309  s << apply(x.get_arg1()) << " == " << apply(x.get_arg2());
310  str_ = s.str();
311 }
312 void CodePrinter::bvisit(const Unequality &x)
313 {
314  std::ostringstream s;
315  s << apply(x.get_arg1()) << " != " << apply(x.get_arg2());
316  str_ = s.str();
317 }
318 void CodePrinter::bvisit(const LessThan &x)
319 {
320  std::ostringstream s;
321  s << apply(x.get_arg1()) << " <= " << apply(x.get_arg2());
322  str_ = s.str();
323 }
324 void CodePrinter::bvisit(const StrictLessThan &x)
325 {
326  std::ostringstream s;
327  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
328  str_ = s.str();
329 }
330 void CodePrinter::bvisit(const Sign &x)
331 {
332  const std::string arg = apply(x.get_arg());
333  const std::string zero = print_scalar_literal(0.0);
334  const std::string one = print_scalar_literal(1.0);
335  const std::string minus_one = print_scalar_literal(-1.0);
336  std::ostringstream s;
337  s << "((" << arg << " == " << zero << ") ? (" << zero << ") : ((" << arg
338  << " < " << zero << ") ? (" << minus_one << ") : (" << one << ")))";
339  str_ = s.str();
340 }
341 void CodePrinter::bvisit(const UnevaluatedExpr &x)
342 {
343  str_ = apply(x.get_arg());
344 }
345 void CodePrinter::bvisit(const UnivariateSeries &x)
346 {
347  throw SymEngineException("Not supported");
348 }
349 void CodePrinter::bvisit(const Derivative &x)
350 {
351  throw SymEngineException("Not supported");
352 }
353 void CodePrinter::bvisit(const Subs &x)
354 {
355  throw SymEngineException("Not supported");
356 }
357 void CodePrinter::bvisit(const GaloisField &x)
358 {
359  throw SymEngineException("Not supported");
360 }
361 
362 void CodePrinter::bvisit(const Function &x)
363 {
364  static const std::vector<std::string> names_ = init_str_printer_names();
365  std::ostringstream o;
366  o << format_codegen_function_name(names_[x.get_type_code()]);
367  vec_basic vec = x.get_args();
368  o << parenthesize(apply(vec));
369  str_ = o.str();
370 }
371 
372 void CodePrinter::bvisit(const RealDouble &x)
373 {
374  if (precision_ != CodePrinterPrecision::Double) {
375  str_ = print_scalar_literal(x.i);
376  } else {
377  StrPrinter::bvisit(x);
378  }
379 }
380 
381 #ifdef HAVE_SYMENGINE_MPFR
382 void CodePrinter::bvisit(const RealMPFR &x)
383 {
384  StrPrinter::bvisit(x);
385  if (precision_ != CodePrinterPrecision::Double) {
386  str_ += print_precision_suffix(precision_);
387  }
388 }
389 #endif
390 
391 C89CodePrinter::C89CodePrinter(CodePrinterPrecision precision)
392  : RewriteTrigVisitor<C89CodePrinter, CodePrinter>(precision)
393 {
394  if (precision_ == CodePrinterPrecision::Half) {
395  throw SymEngineException(
396  "C-family code printers do not support half precision");
397  }
398 }
399 
400 void C89CodePrinter::bvisit(const Infty &x)
401 {
402  std::ostringstream s;
403  if (x.is_negative_infinity())
404  s << "-HUGE_VAL";
405  else if (x.is_positive_infinity())
406  s << "HUGE_VAL";
407  else
408  throw SymEngineException("Not supported");
409  str_ = s.str();
410 }
411 void C89CodePrinter::_print_pow(std::ostringstream &o,
412  const RCP<const Basic> &a,
413  const RCP<const Basic> &b)
414 {
415  if (eq(*a, *E)) {
416  o << print_math_function("exp") << "(" << apply(b) << ")";
417  } else if (eq(*b, *minus_one)) {
418  o << apply(*one) << "/" << parenthesizeLE(a, PrecedenceEnum::Mul);
419  } else if (eq(*b, *rational(1, 2))) {
420  o << print_math_function("sqrt") << "(" << apply(a) << ")";
421  } else {
422  o << print_math_function("pow") << "(" << apply(a) << ", " << apply(b)
423  << ")";
424  }
425 }
426 
427 C99CodePrinter::C99CodePrinter(CodePrinterPrecision precision)
428  : RewriteTrigVisitor<C99CodePrinter, C89CodePrinter>(precision)
429 {
430 }
431 
432 void C99CodePrinter::bvisit(const Infty &x)
433 {
434  std::ostringstream s;
435  if (x.is_negative_infinity())
436  s << "-INFINITY";
437  else if (x.is_positive_infinity())
438  s << "INFINITY";
439  else
440  throw SymEngineException("Not supported");
441  str_ = s.str();
442 }
443 void C99CodePrinter::_print_pow(std::ostringstream &o,
444  const RCP<const Basic> &a,
445  const RCP<const Basic> &b)
446 {
447  if (eq(*a, *E)) {
448  o << print_math_function("exp") << "(" << apply(b) << ")";
449  } else if (eq(*b, *minus_one)) {
450  o << apply(*one) << "/" << parenthesizeLE(a, PrecedenceEnum::Mul);
451  } else if (eq(*b, *rational(1, 2))) {
452  o << print_math_function("sqrt") << "(" << apply(a) << ")";
453  } else if (eq(*b, *rational(1, 3))) {
454  o << print_math_function("cbrt") << "(" << apply(a) << ")";
455  } else {
456  o << print_math_function("pow") << "(" << apply(a) << ", " << apply(b)
457  << ")";
458  }
459 }
460 void C99CodePrinter::bvisit(const Gamma &x)
461 {
462  std::ostringstream s;
463  s << print_math_function("tgamma") << "(" << apply(x.get_arg()) << ")";
464  str_ = s.str();
465 }
466 void C99CodePrinter::bvisit(const LogGamma &x)
467 {
468  std::ostringstream s;
469  s << print_math_function("lgamma") << "(" << apply(x.get_arg()) << ")";
470  str_ = s.str();
471 }
472 
473 CudaCodePrinter::CudaCodePrinter(CodePrinterPrecision precision)
474  : RewriteTrigVisitor<CudaCodePrinter, C99CodePrinter>(precision)
475 {
476 }
477 
478 void CudaCodePrinter::bvisit(const Integer &x)
479 {
480  str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
481 }
482 void CudaCodePrinter::bvisit(const Constant &x)
483 {
484  if (eq(x, *E)) {
485  str_ = print_math_function("exp") + "(" + print_scalar_literal(1.0)
486  + ")";
487  } else if (eq(x, *pi)) {
488  str_ = print_math_function("acos") + "(" + print_scalar_literal(-1.0)
489  + ")";
490  } else {
491  str_ = x.get_name();
492  }
493 }
494 
495 void CudaCodePrinter::bvisit(const NaN &x)
496 {
497  str_ = precision_ == CodePrinterPrecision::Float ? "CUDART_NAN_F"
498  : "CUDART_NAN";
499 }
500 
501 void CudaCodePrinter::bvisit(const Infty &x)
502 {
503  if (x.is_negative_infinity())
504  str_ = precision_ == CodePrinterPrecision::Float ? "-CUDART_INF_F"
505  : "-CUDART_INF";
506  else if (x.is_positive_infinity())
507  str_ = precision_ == CodePrinterPrecision::Float ? "CUDART_INF_F"
508  : "CUDART_INF";
509  else
510  throw SymEngineException("Not supported");
511 }
512 
513 MetalCodePrinter::MetalCodePrinter(CodePrinterPrecision precision)
514  : RewriteTrigVisitor<MetalCodePrinter, CodePrinter>(precision)
515 {
516  if (precision_ == CodePrinterPrecision::Double) {
517  throw SymEngineException(
518  "Metal code printer currently only supports float and half "
519  "precision");
520  }
521 }
522 
523 void MetalCodePrinter::bvisit(const Constant &x)
524 {
525  if (eq(x, *E)) {
526  str_ = "exp(" + print_scalar_literal(1.0) + ")";
527  } else if (eq(x, *pi)) {
528  str_ = "acos(" + print_scalar_literal(-1.0) + ")";
529  } else {
530  str_ = x.get_name();
531  }
532 }
533 
534 void MetalCodePrinter::bvisit(const NaN &x)
535 {
536  if (precision_ == CodePrinterPrecision::Half) {
537  str_ = "half(NAN)";
538  } else {
539  str_ = "NAN";
540  }
541 }
542 
543 void MetalCodePrinter::bvisit(const Infty &x)
544 {
545  if (x.is_negative_infinity()) {
546  str_ = precision_ == CodePrinterPrecision::Half ? "-HUGE_VALH"
547  : "-INFINITY";
548  } else if (x.is_positive_infinity()) {
549  str_ = precision_ == CodePrinterPrecision::Half ? "HUGE_VALH"
550  : "INFINITY";
551  } else {
552  throw SymEngineException("Not supported");
553  }
554 }
555 
556 void MetalCodePrinter::bvisit(const Abs &x)
557 {
558  std::ostringstream s;
559  s << "fabs(" << apply(x.get_arg()) << ")";
560  str_ = s.str();
561 }
562 
563 void MetalCodePrinter::bvisit(const Ceiling &x)
564 {
565  std::ostringstream s;
566  s << "ceil(" << apply(x.get_arg()) << ")";
567  str_ = s.str();
568 }
569 
570 void MetalCodePrinter::bvisit(const Truncate &x)
571 {
572  std::ostringstream s;
573  s << "trunc(" << apply(x.get_arg()) << ")";
574  str_ = s.str();
575 }
576 
577 void MetalCodePrinter::bvisit(const Max &x)
578 {
579  str_ = print_binary_reduction(x.get_args(), "fmax");
580 }
581 
582 void MetalCodePrinter::bvisit(const Min &x)
583 {
584  str_ = print_binary_reduction(x.get_args(), "fmin");
585 }
586 
587 std::string
588 MetalCodePrinter::format_codegen_function_name(const std::string &name) const
589 {
590  return name;
591 }
592 
593 void MetalCodePrinter::_print_pow(std::ostringstream &o,
594  const RCP<const Basic> &a,
595  const RCP<const Basic> &b)
596 {
597  if (eq(*a, *E)) {
598  o << "exp(" << apply(b) << ")";
599  } else if (eq(*b, *minus_one)) {
600  o << apply(*one) << "/" << parenthesizeLE(a, PrecedenceEnum::Mul);
601  } else if (eq(*b, *rational(1, 2))) {
602  o << "sqrt(" << apply(a) << ")";
603  } else {
604  o << "pow(" << apply(a) << ", " << apply(b) << ")";
605  }
606 }
607 
608 void JSCodePrinter::bvisit(const Constant &x)
609 {
610  if (eq(x, *E)) {
611  str_ = "Math.E";
612  } else if (eq(x, *pi)) {
613  str_ = "Math.PI";
614  } else {
615  str_ = x.get_name();
616  }
617 }
618 void JSCodePrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
619  const RCP<const Basic> &b)
620 {
621  if (eq(*a, *E)) {
622  o << "Math.exp(" << apply(b) << ")";
623  } else if (eq(*b, *minus_one)) {
624  o << apply(*one) << "/" << parenthesizeLE(a, PrecedenceEnum::Mul);
625  } else if (eq(*b, *rational(1, 2))) {
626  o << "Math.sqrt(" << apply(a) << ")";
627  } else if (eq(*b, *rational(1, 3))) {
628  o << "Math.cbrt(" << apply(a) << ")";
629  } else {
630  o << "Math.pow(" << apply(a) << ", " << apply(b) << ")";
631  }
632 }
633 void JSCodePrinter::bvisit(const Abs &x)
634 {
635  std::ostringstream s;
636  s << "Math.abs(" << apply(x.get_arg()) << ")";
637  str_ = s.str();
638 }
639 void JSCodePrinter::bvisit(const Sin &x)
640 {
641  std::ostringstream s;
642  s << "Math.sin(" << apply(x.get_arg()) << ")";
643  str_ = s.str();
644 }
645 void JSCodePrinter::bvisit(const Cos &x)
646 {
647  std::ostringstream s;
648  s << "Math.cos(" << apply(x.get_arg()) << ")";
649  str_ = s.str();
650 }
651 void JSCodePrinter::bvisit(const Max &x)
652 {
653  const auto &args = x.get_args();
654  std::ostringstream s;
655  s << "Math.max(";
656  for (size_t i = 0; i < args.size(); ++i) {
657  s << apply(args[i]);
658  s << ((i == args.size() - 1) ? ")" : ", ");
659  }
660  str_ = s.str();
661 }
662 void JSCodePrinter::bvisit(const Min &x)
663 {
664  const auto &args = x.get_args();
665  std::ostringstream s;
666  s << "Math.min(";
667  for (size_t i = 0; i < args.size(); ++i) {
668  s << apply(args[i]);
669  s << ((i == args.size() - 1) ? ")" : ", ");
670  }
671  str_ = s.str();
672 }
673 
674 std::string
675 JSCodePrinter::format_codegen_function_name(const std::string &name) const
676 {
677  return "Math." + name;
678 }
679 
680 std::string ccode(const Basic &x, CodePrinterPrecision precision)
681 {
682  C99CodePrinter c(precision);
683  return c.apply(x);
684 }
685 
686 std::string cudacode(const Basic &x, CodePrinterPrecision precision)
687 {
688  CudaCodePrinter p(precision);
689  return p.apply(x);
690 }
691 
692 std::string metalcode(const Basic &x, CodePrinterPrecision precision)
693 {
694  MetalCodePrinter p(precision);
695  return p.apply(x);
696 }
697 
698 std::string jscode(const Basic &x)
699 {
700  JSCodePrinter p;
701  return p.apply(x);
702 }
703 
704 std::string inline c89code(const Basic &x)
705 {
706  C89CodePrinter p;
707  return p.apply(x);
708 }
709 
710 std::string inline c99code(const Basic &x)
711 {
712  C99CodePrinter p;
713  return p.apply(x);
714 }
715 
716 } // namespace SymEngine
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328