1 #include <symengine/printers/codegen.h>
2 #include <symengine/printers.h>
3 #include <symengine/symengine_exception.h>
8 void CodePrinter::bvisit(
const Basic &x)
10 throw SymEngineException(
"Not supported");
12 void CodePrinter::bvisit(
const Complex &x)
14 throw NotImplementedError(
"Not implemented");
16 void CodePrinter::bvisit(
const Dummy &x)
19 s << x.get_name() <<
'_' << x.get_index();
22 void CodePrinter::bvisit(
const Interval &x)
24 std::string var = str_;
26 bool is_inf =
eq(*x.get_start(), *NegInf);
29 if (x.get_left_open()) {
34 s << apply(x.get_start());
36 if (
neq(*x.get_end(), *Inf)) {
41 if (x.get_right_open()) {
46 s << apply(x.get_end());
50 void CodePrinter::bvisit(
const Contains &x)
52 x.get_expr()->accept(*
this);
53 x.get_set()->accept(*
this);
55 void CodePrinter::bvisit(
const Piecewise &x)
58 auto vec = x.get_vec();
59 for (
size_t i = 0;; ++i) {
60 if (i == vec.size() - 1) {
61 if (
neq(*vec[i].second, *boolTrue)) {
62 throw SymEngineException(
63 "Code generation requires a (Expr, True) at the end");
65 s <<
"(\n " << apply(vec[i].first) <<
"\n";
69 s << apply(vec[i].second);
71 s << apply(vec[i].first);
75 for (
size_t i = 0; i < vec.size(); i++) {
80 void CodePrinter::bvisit(
const Rational &x)
83 double n = mp_get_d(get_num(x.as_rational_class()));
84 double d = mp_get_d(get_den(x.as_rational_class()));
85 o << print_double(n) <<
"/" << print_double(d);
88 void CodePrinter::bvisit(
const Reals &x)
90 throw SymEngineException(
"Not supported");
92 void CodePrinter::bvisit(
const Rationals &x)
94 throw SymEngineException(
"Not supported");
96 void CodePrinter::bvisit(
const Integers &x)
98 throw SymEngineException(
"Not supported");
100 void CodePrinter::bvisit(
const EmptySet &x)
102 throw SymEngineException(
"Not supported");
104 void CodePrinter::bvisit(
const FiniteSet &x)
106 throw SymEngineException(
"Not supported");
108 void CodePrinter::bvisit(
const UniversalSet &x)
110 throw SymEngineException(
"Not supported");
112 void CodePrinter::bvisit(
const Abs &x)
114 std::ostringstream s;
115 s <<
"fabs(" << apply(x.get_arg()) <<
")";
118 void CodePrinter::bvisit(
const Ceiling &x)
120 std::ostringstream s;
121 s <<
"ceil(" << apply(x.get_arg()) <<
")";
124 void CodePrinter::bvisit(
const Truncate &x)
126 std::ostringstream s;
127 s <<
"trunc(" << apply(x.get_arg()) <<
")";
130 void CodePrinter::bvisit(
const Max &x)
132 std::ostringstream s;
133 const auto &args = x.get_args();
134 switch (args.size()) {
137 throw SymEngineException(
"Impossible");
139 s <<
"fmax(" << apply(args[0]) <<
", " << apply(args[1]) <<
")";
142 vec_basic inner_args(args.begin() + 1, args.end());
143 auto inner =
max(inner_args);
144 s <<
"fmax(" << apply(args[0]) <<
", " << apply(inner) <<
")";
150 void CodePrinter::bvisit(
const Min &x)
152 std::ostringstream s;
153 const auto &args = x.get_args();
154 switch (args.size()) {
157 throw SymEngineException(
"Impossible");
159 s <<
"fmin(" << apply(args[0]) <<
", " << apply(args[1]) <<
")";
162 vec_basic inner_args(args.begin() + 1, args.end());
163 auto inner =
min(inner_args);
164 s <<
"fmin(" << apply(args[0]) <<
", " << apply(inner) <<
")";
170 void CodePrinter::bvisit(
const Constant &x)
174 }
else if (
eq(x, *pi)) {
180 void CodePrinter::bvisit(
const NaN &x)
182 std::ostringstream s;
186 void CodePrinter::bvisit(
const Equality &x)
188 std::ostringstream s;
189 s << apply(x.get_arg1()) <<
" == " << apply(x.get_arg2());
192 void CodePrinter::bvisit(
const Unequality &x)
194 std::ostringstream s;
195 s << apply(x.get_arg1()) <<
" != " << apply(x.get_arg2());
198 void CodePrinter::bvisit(
const LessThan &x)
200 std::ostringstream s;
201 s << apply(x.get_arg1()) <<
" <= " << apply(x.get_arg2());
204 void CodePrinter::bvisit(
const StrictLessThan &x)
206 std::ostringstream s;
207 s << apply(x.get_arg1()) <<
" < " << apply(x.get_arg2());
210 void CodePrinter::bvisit(
const UnivariateSeries &x)
212 throw SymEngineException(
"Not supported");
214 void CodePrinter::bvisit(
const Derivative &x)
216 throw SymEngineException(
"Not supported");
218 void CodePrinter::bvisit(
const Subs &x)
220 throw SymEngineException(
"Not supported");
222 void CodePrinter::bvisit(
const GaloisField &x)
224 throw SymEngineException(
"Not supported");
227 void C89CodePrinter::bvisit(
const Infty &x)
229 std::ostringstream s;
230 if (x.is_negative_infinity())
232 else if (x.is_positive_infinity())
235 throw SymEngineException(
"Not supported");
238 void C89CodePrinter::_print_pow(std::ostringstream &o,
239 const RCP<const Basic> &a,
240 const RCP<const Basic> &b)
243 o <<
"exp(" << apply(b) <<
")";
245 o <<
"sqrt(" << apply(a) <<
")";
247 o <<
"pow(" << apply(a) <<
", " << apply(b) <<
")";
251 void C99CodePrinter::bvisit(
const Infty &x)
253 std::ostringstream s;
254 if (x.is_negative_infinity())
256 else if (x.is_positive_infinity())
259 throw SymEngineException(
"Not supported");
262 void C99CodePrinter::_print_pow(std::ostringstream &o,
263 const RCP<const Basic> &a,
264 const RCP<const Basic> &b)
267 o <<
"exp(" << apply(b) <<
")";
269 o <<
"sqrt(" << apply(a) <<
")";
271 o <<
"cbrt(" << apply(a) <<
")";
273 o <<
"pow(" << apply(a) <<
", " << apply(b) <<
")";
276 void C99CodePrinter::bvisit(
const Gamma &x)
278 std::ostringstream s;
279 s <<
"tgamma(" << apply(x.get_arg()) <<
")";
282 void C99CodePrinter::bvisit(
const LogGamma &x)
284 std::ostringstream s;
285 s <<
"lgamma(" << apply(x.get_arg()) <<
")";
289 void JSCodePrinter::bvisit(
const Constant &x)
293 }
else if (
eq(x, *pi)) {
299 void JSCodePrinter::_print_pow(std::ostringstream &o,
const RCP<const Basic> &a,
300 const RCP<const Basic> &b)
303 o <<
"Math.exp(" << apply(b) <<
")";
305 o <<
"Math.sqrt(" << apply(a) <<
")";
307 o <<
"Math.cbrt(" << apply(a) <<
")";
309 o <<
"Math.pow(" << apply(a) <<
", " << apply(b) <<
")";
312 void JSCodePrinter::bvisit(
const Abs &x)
314 std::ostringstream s;
315 s <<
"Math.abs(" << apply(x.get_arg()) <<
")";
318 void JSCodePrinter::bvisit(
const Sin &x)
320 std::ostringstream s;
321 s <<
"Math.sin(" << apply(x.get_arg()) <<
")";
324 void JSCodePrinter::bvisit(
const Cos &x)
326 std::ostringstream s;
327 s <<
"Math.cos(" << apply(x.get_arg()) <<
")";
330 void JSCodePrinter::bvisit(
const Max &x)
332 const auto &args = x.get_args();
333 std::ostringstream s;
335 for (
size_t i = 0; i < args.size(); ++i) {
337 s << ((i == args.size() - 1) ?
")" :
", ");
341 void JSCodePrinter::bvisit(
const Min &x)
343 const auto &args = x.get_args();
344 std::ostringstream s;
346 for (
size_t i = 0; i < args.size(); ++i) {
348 s << ((i == args.size() - 1) ?
")" :
", ");
353 std::string ccode(
const Basic &x)
359 std::string jscode(
const Basic &x)
365 std::string
inline c89code(
const Basic &x)
371 std::string
inline c99code(
const Basic &x)
Main namespace for SymEngine package.
RCP< const Basic > max(const vec_basic &arg)
Canonicalize Max:
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Basic > min(const vec_basic &arg)
Canonicalize Min:
RCP< const Number > rational(long n, long d)
convenience creator from two longs