Loading...
Searching...
No Matches
latex.cpp
1#include <symengine/printers/latex.h>
2#include <symengine/printers.h>
3#include <symengine/basic.h>
4
5namespace SymEngine
6{
7
8std::string latex(const Basic &x)
9{
10 LatexPrinter p;
11 return p.apply(x);
12}
13
14void print_rational_class(const rational_class &r, std::ostringstream &s)
15{
16 if (get_den(r) == 1) {
17 s << get_num(r);
18 } else {
19 s << "\\frac{" << get_num(r) << "}{" << get_den(r) << "}";
20 }
21}
22
23void LatexPrinter::bvisit(const Symbol &x)
24{
25 std::string name = x.get_name();
26
27 if (name.find('\\') != std::string::npos
28 or name.find('{') != std::string::npos) {
29 str_ = name;
30 return;
31 }
32 if (name[0] == '_') {
33 name = name.substr(1, name.size());
34 }
36 = {"alpha", "beta", "gamma", "Gamma", "delta", "Delta", "epsilon",
37 "zeta", "eta", "theta", "Theta", "iota", "kappa", "lambda",
38 "Lambda", "mu", "nu", "xi", "omicron", "pi", "Pi",
39 "rho", "sigma", "Sigma", "tau", "upsilon", "Upsilon", "phi",
40 "Phi", "chi", "psi", "Psi", "omega", "Omega"};
41
42 for (auto &letter : greeks) {
43 if (name == letter) {
44 str_ = "\\" + name;
45 return;
46 }
47 if (name.size() > letter.size() and name.find(letter + "_") == 0) {
48 str_ = "\\" + name;
49 return;
50 }
51 }
52 str_ = name;
53 return;
54}
55
56void LatexPrinter::bvisit(const Rational &x)
57{
58 const auto &rational = x.as_rational_class();
60 print_rational_class(rational, s);
61 str_ = s.str();
62}
63
64void LatexPrinter::bvisit(const Complex &x)
65{
67 if (x.real_ != 0) {
68 print_rational_class(x.real_, s);
69 // Since Complex is in canonical form, imaginary_ is not 0.
70 if (mp_sign(x.imaginary_) == 1) {
71 s << " + ";
72 } else {
73 s << " - ";
74 }
75 // If imaginary_ is not 1 or -1, print the absolute value
76 if (x.imaginary_ != mp_sign(x.imaginary_)) {
77 print_rational_class(mp_abs(x.imaginary_), s);
78 s << "j";
79 } else {
80 s << "j";
81 }
82 } else {
83 if (x.imaginary_ != mp_sign(x.imaginary_)) {
84 print_rational_class(x.imaginary_, s);
85 s << "j";
86 } else {
87 if (mp_sign(x.imaginary_) == 1) {
88 s << "j";
89 } else {
90 s << "-j";
91 }
92 }
93 }
94 str_ = s.str();
95}
96
97void LatexPrinter::bvisit(const ComplexBase &x)
98{
99 RCP<const Number> imag = x.imaginary_part();
100 if (imag->is_negative()) {
101 std::string str = apply(imag);
102 str = str.substr(1, str.length() - 1);
103 str_ = apply(x.real_part()) + " - " + str + "j";
104 } else {
105 str_ = apply(x.real_part()) + " + " + apply(imag) + "j";
106 }
107}
108void LatexPrinter::bvisit(const ComplexDouble &x)
109{
110 bvisit(static_cast<const ComplexBase &>(x));
111}
112
113#ifdef HAVE_SYMENGINE_MPC
114void LatexPrinter::bvisit(const ComplexMPC &x)
115{
116 bvisit(static_cast<const ComplexBase &>(x));
117}
118#endif
119
120void LatexPrinter::bvisit(const Infty &x)
121{
122 if (x.is_negative_infinity()) {
123 str_ = "-\\infty";
124 } else if (x.is_positive_infinity()) {
125 str_ = "\\infty";
126 } else {
127 str_ = "\\tilde{\\infty}";
128 }
129}
130
131void LatexPrinter::bvisit(const NaN &x)
132{
133 str_ = "\\mathrm{NaN}";
134}
135
136void LatexPrinter::bvisit(const Constant &x)
137{
138 if (eq(x, *pi)) {
139 str_ = "\\pi";
140 } else if (eq(x, *E)) {
141 str_ = "e";
142 } else if (eq(x, *EulerGamma)) {
143 str_ = "\\gamma";
144 } else if (eq(x, *Catalan)) {
145 str_ = "G";
146 } else if (eq(x, *GoldenRatio)) {
147 str_ = "\\phi";
148 } else {
149 throw NotImplementedError("Constant " + x.get_name()
150 + " is not implemented.");
151 }
152}
153
154void LatexPrinter::bvisit(const Derivative &x)
155{
156 const auto &symbols = x.get_symbols();
158 if (symbols.size() == 1) {
159 if (free_symbols(*x.get_arg()).size() == 1) {
160 s << "\\frac{d}{d " << apply(*symbols.begin());
161 } else {
162 s << "\\frac{\\partial}{\\partial " << apply(*symbols.begin());
163 }
164 } else {
165 s << "\\frac{\\partial^" << symbols.size() << "}{";
166 unsigned count = 1;
167 auto it = symbols.begin();
168 RCP<const Basic> prev = *it;
169 ++it;
170 for (; it != symbols.end(); ++it) {
171 if (neq(*prev, **it)) {
172 if (count == 1) {
173 s << "\\partial " << apply(*prev) << " ";
174 } else {
175 s << "\\partial " << apply(*prev) << "^" << count << " ";
176 }
177 count = 1;
178 } else {
179 count++;
180 }
181 prev = *it;
182 }
183 if (count == 1) {
184 s << "\\partial " << apply(*prev) << " ";
185 } else {
186 s << "\\partial " << apply(*prev) << "^" << count << " ";
187 }
188 }
189 s << "} " << apply(x.get_arg());
190 str_ = s.str();
191}
192
193void LatexPrinter::bvisit(const Subs &x)
194{
196 o << "\\left. " << apply(x.get_arg()) << "\\right|_{\\substack{";
197 for (auto p = x.get_dict().begin(); p != x.get_dict().end(); p++) {
198 if (p != x.get_dict().begin()) {
199 o << " \\\\ ";
200 }
201 o << apply(p->first) << "=" << apply(p->second);
202 }
203 o << "}}";
204 str_ = o.str();
205}
206
207void LatexPrinter::bvisit(const Equality &x)
208{
210 s << apply(x.get_arg1()) << " = " << apply(x.get_arg2());
211 str_ = s.str();
212}
213
214void LatexPrinter::bvisit(const Unequality &x)
215{
217 s << apply(x.get_arg1()) << " \\neq " << apply(x.get_arg2());
218 str_ = s.str();
219}
220
221void LatexPrinter::bvisit(const LessThan &x)
222{
224 s << apply(x.get_arg1()) << " \\leq " << apply(x.get_arg2());
225 str_ = s.str();
226}
227
228void LatexPrinter::bvisit(const StrictLessThan &x)
229{
231 s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
232 str_ = s.str();
233}
234
235std::string latex(const DenseMatrix &m, const unsigned max_rows,
236 const unsigned max_cols)
237{
238 const unsigned int nrows = m.nrows();
239 const unsigned int ncols = m.ncols();
240 unsigned int nrows_display = nrows;
241 if (nrows > max_rows)
242 nrows_display = max_rows - 1;
243 unsigned int ncols_display = ncols;
244 if (ncols > max_cols)
245 ncols_display = max_cols - 1;
246
248 s << "\\left[\\begin{matrix}" << std::endl;
249
250 std::string end_of_line = " \\\\\n";
251 if (ncols_display < ncols) {
252 end_of_line = " & \\cdots" + end_of_line;
253 }
254 for (unsigned int row_index = 0; row_index < nrows_display; row_index++) {
255 for (unsigned int column_index = 0; column_index < ncols_display;
256 column_index++) {
257 RCP<const Basic> v = m.get(row_index, column_index);
258
259 if (v.is_null()) {
260 // element has not been initalized
261 throw SymEngineException(
262 "cannot display uninitialized element");
263 } else {
264 s << latex(*v);
265 }
266 if (column_index < ncols_display - 1)
267 s << " & ";
268 }
269 s << end_of_line;
270 }
271 if (nrows_display < nrows) {
272 for (unsigned int column_index = 0; column_index < ncols_display;
273 column_index++) {
274 s << "\\vdots";
275 if (column_index < ncols_display - 1)
276 s << " & ";
277 }
278 s << end_of_line;
279 }
280 s << "\\end{matrix}\\right]\n";
281
282 return s.str();
283}
284
285void LatexPrinter::bvisit(const Interval &x)
286{
288 if (x.get_left_open())
289 s << "\\left(";
290 else
291 s << "\\left[";
292 s << *x.get_start() << ", " << *x.get_end();
293 if (x.get_right_open())
294 s << "\\right)";
295 else
296 s << "\\right]";
297 str_ = s.str();
298}
299
300void LatexPrinter::bvisit(const BooleanAtom &x)
301{
302 if (x.get_val()) {
303 str_ = "\\mathrm{True}";
304 } else {
305 str_ = "\\mathrm{False}";
306 }
307}
308
309void LatexPrinter::bvisit(const And &x)
310{
312 auto container = x.get_container();
313 if (is_a<Or>(**container.begin()) or is_a<Xor>(**container.begin())) {
314 s << parenthesize(apply(*container.begin()));
315 } else {
316 s << apply(*container.begin());
317 }
318
319 for (auto it = ++(container.begin()); it != container.end(); ++it) {
320 s << " \\wedge ";
321 if (is_a<Or>(**it) or is_a<Xor>(**it)) {
322 s << parenthesize(apply(*it));
323 } else {
324 s << apply(*it);
325 }
326 }
327 str_ = s.str();
328}
329
330void LatexPrinter::bvisit(const Or &x)
331{
333 auto container = x.get_container();
334 if (is_a<And>(**container.begin()) or is_a<Xor>(**container.begin())) {
335 s << parenthesize(apply(*container.begin()));
336 } else {
337 s << apply(*container.begin());
338 }
339
340 for (auto it = ++(container.begin()); it != container.end(); ++it) {
341 s << " \\vee ";
342 if (is_a<And>(**it) or is_a<Xor>(**it)) {
343 s << parenthesize(apply(*it));
344 } else {
345 s << apply(*it);
346 }
347 }
348 str_ = s.str();
349}
350
351void LatexPrinter::bvisit(const Xor &x)
352{
354 auto container = x.get_container();
355 if (is_a<Or>(**container.begin()) or is_a<And>(**container.begin())) {
356 s << parenthesize(apply(*container.begin()));
357 } else {
358 s << apply(*container.begin());
359 }
360
361 for (auto it = ++(container.begin()); it != container.end(); ++it) {
362 s << " \\veebar ";
363 if (is_a<Or>(**it) or is_a<And>(**it)) {
364 s << parenthesize(apply(*it));
365 } else {
366 s << apply(*it);
367 }
368 }
369 str_ = s.str();
370}
371
372void LatexPrinter::print_with_args(const Basic &x, const std::string &join,
374{
375 vec_basic v = x.get_args();
376 s << apply(*v.begin());
377
378 for (auto it = ++(v.begin()); it != v.end(); ++it) {
379 s << " " << join << " " << apply(*it);
380 }
381}
382
383void LatexPrinter::bvisit(const Not &x)
384{
385 str_ = "\\neg " + apply(*x.get_arg());
386}
387
388void LatexPrinter::bvisit(const Union &x)
389{
391 print_with_args(x, "\\cup", s);
392 str_ = s.str();
393}
394
395void LatexPrinter::bvisit(const Intersection &x)
396{
398 print_with_args(x, "\\cap", s);
399 str_ = s.str();
400}
401
402void LatexPrinter::bvisit(const Complement &x)
403{
405 s << apply(x.get_universe()) << " \\setminus " << apply(x.get_container());
406 str_ = s.str();
407}
408
409void LatexPrinter::bvisit(const ImageSet &x)
410{
412 s << "\\left\\{" << apply(*x.get_expr()) << "\\; |\\; ";
413 s << apply(*x.get_symbol());
414 s << " \\in " << apply(*x.get_baseset()) << "\\right\\}";
415 str_ = s.str();
416}
417
418void LatexPrinter::bvisit(const ConditionSet &x)
419{
421 s << "\\left\\{" << apply(*x.get_symbol()) << "\\; |\\; ";
422 s << apply(x.get_condition()) << "\\right\\}";
423 str_ = s.str();
424}
425
426void LatexPrinter::bvisit(const EmptySet &x)
427{
428 str_ = "\\emptyset";
429}
430
431void LatexPrinter::bvisit(const Complexes &x)
432{
433 str_ = "\\mathbb{C}";
434}
435
436void LatexPrinter::bvisit(const Reals &x)
437{
438 str_ = "\\mathbb{R}";
439}
440
441void LatexPrinter::bvisit(const Rationals &x)
442{
443 str_ = "\\mathbb{Q}";
444}
445
446void LatexPrinter::bvisit(const Integers &x)
447{
448 str_ = "\\mathbb{Z}";
449}
450
451void LatexPrinter::bvisit(const Naturals &x)
452{
453 str_ = "\\mathbb{N}";
454}
455
456void LatexPrinter::bvisit(const Naturals0 &x)
457{
458 str_ = "\\mathbb{N}_0";
459}
460
461void LatexPrinter::bvisit(const FiniteSet &x)
462{
464 s << "\\left{";
465 print_with_args(x, ",", s);
466 s << "\\right}";
467 str_ = s.str();
468}
469
470void LatexPrinter::bvisit(const Contains &x)
471{
473 s << apply(x.get_expr()) << " \\in " << apply(x.get_set());
474 str_ = s.str();
475}
476
477std::string LatexPrinter::print_mul()
478{
479 return " ";
480}
481
482bool LatexPrinter::split_mul_coef()
483{
484 return true;
485}
486
487std::vector<std::string> init_latex_printer_names()
488{
489 std::vector<std::string> names = init_str_printer_names();
490
491 for (unsigned i = 0; i < names.size(); i++) {
492 if (names[i] != "") {
493 names[i] = "\\operatorname{" + names[i] + "}";
494 }
495 }
496 names[SYMENGINE_SIN] = "\\sin";
497 names[SYMENGINE_COS] = "\\cos";
498 names[SYMENGINE_TAN] = "\\tan";
499 names[SYMENGINE_COT] = "\\cot";
500 names[SYMENGINE_CSC] = "\\csc";
501 names[SYMENGINE_SEC] = "\\sec";
502 names[SYMENGINE_ATAN2] = "\\operatorname{atan_2}";
503 names[SYMENGINE_SINH] = "\\sinh";
504 names[SYMENGINE_COSH] = "\\cosh";
505 names[SYMENGINE_TANH] = "\\tanh";
506 names[SYMENGINE_COTH] = "\\coth";
507 names[SYMENGINE_LOG] = "\\log";
508 names[SYMENGINE_ZETA] = "\\zeta";
509 names[SYMENGINE_LAMBERTW] = "\\operatorname{W}";
510 names[SYMENGINE_DIRICHLET_ETA] = "\\eta";
511 names[SYMENGINE_KRONECKERDELTA] = "\\delta_";
512 names[SYMENGINE_LEVICIVITA] = "\\varepsilon_";
513 names[SYMENGINE_LOWERGAMMA] = "\\gamma";
514 names[SYMENGINE_UPPERGAMMA] = "\\Gamma";
515 names[SYMENGINE_BETA] = "\\operatorname{B}";
516 names[SYMENGINE_LOG] = "\\log";
517 names[SYMENGINE_GAMMA] = "\\Gamma";
518 names[SYMENGINE_TRUNCATE] = "\\operatorname{truncate}";
519 names[SYMENGINE_PRIMEPI] = "\\pi";
520 return names;
521}
522
523void LatexPrinter::bvisit(const Function &x)
524{
525 static const std::vector<std::string> names_ = init_latex_printer_names();
527 o << names_[x.get_type_code()] << "{";
528 vec_basic vec = x.get_args();
529 o << parenthesize(apply(vec)) << "}";
530 str_ = o.str();
531}
532
533void LatexPrinter::bvisit(const Floor &x)
534{
536 o << "\\lfloor{" << apply(x.get_arg()) << "}\\rfloor";
537 str_ = o.str();
538}
539
540void LatexPrinter::bvisit(const Ceiling &x)
541{
543 o << "\\lceil{" << apply(x.get_arg()) << "}\\rceil";
544 str_ = o.str();
545}
546
547void LatexPrinter::bvisit(const Abs &x)
548{
550 o << "\\left|" << apply(x.get_arg()) << "}\\right|";
551 str_ = o.str();
552}
553
554std::string LatexPrinter::parenthesize(const std::string &expr)
555{
556 return "\\left(" + expr + "\\right)";
557}
558
559void LatexPrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
560 const RCP<const Basic> &b)
561{
562 if (eq(*a, *E)) {
563 o << "e^{" << apply(b) << "}";
564 } else if (eq(*b, *rational(1, 2))) {
565 o << "\\sqrt{" << apply(a) << "}";
566 } else if (is_a<Rational>(*b)
567 and eq(*static_cast<const Rational &>(*b).get_num(), *one)) {
568 o << "\\sqrt[" << apply(static_cast<const Rational &>(*b).get_den())
569 << "]{" << apply(a) << "}";
570 } else {
571 o << parenthesizeLE(a, PrecedenceEnum::Pow);
572 Precedence prec;
573 auto b_str = apply(b);
574 if (b_str.size() > 1) {
575 o << "^{" << b_str << "}";
576 } else {
577 o << "^" << b_str;
578 }
579 }
580}
581
582std::string LatexPrinter::print_div(const std::string &num,
583 const std::string &den, bool paren)
584{
585 return "\\frac{" + num + "}{" + den + "}";
586}
587
588void LatexPrinter::bvisit(const Piecewise &x)
589{
591 s << "\\begin{cases} ";
592 const auto &vec = x.get_vec();
593 auto it = vec.begin();
594 auto it_last = --vec.end();
595 while (it != vec.end()) {
596 s << apply(it->first);
597 if (it == it_last) {
598 if (eq(*it->second, *boolTrue)) {
599 s << " & \\text{otherwise} \\end{cases}";
600 } else {
601 s << " & \\text{for}\\: ";
602 s << apply(it->second);
603 s << " \\end{cases}";
604 }
605 } else {
606 s << " & \\text{for}\\: ";
607 s << apply(it->second);
608 s << "\\\\";
609 }
610 it++;
611 }
612 str_ = s.str();
613}
614
615void LatexPrinter::bvisit(const Tuple &x)
616{
618 vec_basic vec = x.get_args();
619 o << parenthesize(apply(vec));
620 str_ = o.str();
621}
622
623} // namespace SymEngine
The base class for SymEngine.
T count(T... args)
T endl(T... args)
T find(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
T prev(T... args)
T size(T... args)
T str(T... args)
T substr(T... args)