codegen.cpp
1 #include <symengine/printers/codegen.h>
2 #include <symengine/printers.h>
3 #include <symengine/symengine_exception.h>
4 
5 namespace SymEngine
6 {
7 
8 void CodePrinter::bvisit(const Basic &x)
9 {
10  throw SymEngineException("Not supported");
11 }
12 void CodePrinter::bvisit(const Complex &x)
13 {
14  throw NotImplementedError("Not implemented");
15 }
16 void CodePrinter::bvisit(const Dummy &x)
17 {
18  std::ostringstream s;
19  s << x.get_name() << '_' << x.get_index();
20  str_ = s.str();
21 }
22 void CodePrinter::bvisit(const Interval &x)
23 {
24  std::string var = str_;
25  std::ostringstream s;
26  bool is_inf = eq(*x.get_start(), *NegInf);
27  if (not is_inf) {
28  s << var;
29  if (x.get_left_open()) {
30  s << " > ";
31  } else {
32  s << " >= ";
33  }
34  s << apply(x.get_start());
35  }
36  if (neq(*x.get_end(), *Inf)) {
37  if (not is_inf) {
38  s << " && ";
39  }
40  s << var;
41  if (x.get_right_open()) {
42  s << " < ";
43  } else {
44  s << " <= ";
45  }
46  s << apply(x.get_end());
47  }
48  str_ = s.str();
49 }
50 void CodePrinter::bvisit(const Contains &x)
51 {
52  x.get_expr()->accept(*this);
53  x.get_set()->accept(*this);
54 }
55 void CodePrinter::bvisit(const Piecewise &x)
56 {
57  std::ostringstream s;
58  auto vec = x.get_vec();
59  for (size_t i = 0;; ++i) {
60  if (i == vec.size() - 1) {
61  if (neq(*vec[i].second, *boolTrue)) {
62  throw SymEngineException(
63  "Code generation requires a (Expr, True) at the end");
64  }
65  s << "(\n " << apply(vec[i].first) << "\n";
66  break;
67  } else {
68  s << "((";
69  s << apply(vec[i].second);
70  s << ") ? (\n ";
71  s << apply(vec[i].first);
72  s << "\n)\n: ";
73  }
74  }
75  for (size_t i = 0; i < vec.size(); i++) {
76  s << ")";
77  }
78  str_ = s.str();
79 }
80 void CodePrinter::bvisit(const Rational &x)
81 {
82  std::ostringstream o;
83  double n = mp_get_d(get_num(x.as_rational_class()));
84  double d = mp_get_d(get_den(x.as_rational_class()));
85  o << print_double(n) << "/" << print_double(d);
86  str_ = o.str();
87 }
88 void CodePrinter::bvisit(const Reals &x)
89 {
90  throw SymEngineException("Not supported");
91 }
92 void CodePrinter::bvisit(const Rationals &x)
93 {
94  throw SymEngineException("Not supported");
95 }
96 void CodePrinter::bvisit(const Integers &x)
97 {
98  throw SymEngineException("Not supported");
99 }
100 void CodePrinter::bvisit(const EmptySet &x)
101 {
102  throw SymEngineException("Not supported");
103 }
104 void CodePrinter::bvisit(const FiniteSet &x)
105 {
106  throw SymEngineException("Not supported");
107 }
108 void CodePrinter::bvisit(const UniversalSet &x)
109 {
110  throw SymEngineException("Not supported");
111 }
112 void CodePrinter::bvisit(const Abs &x)
113 {
114  std::ostringstream s;
115  s << "fabs(" << apply(x.get_arg()) << ")";
116  str_ = s.str();
117 }
118 void CodePrinter::bvisit(const Ceiling &x)
119 {
120  std::ostringstream s;
121  s << "ceil(" << apply(x.get_arg()) << ")";
122  str_ = s.str();
123 }
124 void CodePrinter::bvisit(const Truncate &x)
125 {
126  std::ostringstream s;
127  s << "trunc(" << apply(x.get_arg()) << ")";
128  str_ = s.str();
129 }
130 void CodePrinter::bvisit(const Max &x)
131 {
132  std::ostringstream s;
133  const auto &args = x.get_args();
134  switch (args.size()) {
135  case 0:
136  case 1:
137  throw SymEngineException("Impossible");
138  case 2:
139  s << "fmax(" << apply(args[0]) << ", " << apply(args[1]) << ")";
140  break;
141  default: {
142  vec_basic inner_args(args.begin() + 1, args.end());
143  auto inner = max(inner_args);
144  s << "fmax(" << apply(args[0]) << ", " << apply(inner) << ")";
145  break;
146  }
147  }
148  str_ = s.str();
149 }
150 void CodePrinter::bvisit(const Min &x)
151 {
152  std::ostringstream s;
153  const auto &args = x.get_args();
154  switch (args.size()) {
155  case 0:
156  case 1:
157  throw SymEngineException("Impossible");
158  case 2:
159  s << "fmin(" << apply(args[0]) << ", " << apply(args[1]) << ")";
160  break;
161  default: {
162  vec_basic inner_args(args.begin() + 1, args.end());
163  auto inner = min(inner_args);
164  s << "fmin(" << apply(args[0]) << ", " << apply(inner) << ")";
165  break;
166  }
167  }
168  str_ = s.str();
169 }
170 void CodePrinter::bvisit(const Constant &x)
171 {
172  if (eq(x, *E)) {
173  str_ = "exp(1)";
174  } else if (eq(x, *pi)) {
175  str_ = "acos(-1)";
176  } else {
177  str_ = x.get_name();
178  }
179 }
180 void CodePrinter::bvisit(const NaN &x)
181 {
182  std::ostringstream s;
183  s << "NAN";
184  str_ = s.str();
185 }
186 void CodePrinter::bvisit(const Equality &x)
187 {
188  std::ostringstream s;
189  s << apply(x.get_arg1()) << " == " << apply(x.get_arg2());
190  str_ = s.str();
191 }
192 void CodePrinter::bvisit(const Unequality &x)
193 {
194  std::ostringstream s;
195  s << apply(x.get_arg1()) << " != " << apply(x.get_arg2());
196  str_ = s.str();
197 }
198 void CodePrinter::bvisit(const LessThan &x)
199 {
200  std::ostringstream s;
201  s << apply(x.get_arg1()) << " <= " << apply(x.get_arg2());
202  str_ = s.str();
203 }
204 void CodePrinter::bvisit(const StrictLessThan &x)
205 {
206  std::ostringstream s;
207  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
208  str_ = s.str();
209 }
210 void CodePrinter::bvisit(const UnivariateSeries &x)
211 {
212  throw SymEngineException("Not supported");
213 }
214 void CodePrinter::bvisit(const Derivative &x)
215 {
216  throw SymEngineException("Not supported");
217 }
218 void CodePrinter::bvisit(const Subs &x)
219 {
220  throw SymEngineException("Not supported");
221 }
222 void CodePrinter::bvisit(const GaloisField &x)
223 {
224  throw SymEngineException("Not supported");
225 }
226 
227 void C89CodePrinter::bvisit(const Infty &x)
228 {
229  std::ostringstream s;
230  if (x.is_negative_infinity())
231  s << "-HUGE_VAL";
232  else if (x.is_positive_infinity())
233  s << "HUGE_VAL";
234  else
235  throw SymEngineException("Not supported");
236  str_ = s.str();
237 }
238 void C89CodePrinter::_print_pow(std::ostringstream &o,
239  const RCP<const Basic> &a,
240  const RCP<const Basic> &b)
241 {
242  if (eq(*a, *E)) {
243  o << "exp(" << apply(b) << ")";
244  } else if (eq(*b, *rational(1, 2))) {
245  o << "sqrt(" << apply(a) << ")";
246  } else {
247  o << "pow(" << apply(a) << ", " << apply(b) << ")";
248  }
249 }
250 
251 void C99CodePrinter::bvisit(const Infty &x)
252 {
253  std::ostringstream s;
254  if (x.is_negative_infinity())
255  s << "-INFINITY";
256  else if (x.is_positive_infinity())
257  s << "INFINITY";
258  else
259  throw SymEngineException("Not supported");
260  str_ = s.str();
261 }
262 void C99CodePrinter::_print_pow(std::ostringstream &o,
263  const RCP<const Basic> &a,
264  const RCP<const Basic> &b)
265 {
266  if (eq(*a, *E)) {
267  o << "exp(" << apply(b) << ")";
268  } else if (eq(*b, *rational(1, 2))) {
269  o << "sqrt(" << apply(a) << ")";
270  } else if (eq(*b, *rational(1, 3))) {
271  o << "cbrt(" << apply(a) << ")";
272  } else {
273  o << "pow(" << apply(a) << ", " << apply(b) << ")";
274  }
275 }
276 void C99CodePrinter::bvisit(const Gamma &x)
277 {
278  std::ostringstream s;
279  s << "tgamma(" << apply(x.get_arg()) << ")";
280  str_ = s.str();
281 }
282 void C99CodePrinter::bvisit(const LogGamma &x)
283 {
284  std::ostringstream s;
285  s << "lgamma(" << apply(x.get_arg()) << ")";
286  str_ = s.str();
287 }
288 
289 void JSCodePrinter::bvisit(const Constant &x)
290 {
291  if (eq(x, *E)) {
292  str_ = "Math.E";
293  } else if (eq(x, *pi)) {
294  str_ = "Math.PI";
295  } else {
296  str_ = x.get_name();
297  }
298 }
299 void JSCodePrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
300  const RCP<const Basic> &b)
301 {
302  if (eq(*a, *E)) {
303  o << "Math.exp(" << apply(b) << ")";
304  } else if (eq(*b, *rational(1, 2))) {
305  o << "Math.sqrt(" << apply(a) << ")";
306  } else if (eq(*b, *rational(1, 3))) {
307  o << "Math.cbrt(" << apply(a) << ")";
308  } else {
309  o << "Math.pow(" << apply(a) << ", " << apply(b) << ")";
310  }
311 }
312 void JSCodePrinter::bvisit(const Abs &x)
313 {
314  std::ostringstream s;
315  s << "Math.abs(" << apply(x.get_arg()) << ")";
316  str_ = s.str();
317 }
318 void JSCodePrinter::bvisit(const Sin &x)
319 {
320  std::ostringstream s;
321  s << "Math.sin(" << apply(x.get_arg()) << ")";
322  str_ = s.str();
323 }
324 void JSCodePrinter::bvisit(const Cos &x)
325 {
326  std::ostringstream s;
327  s << "Math.cos(" << apply(x.get_arg()) << ")";
328  str_ = s.str();
329 }
330 void JSCodePrinter::bvisit(const Max &x)
331 {
332  const auto &args = x.get_args();
333  std::ostringstream s;
334  s << "Math.max(";
335  for (size_t i = 0; i < args.size(); ++i) {
336  s << apply(args[i]);
337  s << ((i == args.size() - 1) ? ")" : ", ");
338  }
339  str_ = s.str();
340 }
341 void JSCodePrinter::bvisit(const Min &x)
342 {
343  const auto &args = x.get_args();
344  std::ostringstream s;
345  s << "Math.min(";
346  for (size_t i = 0; i < args.size(); ++i) {
347  s << apply(args[i]);
348  s << ((i == args.size() - 1) ? ")" : ", ");
349  }
350  str_ = s.str();
351 }
352 
353 std::string ccode(const Basic &x)
354 {
355  C99CodePrinter c;
356  return c.apply(x);
357 }
358 
359 std::string jscode(const Basic &x)
360 {
361  JSCodePrinter p;
362  return p.apply(x);
363 }
364 
365 std::string inline c89code(const Basic &x)
366 {
367  C89CodePrinter p;
368  return p.apply(x);
369 }
370 
371 std::string inline c99code(const Basic &x)
372 {
373  C99CodePrinter p;
374  return p.apply(x);
375 }
376 
377 } // namespace SymEngine
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > max(const vec_basic &arg)
Canonicalize Max:
Definition: functions.cpp:3555
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > min(const vec_basic &arg)
Canonicalize Min:
Definition: functions.cpp:3659
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328