codegen.cpp
1 #include <symengine/printers/codegen.h>
2 #include <symengine/printers.h>
3 #include <symengine/symengine_exception.h>
4 
5 namespace SymEngine
6 {
7 
8 namespace
9 {
10 
11 const char *print_precision_suffix(CodePrinterPrecision precision)
12 {
13  switch (precision) {
14  case CodePrinterPrecision::Double:
15  return "";
16  case CodePrinterPrecision::Float:
17  return "f";
18  case CodePrinterPrecision::Half:
19  return "h";
20  default:
21  throw SymEngineException("Unknown code printer precision");
22  }
23 }
24 
25 CodePrinterPrecision
26 normalize_code_printer_precision(CodePrinterPrecision precision)
27 {
28  switch (precision) {
29  case CodePrinterPrecision::Double:
30  case CodePrinterPrecision::Float:
31  case CodePrinterPrecision::Half:
32  return precision;
33  default:
34  throw SymEngineException("Unknown code printer precision");
35  }
36 }
37 
38 } // namespace
39 
40 CodePrinter::CodePrinter(CodePrinterPrecision precision)
41  : precision_{normalize_code_printer_precision(precision)}
42 {
43 }
44 
45 std::string CodePrinter::print_scalar_literal(double d) const
46 {
47  return print_double(d) + print_precision_suffix(precision_);
48 }
49 
50 std::string CodePrinter::print_math_function(const std::string &name) const
51 {
52  if (precision_ == CodePrinterPrecision::Float) {
53  return name + "f";
54  }
55  return name;
56 }
57 
58 std::string CodePrinter::print_binary_reduction(const vec_basic &args,
59  const std::string &func_name)
60 {
61  if (args.size() < 2) {
62  throw SymEngineException("Impossible");
63  }
64 
65  return print_binary_reduction_impl(args.begin(), args.end(), func_name);
66 }
67 
68 std::string
69 CodePrinter::print_binary_reduction_impl(vec_basic::const_iterator begin,
70  vec_basic::const_iterator end,
71  const std::string &func_name)
72 {
73  const auto size = static_cast<std::size_t>(std::distance(begin, end));
74  if (size == 1) {
75  return apply(*begin);
76  }
77 
78  const auto mid = begin + size / 2;
79  std::ostringstream s;
80  s << func_name << "(" << print_binary_reduction_impl(begin, mid, func_name)
81  << ", " << print_binary_reduction_impl(mid, end, func_name) << ")";
82  return s.str();
83 }
84 
85 void CodePrinter::bvisit(const Basic &x)
86 {
87  throw SymEngineException("Not supported");
88 }
89 void CodePrinter::bvisit(const Complex &x)
90 {
91  throw NotImplementedError("Not implemented");
92 }
93 void CodePrinter::bvisit(const Dummy &x)
94 {
95  std::ostringstream s;
96  s << x.get_name() << '_' << x.get_index();
97  str_ = s.str();
98 }
99 void CodePrinter::bvisit(const Interval &x)
100 {
101  std::string var = str_;
102  std::ostringstream s;
103  bool is_inf = eq(*x.get_start(), *NegInf);
104  if (not is_inf) {
105  s << var;
106  if (x.get_left_open()) {
107  s << " > ";
108  } else {
109  s << " >= ";
110  }
111  s << apply(x.get_start());
112  }
113  if (neq(*x.get_end(), *Inf)) {
114  if (not is_inf) {
115  s << " && ";
116  }
117  s << var;
118  if (x.get_right_open()) {
119  s << " < ";
120  } else {
121  s << " <= ";
122  }
123  s << apply(x.get_end());
124  }
125  str_ = s.str();
126 }
127 void CodePrinter::bvisit(const Contains &x)
128 {
129  x.get_expr()->accept(*this);
130  x.get_set()->accept(*this);
131 }
132 void CodePrinter::bvisit(const Piecewise &x)
133 {
134  std::ostringstream s;
135  auto vec = x.get_vec();
136  for (size_t i = 0;; ++i) {
137  if (i == vec.size() - 1) {
138  if (neq(*vec[i].second, *boolTrue)) {
139  throw SymEngineException(
140  "Code generation requires a (Expr, True) at the end");
141  }
142  s << "(\n " << apply(vec[i].first) << "\n";
143  break;
144  } else {
145  s << "((";
146  s << apply(vec[i].second);
147  s << ") ? (\n ";
148  s << apply(vec[i].first);
149  s << "\n)\n: ";
150  }
151  }
152  for (size_t i = 0; i < vec.size(); i++) {
153  s << ")";
154  }
155  str_ = s.str();
156 }
157 void CodePrinter::bvisit(const BooleanAtom &x)
158 {
159  str_ = print_scalar_literal(x.get_val() ? 1.0 : 0.0);
160 }
161 void CodePrinter::bvisit(const Integer &x)
162 {
163  if (precision_ != CodePrinterPrecision::Double) {
164  str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
165  } else {
166  StrPrinter::bvisit(x);
167  }
168 }
169 void CodePrinter::bvisit(const And &x)
170 {
171  std::ostringstream s;
172  const auto &container = x.get_container();
173  s << "(";
174  for (auto it = container.begin(); it != container.end(); ++it) {
175  if (it != container.begin()) {
176  s << " && ";
177  }
178  s << "(" << apply(*(*it)) << ")";
179  }
180  s << ")";
181  str_ = s.str();
182 }
183 void CodePrinter::bvisit(const Or &x)
184 {
185  std::ostringstream s;
186  const auto &container = x.get_container();
187  s << "(";
188  for (auto it = container.begin(); it != container.end(); ++it) {
189  if (it != container.begin()) {
190  s << " || ";
191  }
192  s << "(" << apply(*(*it)) << ")";
193  }
194  s << ")";
195  str_ = s.str();
196 }
197 void CodePrinter::bvisit(const Xor &x)
198 {
199  std::ostringstream s;
200  const auto &container = x.get_container();
201  s << "(";
202  for (auto it = container.begin(); it != container.end(); ++it) {
203  if (it != container.begin()) {
204  s << " != ";
205  }
206  s << "((" << apply(*(*it)) << ") != 0)";
207  }
208  s << ")";
209  str_ = s.str();
210 }
211 void CodePrinter::bvisit(const Not &x)
212 {
213  std::ostringstream s;
214  s << "!(" << apply(*x.get_arg()) << ")";
215  str_ = s.str();
216 }
217 void CodePrinter::bvisit(const Rational &x)
218 {
219  std::ostringstream o;
220  double n = mp_get_d(get_num(x.as_rational_class()));
221  double d = mp_get_d(get_den(x.as_rational_class()));
222  o << print_scalar_literal(n) << "/" << print_scalar_literal(d);
223  str_ = o.str();
224 }
225 void CodePrinter::bvisit(const Reals &x)
226 {
227  throw SymEngineException("Not supported");
228 }
229 void CodePrinter::bvisit(const Rationals &x)
230 {
231  throw SymEngineException("Not supported");
232 }
233 void CodePrinter::bvisit(const Integers &x)
234 {
235  throw SymEngineException("Not supported");
236 }
237 void CodePrinter::bvisit(const EmptySet &x)
238 {
239  throw SymEngineException("Not supported");
240 }
241 void CodePrinter::bvisit(const FiniteSet &x)
242 {
243  throw SymEngineException("Not supported");
244 }
245 void CodePrinter::bvisit(const UniversalSet &x)
246 {
247  throw SymEngineException("Not supported");
248 }
249 void CodePrinter::bvisit(const Abs &x)
250 {
251  std::ostringstream s;
252  s << print_math_function("fabs") << "(" << apply(x.get_arg()) << ")";
253  str_ = s.str();
254 }
255 void CodePrinter::bvisit(const Ceiling &x)
256 {
257  std::ostringstream s;
258  s << print_math_function("ceil") << "(" << apply(x.get_arg()) << ")";
259  str_ = s.str();
260 }
261 void CodePrinter::bvisit(const Truncate &x)
262 {
263  std::ostringstream s;
264  s << print_math_function("trunc") << "(" << apply(x.get_arg()) << ")";
265  str_ = s.str();
266 }
267 void CodePrinter::bvisit(const Max &x)
268 {
269  str_ = print_binary_reduction(x.get_args(), print_math_function("fmax"));
270 }
271 void CodePrinter::bvisit(const Min &x)
272 {
273  str_ = print_binary_reduction(x.get_args(), print_math_function("fmin"));
274 }
275 void CodePrinter::bvisit(const Constant &x)
276 {
277  if (eq(x, *E)) {
278  str_ = precision_ != CodePrinterPrecision::Double
279  ? print_math_function("exp") + "("
280  + print_scalar_literal(1.0) + ")"
281  : "exp(1)";
282  } else if (eq(x, *pi)) {
283  str_ = precision_ != CodePrinterPrecision::Double
284  ? print_math_function("acos") + "("
285  + print_scalar_literal(-1.0) + ")"
286  : "acos(-1)";
287  } else {
288  str_ = x.get_name();
289  }
290 }
291 void CodePrinter::bvisit(const NaN &x)
292 {
293  std::ostringstream s;
294  s << "NAN";
295  str_ = s.str();
296 }
297 void CodePrinter::bvisit(const Equality &x)
298 {
299  std::ostringstream s;
300  s << apply(x.get_arg1()) << " == " << apply(x.get_arg2());
301  str_ = s.str();
302 }
303 void CodePrinter::bvisit(const Unequality &x)
304 {
305  std::ostringstream s;
306  s << apply(x.get_arg1()) << " != " << apply(x.get_arg2());
307  str_ = s.str();
308 }
309 void CodePrinter::bvisit(const LessThan &x)
310 {
311  std::ostringstream s;
312  s << apply(x.get_arg1()) << " <= " << apply(x.get_arg2());
313  str_ = s.str();
314 }
315 void CodePrinter::bvisit(const StrictLessThan &x)
316 {
317  std::ostringstream s;
318  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
319  str_ = s.str();
320 }
321 void CodePrinter::bvisit(const Sign &x)
322 {
323  const std::string arg = apply(x.get_arg());
324  const std::string zero = print_scalar_literal(0.0);
325  const std::string one = print_scalar_literal(1.0);
326  const std::string minus_one = print_scalar_literal(-1.0);
327  std::ostringstream s;
328  s << "((" << arg << " == " << zero << ") ? (" << zero << ") : ((" << arg
329  << " < " << zero << ") ? (" << minus_one << ") : (" << one << ")))";
330  str_ = s.str();
331 }
332 void CodePrinter::bvisit(const UnevaluatedExpr &x)
333 {
334  str_ = apply(x.get_arg());
335 }
336 void CodePrinter::bvisit(const UnivariateSeries &x)
337 {
338  throw SymEngineException("Not supported");
339 }
340 void CodePrinter::bvisit(const Derivative &x)
341 {
342  throw SymEngineException("Not supported");
343 }
344 void CodePrinter::bvisit(const Subs &x)
345 {
346  throw SymEngineException("Not supported");
347 }
348 void CodePrinter::bvisit(const GaloisField &x)
349 {
350  throw SymEngineException("Not supported");
351 }
352 
353 void CodePrinter::bvisit(const Function &x)
354 {
355  static const std::vector<std::string> names_ = init_str_printer_names();
356  std::ostringstream o;
357  o << print_math_function(names_[x.get_type_code()]);
358  vec_basic vec = x.get_args();
359  o << parenthesize(apply(vec));
360  str_ = o.str();
361 }
362 
363 void CodePrinter::bvisit(const RealDouble &x)
364 {
365  if (precision_ != CodePrinterPrecision::Double) {
366  str_ = print_scalar_literal(x.i);
367  } else {
368  StrPrinter::bvisit(x);
369  }
370 }
371 
372 #ifdef HAVE_SYMENGINE_MPFR
373 void CodePrinter::bvisit(const RealMPFR &x)
374 {
375  StrPrinter::bvisit(x);
376  if (precision_ != CodePrinterPrecision::Double) {
377  str_ += print_precision_suffix(precision_);
378  }
379 }
380 #endif
381 
382 C89CodePrinter::C89CodePrinter(CodePrinterPrecision precision)
383  : BaseVisitor<C89CodePrinter, CodePrinter>(precision)
384 {
385  if (precision_ == CodePrinterPrecision::Half) {
386  throw SymEngineException(
387  "C-family code printers do not support half precision");
388  }
389 }
390 
391 void C89CodePrinter::bvisit(const Infty &x)
392 {
393  std::ostringstream s;
394  if (x.is_negative_infinity())
395  s << "-HUGE_VAL";
396  else if (x.is_positive_infinity())
397  s << "HUGE_VAL";
398  else
399  throw SymEngineException("Not supported");
400  str_ = s.str();
401 }
402 void C89CodePrinter::_print_pow(std::ostringstream &o,
403  const RCP<const Basic> &a,
404  const RCP<const Basic> &b)
405 {
406  if (eq(*a, *E)) {
407  o << print_math_function("exp") << "(" << apply(b) << ")";
408  } else if (eq(*b, *rational(1, 2))) {
409  o << print_math_function("sqrt") << "(" << apply(a) << ")";
410  } else {
411  o << print_math_function("pow") << "(" << apply(a) << ", " << apply(b)
412  << ")";
413  }
414 }
415 
416 C99CodePrinter::C99CodePrinter(CodePrinterPrecision precision)
417  : BaseVisitor<C99CodePrinter, C89CodePrinter>(precision)
418 {
419 }
420 
421 void C99CodePrinter::bvisit(const Infty &x)
422 {
423  std::ostringstream s;
424  if (x.is_negative_infinity())
425  s << "-INFINITY";
426  else if (x.is_positive_infinity())
427  s << "INFINITY";
428  else
429  throw SymEngineException("Not supported");
430  str_ = s.str();
431 }
432 void C99CodePrinter::_print_pow(std::ostringstream &o,
433  const RCP<const Basic> &a,
434  const RCP<const Basic> &b)
435 {
436  if (eq(*a, *E)) {
437  o << print_math_function("exp") << "(" << apply(b) << ")";
438  } else if (eq(*b, *rational(1, 2))) {
439  o << print_math_function("sqrt") << "(" << apply(a) << ")";
440  } else if (eq(*b, *rational(1, 3))) {
441  o << print_math_function("cbrt") << "(" << apply(a) << ")";
442  } else {
443  o << print_math_function("pow") << "(" << apply(a) << ", " << apply(b)
444  << ")";
445  }
446 }
447 void C99CodePrinter::bvisit(const Gamma &x)
448 {
449  std::ostringstream s;
450  s << print_math_function("tgamma") << "(" << apply(x.get_arg()) << ")";
451  str_ = s.str();
452 }
453 void C99CodePrinter::bvisit(const LogGamma &x)
454 {
455  std::ostringstream s;
456  s << print_math_function("lgamma") << "(" << apply(x.get_arg()) << ")";
457  str_ = s.str();
458 }
459 
460 CudaCodePrinter::CudaCodePrinter(CodePrinterPrecision precision)
461  : BaseVisitor<CudaCodePrinter, C99CodePrinter>(precision)
462 {
463 }
464 
465 void CudaCodePrinter::bvisit(const Integer &x)
466 {
467  str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
468 }
469 void CudaCodePrinter::bvisit(const Constant &x)
470 {
471  if (eq(x, *E)) {
472  str_ = print_math_function("exp") + "(" + print_scalar_literal(1.0)
473  + ")";
474  } else if (eq(x, *pi)) {
475  str_ = print_math_function("acos") + "(" + print_scalar_literal(-1.0)
476  + ")";
477  } else {
478  str_ = x.get_name();
479  }
480 }
481 
482 void CudaCodePrinter::bvisit(const NaN &x)
483 {
484  str_ = precision_ == CodePrinterPrecision::Float ? "CUDART_NAN_F"
485  : "CUDART_NAN";
486 }
487 
488 void CudaCodePrinter::bvisit(const Infty &x)
489 {
490  if (x.is_negative_infinity())
491  str_ = precision_ == CodePrinterPrecision::Float ? "-CUDART_INF_F"
492  : "-CUDART_INF";
493  else if (x.is_positive_infinity())
494  str_ = precision_ == CodePrinterPrecision::Float ? "CUDART_INF_F"
495  : "CUDART_INF";
496  else
497  throw SymEngineException("Not supported");
498 }
499 
500 MetalCodePrinter::MetalCodePrinter(CodePrinterPrecision precision)
501  : BaseVisitor<MetalCodePrinter, CodePrinter>(precision)
502 {
503  if (precision_ == CodePrinterPrecision::Double) {
504  throw SymEngineException(
505  "Metal code printer currently only supports float and half "
506  "precision");
507  }
508 }
509 
510 void MetalCodePrinter::bvisit(const Constant &x)
511 {
512  if (eq(x, *E)) {
513  str_ = "exp(" + print_scalar_literal(1.0) + ")";
514  } else if (eq(x, *pi)) {
515  str_ = "acos(" + print_scalar_literal(-1.0) + ")";
516  } else {
517  str_ = x.get_name();
518  }
519 }
520 
521 void MetalCodePrinter::bvisit(const NaN &x)
522 {
523  if (precision_ == CodePrinterPrecision::Half) {
524  str_ = "half(NAN)";
525  } else {
526  str_ = "NAN";
527  }
528 }
529 
530 void MetalCodePrinter::bvisit(const Infty &x)
531 {
532  if (x.is_negative_infinity()) {
533  str_ = precision_ == CodePrinterPrecision::Half ? "-HUGE_VALH"
534  : "-INFINITY";
535  } else if (x.is_positive_infinity()) {
536  str_ = precision_ == CodePrinterPrecision::Half ? "HUGE_VALH"
537  : "INFINITY";
538  } else {
539  throw SymEngineException("Not supported");
540  }
541 }
542 
543 void MetalCodePrinter::bvisit(const Abs &x)
544 {
545  std::ostringstream s;
546  s << "fabs(" << apply(x.get_arg()) << ")";
547  str_ = s.str();
548 }
549 
550 void MetalCodePrinter::bvisit(const Ceiling &x)
551 {
552  std::ostringstream s;
553  s << "ceil(" << apply(x.get_arg()) << ")";
554  str_ = s.str();
555 }
556 
557 void MetalCodePrinter::bvisit(const Truncate &x)
558 {
559  std::ostringstream s;
560  s << "trunc(" << apply(x.get_arg()) << ")";
561  str_ = s.str();
562 }
563 
564 void MetalCodePrinter::bvisit(const Max &x)
565 {
566  str_ = print_binary_reduction(x.get_args(), "fmax");
567 }
568 
569 void MetalCodePrinter::bvisit(const Min &x)
570 {
571  str_ = print_binary_reduction(x.get_args(), "fmin");
572 }
573 
574 void MetalCodePrinter::bvisit(const Function &x)
575 {
576  static const std::vector<std::string> names_ = init_str_printer_names();
577  std::ostringstream o;
578  o << names_[x.get_type_code()];
579  vec_basic vec = x.get_args();
580  o << parenthesize(apply(vec));
581  str_ = o.str();
582 }
583 
584 void MetalCodePrinter::_print_pow(std::ostringstream &o,
585  const RCP<const Basic> &a,
586  const RCP<const Basic> &b)
587 {
588  if (eq(*a, *E)) {
589  o << "exp(" << apply(b) << ")";
590  } else if (eq(*b, *rational(1, 2))) {
591  o << "sqrt(" << apply(a) << ")";
592  } else {
593  o << "pow(" << apply(a) << ", " << apply(b) << ")";
594  }
595 }
596 
597 void JSCodePrinter::bvisit(const Constant &x)
598 {
599  if (eq(x, *E)) {
600  str_ = "Math.E";
601  } else if (eq(x, *pi)) {
602  str_ = "Math.PI";
603  } else {
604  str_ = x.get_name();
605  }
606 }
607 void JSCodePrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
608  const RCP<const Basic> &b)
609 {
610  if (eq(*a, *E)) {
611  o << "Math.exp(" << apply(b) << ")";
612  } else if (eq(*b, *rational(1, 2))) {
613  o << "Math.sqrt(" << apply(a) << ")";
614  } else if (eq(*b, *rational(1, 3))) {
615  o << "Math.cbrt(" << apply(a) << ")";
616  } else {
617  o << "Math.pow(" << apply(a) << ", " << apply(b) << ")";
618  }
619 }
620 void JSCodePrinter::bvisit(const Abs &x)
621 {
622  std::ostringstream s;
623  s << "Math.abs(" << apply(x.get_arg()) << ")";
624  str_ = s.str();
625 }
626 void JSCodePrinter::bvisit(const Sin &x)
627 {
628  std::ostringstream s;
629  s << "Math.sin(" << apply(x.get_arg()) << ")";
630  str_ = s.str();
631 }
632 void JSCodePrinter::bvisit(const Cos &x)
633 {
634  std::ostringstream s;
635  s << "Math.cos(" << apply(x.get_arg()) << ")";
636  str_ = s.str();
637 }
638 void JSCodePrinter::bvisit(const Max &x)
639 {
640  const auto &args = x.get_args();
641  std::ostringstream s;
642  s << "Math.max(";
643  for (size_t i = 0; i < args.size(); ++i) {
644  s << apply(args[i]);
645  s << ((i == args.size() - 1) ? ")" : ", ");
646  }
647  str_ = s.str();
648 }
649 void JSCodePrinter::bvisit(const Min &x)
650 {
651  const auto &args = x.get_args();
652  std::ostringstream s;
653  s << "Math.min(";
654  for (size_t i = 0; i < args.size(); ++i) {
655  s << apply(args[i]);
656  s << ((i == args.size() - 1) ? ")" : ", ");
657  }
658  str_ = s.str();
659 }
660 
661 std::string ccode(const Basic &x, CodePrinterPrecision precision)
662 {
663  C99CodePrinter c(precision);
664  return c.apply(x);
665 }
666 
667 std::string cudacode(const Basic &x, CodePrinterPrecision precision)
668 {
669  CudaCodePrinter p(precision);
670  return p.apply(x);
671 }
672 
673 std::string metalcode(const Basic &x, CodePrinterPrecision precision)
674 {
675  MetalCodePrinter p(precision);
676  return p.apply(x);
677 }
678 
679 std::string jscode(const Basic &x)
680 {
681  JSCodePrinter p;
682  return p.apply(x);
683 }
684 
685 std::string inline c89code(const Basic &x)
686 {
687  C89CodePrinter p;
688  return p.apply(x);
689 }
690 
691 std::string inline c99code(const Basic &x)
692 {
693  C99CodePrinter p;
694  return p.apply(x);
695 }
696 
697 } // namespace SymEngine
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328