codegen.cpp
1 #include <symengine/printers/codegen.h>
2 #include <symengine/printers.h>
3 #include <symengine/symengine_exception.h>
4 
5 namespace SymEngine
6 {
7 
8 namespace
9 {
10 
11 std::string print_float_literal(double d)
12 {
13  return print_double(d) + "f";
14 }
15 
16 CodePrinterPrecision
17 normalize_code_printer_precision(CodePrinterPrecision precision)
18 {
19  switch (precision) {
20  case CodePrinterPrecision::Double:
21  case CodePrinterPrecision::Float:
22  return precision;
23  default:
24  throw SymEngineException("Unknown code printer precision");
25  }
26 }
27 
28 } // namespace
29 
30 CodePrinter::CodePrinter(CodePrinterPrecision precision)
31  : precision_{normalize_code_printer_precision(precision)}
32 {
33 }
34 
35 std::string CodePrinter::print_scalar_literal(double d) const
36 {
37  if (precision_ == CodePrinterPrecision::Float) {
38  return print_float_literal(d);
39  }
40  return print_double(d);
41 }
42 
43 std::string CodePrinter::print_math_function(const std::string &name) const
44 {
45  if (precision_ == CodePrinterPrecision::Float) {
46  return name + "f";
47  }
48  return name;
49 }
50 
51 void CodePrinter::bvisit(const Basic &x)
52 {
53  throw SymEngineException("Not supported");
54 }
55 void CodePrinter::bvisit(const Complex &x)
56 {
57  throw NotImplementedError("Not implemented");
58 }
59 void CodePrinter::bvisit(const Dummy &x)
60 {
61  std::ostringstream s;
62  s << x.get_name() << '_' << x.get_index();
63  str_ = s.str();
64 }
65 void CodePrinter::bvisit(const Interval &x)
66 {
67  std::string var = str_;
68  std::ostringstream s;
69  bool is_inf = eq(*x.get_start(), *NegInf);
70  if (not is_inf) {
71  s << var;
72  if (x.get_left_open()) {
73  s << " > ";
74  } else {
75  s << " >= ";
76  }
77  s << apply(x.get_start());
78  }
79  if (neq(*x.get_end(), *Inf)) {
80  if (not is_inf) {
81  s << " && ";
82  }
83  s << var;
84  if (x.get_right_open()) {
85  s << " < ";
86  } else {
87  s << " <= ";
88  }
89  s << apply(x.get_end());
90  }
91  str_ = s.str();
92 }
93 void CodePrinter::bvisit(const Contains &x)
94 {
95  x.get_expr()->accept(*this);
96  x.get_set()->accept(*this);
97 }
98 void CodePrinter::bvisit(const Piecewise &x)
99 {
100  std::ostringstream s;
101  auto vec = x.get_vec();
102  for (size_t i = 0;; ++i) {
103  if (i == vec.size() - 1) {
104  if (neq(*vec[i].second, *boolTrue)) {
105  throw SymEngineException(
106  "Code generation requires a (Expr, True) at the end");
107  }
108  s << "(\n " << apply(vec[i].first) << "\n";
109  break;
110  } else {
111  s << "((";
112  s << apply(vec[i].second);
113  s << ") ? (\n ";
114  s << apply(vec[i].first);
115  s << "\n)\n: ";
116  }
117  }
118  for (size_t i = 0; i < vec.size(); i++) {
119  s << ")";
120  }
121  str_ = s.str();
122 }
123 void CodePrinter::bvisit(const BooleanAtom &x)
124 {
125  str_ = print_scalar_literal(x.get_val() ? 1.0 : 0.0);
126 }
127 void CodePrinter::bvisit(const Integer &x)
128 {
129  if (precision_ == CodePrinterPrecision::Float) {
130  str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
131  } else {
132  StrPrinter::bvisit(x);
133  }
134 }
135 void CodePrinter::bvisit(const And &x)
136 {
137  std::ostringstream s;
138  const auto &container = x.get_container();
139  s << "(";
140  for (auto it = container.begin(); it != container.end(); ++it) {
141  if (it != container.begin()) {
142  s << " && ";
143  }
144  s << "(" << apply(*(*it)) << ")";
145  }
146  s << ")";
147  str_ = s.str();
148 }
149 void CodePrinter::bvisit(const Or &x)
150 {
151  std::ostringstream s;
152  const auto &container = x.get_container();
153  s << "(";
154  for (auto it = container.begin(); it != container.end(); ++it) {
155  if (it != container.begin()) {
156  s << " || ";
157  }
158  s << "(" << apply(*(*it)) << ")";
159  }
160  s << ")";
161  str_ = s.str();
162 }
163 void CodePrinter::bvisit(const Xor &x)
164 {
165  std::ostringstream s;
166  const auto &container = x.get_container();
167  s << "(";
168  for (auto it = container.begin(); it != container.end(); ++it) {
169  if (it != container.begin()) {
170  s << " != ";
171  }
172  s << "((" << apply(*(*it)) << ") != 0)";
173  }
174  s << ")";
175  str_ = s.str();
176 }
177 void CodePrinter::bvisit(const Not &x)
178 {
179  std::ostringstream s;
180  s << "!(" << apply(*x.get_arg()) << ")";
181  str_ = s.str();
182 }
183 void CodePrinter::bvisit(const Rational &x)
184 {
185  std::ostringstream o;
186  double n = mp_get_d(get_num(x.as_rational_class()));
187  double d = mp_get_d(get_den(x.as_rational_class()));
188  o << print_scalar_literal(n) << "/" << print_scalar_literal(d);
189  str_ = o.str();
190 }
191 void CodePrinter::bvisit(const Reals &x)
192 {
193  throw SymEngineException("Not supported");
194 }
195 void CodePrinter::bvisit(const Rationals &x)
196 {
197  throw SymEngineException("Not supported");
198 }
199 void CodePrinter::bvisit(const Integers &x)
200 {
201  throw SymEngineException("Not supported");
202 }
203 void CodePrinter::bvisit(const EmptySet &x)
204 {
205  throw SymEngineException("Not supported");
206 }
207 void CodePrinter::bvisit(const FiniteSet &x)
208 {
209  throw SymEngineException("Not supported");
210 }
211 void CodePrinter::bvisit(const UniversalSet &x)
212 {
213  throw SymEngineException("Not supported");
214 }
215 void CodePrinter::bvisit(const Abs &x)
216 {
217  std::ostringstream s;
218  s << print_math_function("fabs") << "(" << apply(x.get_arg()) << ")";
219  str_ = s.str();
220 }
221 void CodePrinter::bvisit(const Ceiling &x)
222 {
223  std::ostringstream s;
224  s << print_math_function("ceil") << "(" << apply(x.get_arg()) << ")";
225  str_ = s.str();
226 }
227 void CodePrinter::bvisit(const Truncate &x)
228 {
229  std::ostringstream s;
230  s << print_math_function("trunc") << "(" << apply(x.get_arg()) << ")";
231  str_ = s.str();
232 }
233 void CodePrinter::bvisit(const Max &x)
234 {
235  std::ostringstream s;
236  const auto &args = x.get_args();
237  switch (args.size()) {
238  case 0:
239  case 1:
240  throw SymEngineException("Impossible");
241  case 2:
242  s << print_math_function("fmax") << "(" << apply(args[0]) << ", "
243  << apply(args[1]) << ")";
244  break;
245  default: {
246  vec_basic inner_args(args.begin() + 1, args.end());
247  auto inner = max(inner_args);
248  s << print_math_function("fmax") << "(" << apply(args[0]) << ", "
249  << apply(inner) << ")";
250  break;
251  }
252  }
253  str_ = s.str();
254 }
255 void CodePrinter::bvisit(const Min &x)
256 {
257  std::ostringstream s;
258  const auto &args = x.get_args();
259  switch (args.size()) {
260  case 0:
261  case 1:
262  throw SymEngineException("Impossible");
263  case 2:
264  s << print_math_function("fmin") << "(" << apply(args[0]) << ", "
265  << apply(args[1]) << ")";
266  break;
267  default: {
268  vec_basic inner_args(args.begin() + 1, args.end());
269  auto inner = min(inner_args);
270  s << print_math_function("fmin") << "(" << apply(args[0]) << ", "
271  << apply(inner) << ")";
272  break;
273  }
274  }
275  str_ = s.str();
276 }
277 void CodePrinter::bvisit(const Constant &x)
278 {
279  if (eq(x, *E)) {
280  str_ = precision_ == CodePrinterPrecision::Float
281  ? print_math_function("exp") + "("
282  + print_scalar_literal(1.0) + ")"
283  : "exp(1)";
284  } else if (eq(x, *pi)) {
285  str_ = precision_ == CodePrinterPrecision::Float
286  ? print_math_function("acos") + "("
287  + print_scalar_literal(-1.0) + ")"
288  : "acos(-1)";
289  } else {
290  str_ = x.get_name();
291  }
292 }
293 void CodePrinter::bvisit(const NaN &x)
294 {
295  std::ostringstream s;
296  s << "NAN";
297  str_ = s.str();
298 }
299 void CodePrinter::bvisit(const Equality &x)
300 {
301  std::ostringstream s;
302  s << apply(x.get_arg1()) << " == " << apply(x.get_arg2());
303  str_ = s.str();
304 }
305 void CodePrinter::bvisit(const Unequality &x)
306 {
307  std::ostringstream s;
308  s << apply(x.get_arg1()) << " != " << apply(x.get_arg2());
309  str_ = s.str();
310 }
311 void CodePrinter::bvisit(const LessThan &x)
312 {
313  std::ostringstream s;
314  s << apply(x.get_arg1()) << " <= " << apply(x.get_arg2());
315  str_ = s.str();
316 }
317 void CodePrinter::bvisit(const StrictLessThan &x)
318 {
319  std::ostringstream s;
320  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
321  str_ = s.str();
322 }
323 void CodePrinter::bvisit(const Sign &x)
324 {
325  const std::string arg = apply(x.get_arg());
326  const std::string zero = print_scalar_literal(0.0);
327  const std::string one = print_scalar_literal(1.0);
328  const std::string minus_one = print_scalar_literal(-1.0);
329  std::ostringstream s;
330  s << "((" << arg << " == " << zero << ") ? (" << zero << ") : ((" << arg
331  << " < " << zero << ") ? (" << minus_one << ") : (" << one << ")))";
332  str_ = s.str();
333 }
334 void CodePrinter::bvisit(const UnevaluatedExpr &x)
335 {
336  str_ = apply(x.get_arg());
337 }
338 void CodePrinter::bvisit(const UnivariateSeries &x)
339 {
340  throw SymEngineException("Not supported");
341 }
342 void CodePrinter::bvisit(const Derivative &x)
343 {
344  throw SymEngineException("Not supported");
345 }
346 void CodePrinter::bvisit(const Subs &x)
347 {
348  throw SymEngineException("Not supported");
349 }
350 void CodePrinter::bvisit(const GaloisField &x)
351 {
352  throw SymEngineException("Not supported");
353 }
354 
355 void CodePrinter::bvisit(const Function &x)
356 {
357  static const std::vector<std::string> names_ = init_str_printer_names();
358  std::ostringstream o;
359  o << print_math_function(names_[x.get_type_code()]);
360  vec_basic vec = x.get_args();
361  o << parenthesize(apply(vec));
362  str_ = o.str();
363 }
364 
365 void CodePrinter::bvisit(const RealDouble &x)
366 {
367  if (precision_ == CodePrinterPrecision::Float) {
368  str_ = print_scalar_literal(x.i);
369  } else {
370  StrPrinter::bvisit(x);
371  }
372 }
373 
374 #ifdef HAVE_SYMENGINE_MPFR
375 void CodePrinter::bvisit(const RealMPFR &x)
376 {
377  StrPrinter::bvisit(x);
378  if (precision_ == CodePrinterPrecision::Float) {
379  str_ += "f";
380  }
381 }
382 #endif
383 
384 C89CodePrinter::C89CodePrinter(CodePrinterPrecision precision)
385  : BaseVisitor<C89CodePrinter, CodePrinter>(precision)
386 {
387 }
388 
389 void C89CodePrinter::bvisit(const Infty &x)
390 {
391  std::ostringstream s;
392  if (x.is_negative_infinity())
393  s << "-HUGE_VAL";
394  else if (x.is_positive_infinity())
395  s << "HUGE_VAL";
396  else
397  throw SymEngineException("Not supported");
398  str_ = s.str();
399 }
400 void C89CodePrinter::_print_pow(std::ostringstream &o,
401  const RCP<const Basic> &a,
402  const RCP<const Basic> &b)
403 {
404  if (eq(*a, *E)) {
405  o << print_math_function("exp") << "(" << apply(b) << ")";
406  } else if (eq(*b, *rational(1, 2))) {
407  o << print_math_function("sqrt") << "(" << apply(a) << ")";
408  } else {
409  o << print_math_function("pow") << "(" << apply(a) << ", " << apply(b)
410  << ")";
411  }
412 }
413 
414 C99CodePrinter::C99CodePrinter(CodePrinterPrecision precision)
415  : BaseVisitor<C99CodePrinter, C89CodePrinter>(precision)
416 {
417 }
418 
419 void C99CodePrinter::bvisit(const Infty &x)
420 {
421  std::ostringstream s;
422  if (x.is_negative_infinity())
423  s << "-INFINITY";
424  else if (x.is_positive_infinity())
425  s << "INFINITY";
426  else
427  throw SymEngineException("Not supported");
428  str_ = s.str();
429 }
430 void C99CodePrinter::_print_pow(std::ostringstream &o,
431  const RCP<const Basic> &a,
432  const RCP<const Basic> &b)
433 {
434  if (eq(*a, *E)) {
435  o << print_math_function("exp") << "(" << apply(b) << ")";
436  } else if (eq(*b, *rational(1, 2))) {
437  o << print_math_function("sqrt") << "(" << apply(a) << ")";
438  } else if (eq(*b, *rational(1, 3))) {
439  o << print_math_function("cbrt") << "(" << apply(a) << ")";
440  } else {
441  o << print_math_function("pow") << "(" << apply(a) << ", " << apply(b)
442  << ")";
443  }
444 }
445 void C99CodePrinter::bvisit(const Gamma &x)
446 {
447  std::ostringstream s;
448  s << print_math_function("tgamma") << "(" << apply(x.get_arg()) << ")";
449  str_ = s.str();
450 }
451 void C99CodePrinter::bvisit(const LogGamma &x)
452 {
453  std::ostringstream s;
454  s << print_math_function("lgamma") << "(" << apply(x.get_arg()) << ")";
455  str_ = s.str();
456 }
457 
458 CudaCodePrinter::CudaCodePrinter(CodePrinterPrecision precision)
459  : BaseVisitor<CudaCodePrinter, C99CodePrinter>(precision)
460 {
461 }
462 
463 void CudaCodePrinter::bvisit(const Integer &x)
464 {
465  str_ = print_scalar_literal(mp_get_d(x.as_integer_class()));
466 }
467 void CudaCodePrinter::bvisit(const Constant &x)
468 {
469  if (eq(x, *E)) {
470  str_ = print_math_function("exp") + "(" + print_scalar_literal(1.0)
471  + ")";
472  } else if (eq(x, *pi)) {
473  str_ = print_math_function("acos") + "(" + print_scalar_literal(-1.0)
474  + ")";
475  } else {
476  str_ = x.get_name();
477  }
478 }
479 
480 void CudaCodePrinter::bvisit(const NaN &x)
481 {
482  str_ = precision_ == CodePrinterPrecision::Float ? "CUDART_NAN_F"
483  : "CUDART_NAN";
484 }
485 
486 void CudaCodePrinter::bvisit(const Infty &x)
487 {
488  if (x.is_negative_infinity())
489  str_ = precision_ == CodePrinterPrecision::Float ? "-CUDART_INF_F"
490  : "-CUDART_INF";
491  else if (x.is_positive_infinity())
492  str_ = precision_ == CodePrinterPrecision::Float ? "CUDART_INF_F"
493  : "CUDART_INF";
494  else
495  throw SymEngineException("Not supported");
496 }
497 
498 void JSCodePrinter::bvisit(const Constant &x)
499 {
500  if (eq(x, *E)) {
501  str_ = "Math.E";
502  } else if (eq(x, *pi)) {
503  str_ = "Math.PI";
504  } else {
505  str_ = x.get_name();
506  }
507 }
508 void JSCodePrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
509  const RCP<const Basic> &b)
510 {
511  if (eq(*a, *E)) {
512  o << "Math.exp(" << apply(b) << ")";
513  } else if (eq(*b, *rational(1, 2))) {
514  o << "Math.sqrt(" << apply(a) << ")";
515  } else if (eq(*b, *rational(1, 3))) {
516  o << "Math.cbrt(" << apply(a) << ")";
517  } else {
518  o << "Math.pow(" << apply(a) << ", " << apply(b) << ")";
519  }
520 }
521 void JSCodePrinter::bvisit(const Abs &x)
522 {
523  std::ostringstream s;
524  s << "Math.abs(" << apply(x.get_arg()) << ")";
525  str_ = s.str();
526 }
527 void JSCodePrinter::bvisit(const Sin &x)
528 {
529  std::ostringstream s;
530  s << "Math.sin(" << apply(x.get_arg()) << ")";
531  str_ = s.str();
532 }
533 void JSCodePrinter::bvisit(const Cos &x)
534 {
535  std::ostringstream s;
536  s << "Math.cos(" << apply(x.get_arg()) << ")";
537  str_ = s.str();
538 }
539 void JSCodePrinter::bvisit(const Max &x)
540 {
541  const auto &args = x.get_args();
542  std::ostringstream s;
543  s << "Math.max(";
544  for (size_t i = 0; i < args.size(); ++i) {
545  s << apply(args[i]);
546  s << ((i == args.size() - 1) ? ")" : ", ");
547  }
548  str_ = s.str();
549 }
550 void JSCodePrinter::bvisit(const Min &x)
551 {
552  const auto &args = x.get_args();
553  std::ostringstream s;
554  s << "Math.min(";
555  for (size_t i = 0; i < args.size(); ++i) {
556  s << apply(args[i]);
557  s << ((i == args.size() - 1) ? ")" : ", ");
558  }
559  str_ = s.str();
560 }
561 
562 std::string ccode(const Basic &x, CodePrinterPrecision precision)
563 {
564  C99CodePrinter c(precision);
565  return c.apply(x);
566 }
567 
568 std::string cudacode(const Basic &x, CodePrinterPrecision precision)
569 {
570  CudaCodePrinter p(precision);
571  return p.apply(x);
572 }
573 
574 std::string jscode(const Basic &x)
575 {
576  JSCodePrinter p;
577  return p.apply(x);
578 }
579 
580 std::string inline c89code(const Basic &x)
581 {
582  C89CodePrinter p;
583  return p.apply(x);
584 }
585 
586 std::string inline c99code(const Basic &x)
587 {
588  C99CodePrinter p;
589  return p.apply(x);
590 }
591 
592 } // namespace SymEngine
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > max(const vec_basic &arg)
Canonicalize Max:
Definition: functions.cpp:3555
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > min(const vec_basic &arg)
Canonicalize Min:
Definition: functions.cpp:3659
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328