codegen.cpp
1 #include <symengine/printers/codegen.h>
2 #include <symengine/printers.h>
3 #include <symengine/symengine_exception.h>
4 
5 namespace SymEngine
6 {
7 
8 void CodePrinter::bvisit(const Basic &x)
9 {
10  throw SymEngineException("Not supported");
11 }
12 void CodePrinter::bvisit(const Complex &x)
13 {
14  throw NotImplementedError("Not implemented");
15 }
16 void CodePrinter::bvisit(const Interval &x)
17 {
18  std::string var = str_;
20  bool is_inf = eq(*x.get_start(), *NegInf);
21  if (not is_inf) {
22  s << var;
23  if (x.get_left_open()) {
24  s << " > ";
25  } else {
26  s << " >= ";
27  }
28  s << apply(x.get_start());
29  }
30  if (neq(*x.get_end(), *Inf)) {
31  if (not is_inf) {
32  s << " && ";
33  }
34  s << var;
35  if (x.get_right_open()) {
36  s << " < ";
37  } else {
38  s << " <= ";
39  }
40  s << apply(x.get_end());
41  }
42  str_ = s.str();
43 }
44 void CodePrinter::bvisit(const Contains &x)
45 {
46  x.get_expr()->accept(*this);
47  x.get_set()->accept(*this);
48 }
49 void CodePrinter::bvisit(const Piecewise &x)
50 {
52  auto vec = x.get_vec();
53  for (size_t i = 0;; ++i) {
54  if (i == vec.size() - 1) {
55  if (neq(*vec[i].second, *boolTrue)) {
56  throw SymEngineException(
57  "Code generation requires a (Expr, True) at the end");
58  }
59  s << "(\n " << apply(vec[i].first) << "\n";
60  break;
61  } else {
62  s << "((";
63  s << apply(vec[i].second);
64  s << ") ? (\n ";
65  s << apply(vec[i].first);
66  s << "\n)\n: ";
67  }
68  }
69  for (size_t i = 0; i < vec.size(); i++) {
70  s << ")";
71  }
72  str_ = s.str();
73 }
74 void CodePrinter::bvisit(const Rational &x)
75 {
77  double n = mp_get_d(get_num(x.as_rational_class()));
78  double d = mp_get_d(get_den(x.as_rational_class()));
79  o << print_double(n) << "/" << print_double(d);
80  str_ = o.str();
81 }
82 void CodePrinter::bvisit(const Reals &x)
83 {
84  throw SymEngineException("Not supported");
85 }
86 void CodePrinter::bvisit(const Rationals &x)
87 {
88  throw SymEngineException("Not supported");
89 }
90 void CodePrinter::bvisit(const Integers &x)
91 {
92  throw SymEngineException("Not supported");
93 }
94 void CodePrinter::bvisit(const EmptySet &x)
95 {
96  throw SymEngineException("Not supported");
97 }
98 void CodePrinter::bvisit(const FiniteSet &x)
99 {
100  throw SymEngineException("Not supported");
101 }
102 void CodePrinter::bvisit(const UniversalSet &x)
103 {
104  throw SymEngineException("Not supported");
105 }
106 void CodePrinter::bvisit(const Abs &x)
107 {
109  s << "fabs(" << apply(x.get_arg()) << ")";
110  str_ = s.str();
111 }
112 void CodePrinter::bvisit(const Ceiling &x)
113 {
115  s << "ceil(" << apply(x.get_arg()) << ")";
116  str_ = s.str();
117 }
118 void CodePrinter::bvisit(const Truncate &x)
119 {
121  s << "trunc(" << apply(x.get_arg()) << ")";
122  str_ = s.str();
123 }
124 void CodePrinter::bvisit(const Max &x)
125 {
127  const auto &args = x.get_args();
128  switch (args.size()) {
129  case 0:
130  case 1:
131  throw SymEngineException("Impossible");
132  case 2:
133  s << "fmax(" << apply(args[0]) << ", " << apply(args[1]) << ")";
134  break;
135  default: {
136  vec_basic inner_args(args.begin() + 1, args.end());
137  auto inner = max(inner_args);
138  s << "fmax(" << apply(args[0]) << ", " << apply(inner) << ")";
139  break;
140  }
141  }
142  str_ = s.str();
143 }
144 void CodePrinter::bvisit(const Min &x)
145 {
147  const auto &args = x.get_args();
148  switch (args.size()) {
149  case 0:
150  case 1:
151  throw SymEngineException("Impossible");
152  case 2:
153  s << "fmin(" << apply(args[0]) << ", " << apply(args[1]) << ")";
154  break;
155  default: {
156  vec_basic inner_args(args.begin() + 1, args.end());
157  auto inner = min(inner_args);
158  s << "fmin(" << apply(args[0]) << ", " << apply(inner) << ")";
159  break;
160  }
161  }
162  str_ = s.str();
163 }
164 void CodePrinter::bvisit(const Constant &x)
165 {
166  if (eq(x, *E)) {
167  str_ = "exp(1)";
168  } else if (eq(x, *pi)) {
169  str_ = "acos(-1)";
170  } else {
171  str_ = x.get_name();
172  }
173 }
174 void CodePrinter::bvisit(const NaN &x)
175 {
177  s << "NAN";
178  str_ = s.str();
179 }
180 void CodePrinter::bvisit(const Equality &x)
181 {
183  s << apply(x.get_arg1()) << " == " << apply(x.get_arg2());
184  str_ = s.str();
185 }
186 void CodePrinter::bvisit(const Unequality &x)
187 {
189  s << apply(x.get_arg1()) << " != " << apply(x.get_arg2());
190  str_ = s.str();
191 }
192 void CodePrinter::bvisit(const LessThan &x)
193 {
195  s << apply(x.get_arg1()) << " <= " << apply(x.get_arg2());
196  str_ = s.str();
197 }
198 void CodePrinter::bvisit(const StrictLessThan &x)
199 {
201  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
202  str_ = s.str();
203 }
204 void CodePrinter::bvisit(const UnivariateSeries &x)
205 {
206  throw SymEngineException("Not supported");
207 }
208 void CodePrinter::bvisit(const Derivative &x)
209 {
210  throw SymEngineException("Not supported");
211 }
212 void CodePrinter::bvisit(const Subs &x)
213 {
214  throw SymEngineException("Not supported");
215 }
216 void CodePrinter::bvisit(const GaloisField &x)
217 {
218  throw SymEngineException("Not supported");
219 }
220 
221 void C89CodePrinter::bvisit(const Infty &x)
222 {
224  if (x.is_negative_infinity())
225  s << "-HUGE_VAL";
226  else if (x.is_positive_infinity())
227  s << "HUGE_VAL";
228  else
229  throw SymEngineException("Not supported");
230  str_ = s.str();
231 }
232 void C89CodePrinter::_print_pow(std::ostringstream &o,
233  const RCP<const Basic> &a,
234  const RCP<const Basic> &b)
235 {
236  if (eq(*a, *E)) {
237  o << "exp(" << apply(b) << ")";
238  } else if (eq(*b, *rational(1, 2))) {
239  o << "sqrt(" << apply(a) << ")";
240  } else {
241  o << "pow(" << apply(a) << ", " << apply(b) << ")";
242  }
243 }
244 
245 void C99CodePrinter::bvisit(const Infty &x)
246 {
248  if (x.is_negative_infinity())
249  s << "-INFINITY";
250  else if (x.is_positive_infinity())
251  s << "INFINITY";
252  else
253  throw SymEngineException("Not supported");
254  str_ = s.str();
255 }
256 void C99CodePrinter::_print_pow(std::ostringstream &o,
257  const RCP<const Basic> &a,
258  const RCP<const Basic> &b)
259 {
260  if (eq(*a, *E)) {
261  o << "exp(" << apply(b) << ")";
262  } else if (eq(*b, *rational(1, 2))) {
263  o << "sqrt(" << apply(a) << ")";
264  } else if (eq(*b, *rational(1, 3))) {
265  o << "cbrt(" << apply(a) << ")";
266  } else {
267  o << "pow(" << apply(a) << ", " << apply(b) << ")";
268  }
269 }
270 void C99CodePrinter::bvisit(const Gamma &x)
271 {
273  s << "tgamma(" << apply(x.get_arg()) << ")";
274  str_ = s.str();
275 }
276 void C99CodePrinter::bvisit(const LogGamma &x)
277 {
279  s << "lgamma(" << apply(x.get_arg()) << ")";
280  str_ = s.str();
281 }
282 
283 void JSCodePrinter::bvisit(const Constant &x)
284 {
285  if (eq(x, *E)) {
286  str_ = "Math.E";
287  } else if (eq(x, *pi)) {
288  str_ = "Math.PI";
289  } else {
290  str_ = x.get_name();
291  }
292 }
293 void JSCodePrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
294  const RCP<const Basic> &b)
295 {
296  if (eq(*a, *E)) {
297  o << "Math.exp(" << apply(b) << ")";
298  } else if (eq(*b, *rational(1, 2))) {
299  o << "Math.sqrt(" << apply(a) << ")";
300  } else if (eq(*b, *rational(1, 3))) {
301  o << "Math.cbrt(" << apply(a) << ")";
302  } else {
303  o << "Math.pow(" << apply(a) << ", " << apply(b) << ")";
304  }
305 }
306 void JSCodePrinter::bvisit(const Abs &x)
307 {
309  s << "Math.abs(" << apply(x.get_arg()) << ")";
310  str_ = s.str();
311 }
312 void JSCodePrinter::bvisit(const Sin &x)
313 {
315  s << "Math.sin(" << apply(x.get_arg()) << ")";
316  str_ = s.str();
317 }
318 void JSCodePrinter::bvisit(const Cos &x)
319 {
321  s << "Math.cos(" << apply(x.get_arg()) << ")";
322  str_ = s.str();
323 }
324 void JSCodePrinter::bvisit(const Max &x)
325 {
326  const auto &args = x.get_args();
328  s << "Math.max(";
329  for (size_t i = 0; i < args.size(); ++i) {
330  s << apply(args[i]);
331  s << ((i == args.size() - 1) ? ")" : ", ");
332  }
333  str_ = s.str();
334 }
335 void JSCodePrinter::bvisit(const Min &x)
336 {
337  const auto &args = x.get_args();
339  s << "Math.min(";
340  for (size_t i = 0; i < args.size(); ++i) {
341  s << apply(args[i]);
342  s << ((i == args.size() - 1) ? ")" : ", ");
343  }
344  str_ = s.str();
345 }
346 
347 std::string ccode(const Basic &x)
348 {
349  C99CodePrinter c;
350  return c.apply(x);
351 }
352 
353 std::string jscode(const Basic &x)
354 {
355  JSCodePrinter p;
356  return p.apply(x);
357 }
358 
359 std::string inline c89code(const Basic &x)
360 {
361  C89CodePrinter p;
362  return p.apply(x);
363 }
364 
365 std::string inline c99code(const Basic &x)
366 {
367  C99CodePrinter p;
368  return p.apply(x);
369 }
370 
371 } // namespace SymEngine
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > max(const vec_basic &arg)
Canonicalize Max:
Definition: functions.cpp:3555
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > min(const vec_basic &arg)
Canonicalize Min:
Definition: functions.cpp:3659
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328
T str(T... args)