latex.cpp
1 #include <symengine/printers/latex.h>
2 #include <symengine/printers.h>
3 #include <symengine/basic.h>
4 
5 namespace SymEngine
6 {
7 
8 std::string latex(const Basic &x)
9 {
10  LatexPrinter p;
11  return p.apply(x);
12 }
13 
14 void print_rational_class(const rational_class &r, std::ostringstream &s)
15 {
16  if (get_den(r) == 1) {
17  s << get_num(r);
18  } else {
19  s << "\\frac{" << get_num(r) << "}{" << get_den(r) << "}";
20  }
21 }
22 
23 void LatexPrinter::bvisit(const Symbol &x)
24 {
25  std::string name = x.get_name();
26 
27  if (name.find('\\') != std::string::npos
28  or name.find('{') != std::string::npos) {
29  str_ = name;
30  return;
31  }
32  if (name[0] == '_') {
33  name = name.substr(1, name.size());
34  }
36  = {"alpha", "beta", "gamma", "Gamma", "delta", "Delta", "epsilon",
37  "zeta", "eta", "theta", "Theta", "iota", "kappa", "lambda",
38  "Lambda", "mu", "nu", "xi", "omicron", "pi", "Pi",
39  "rho", "sigma", "Sigma", "tau", "upsilon", "Upsilon", "phi",
40  "Phi", "chi", "psi", "Psi", "omega", "Omega"};
41 
42  for (auto &letter : greeks) {
43  if (name == letter) {
44  str_ = "\\" + name;
45  return;
46  }
47  if (name.size() > letter.size() and name.find(letter + "_") == 0) {
48  str_ = "\\" + name;
49  return;
50  }
51  }
52  str_ = name;
53  return;
54 }
55 
56 void LatexPrinter::bvisit(const Rational &x)
57 {
58  const auto &rational = x.as_rational_class();
60  print_rational_class(rational, s);
61  str_ = s.str();
62 }
63 
64 void LatexPrinter::bvisit(const Complex &x)
65 {
67  if (x.real_ != 0) {
68  print_rational_class(x.real_, s);
69  // Since Complex is in canonical form, imaginary_ is not 0.
70  if (mp_sign(x.imaginary_) == 1) {
71  s << " + ";
72  } else {
73  s << " - ";
74  }
75  // If imaginary_ is not 1 or -1, print the absolute value
76  if (x.imaginary_ != mp_sign(x.imaginary_)) {
77  print_rational_class(mp_abs(x.imaginary_), s);
78  s << "j";
79  } else {
80  s << "j";
81  }
82  } else {
83  if (x.imaginary_ != mp_sign(x.imaginary_)) {
84  print_rational_class(x.imaginary_, s);
85  s << "j";
86  } else {
87  if (mp_sign(x.imaginary_) == 1) {
88  s << "j";
89  } else {
90  s << "-j";
91  }
92  }
93  }
94  str_ = s.str();
95 }
96 
97 void LatexPrinter::bvisit(const ComplexBase &x)
98 {
99  RCP<const Number> imag = x.imaginary_part();
100  if (imag->is_negative()) {
101  std::string str = apply(imag);
102  str = str.substr(1, str.length() - 1);
103  str_ = apply(x.real_part()) + " - " + str + "j";
104  } else {
105  str_ = apply(x.real_part()) + " + " + apply(imag) + "j";
106  }
107 }
108 void LatexPrinter::bvisit(const ComplexDouble &x)
109 {
110  bvisit(static_cast<const ComplexBase &>(x));
111 }
112 
113 #ifdef HAVE_SYMENGINE_MPC
114 void LatexPrinter::bvisit(const ComplexMPC &x)
115 {
116  bvisit(static_cast<const ComplexBase &>(x));
117 }
118 #endif
119 
120 void LatexPrinter::bvisit(const Infty &x)
121 {
122  if (x.is_negative_infinity()) {
123  str_ = "-\\infty";
124  } else if (x.is_positive_infinity()) {
125  str_ = "\\infty";
126  } else {
127  str_ = "\\tilde{\\infty}";
128  }
129 }
130 
131 void LatexPrinter::bvisit(const NaN &x)
132 {
133  str_ = "\\mathrm{NaN}";
134 }
135 
136 void LatexPrinter::bvisit(const Constant &x)
137 {
138  if (eq(x, *pi)) {
139  str_ = "\\pi";
140  } else if (eq(x, *E)) {
141  str_ = "e";
142  } else if (eq(x, *EulerGamma)) {
143  str_ = "\\gamma";
144  } else if (eq(x, *Catalan)) {
145  str_ = "G";
146  } else if (eq(x, *GoldenRatio)) {
147  str_ = "\\phi";
148  } else {
149  throw NotImplementedError("Constant " + x.get_name()
150  + " is not implemented.");
151  }
152 }
153 
154 void LatexPrinter::bvisit(const Derivative &x)
155 {
156  const auto &symbols = x.get_symbols();
158  if (symbols.size() == 1) {
159  if (free_symbols(*x.get_arg()).size() == 1) {
160  s << "\\frac{d}{d " << apply(*symbols.begin());
161  } else {
162  s << "\\frac{\\partial}{\\partial " << apply(*symbols.begin());
163  }
164  } else {
165  s << "\\frac{\\partial^" << symbols.size() << "}{";
166  unsigned count = 1;
167  auto it = symbols.begin();
168  RCP<const Basic> prev = *it;
169  ++it;
170  for (; it != symbols.end(); ++it) {
171  if (neq(*prev, **it)) {
172  if (count == 1) {
173  s << "\\partial " << apply(*prev) << " ";
174  } else {
175  s << "\\partial " << apply(*prev) << "^" << count << " ";
176  }
177  count = 1;
178  } else {
179  count++;
180  }
181  prev = *it;
182  }
183  if (count == 1) {
184  s << "\\partial " << apply(*prev) << " ";
185  } else {
186  s << "\\partial " << apply(*prev) << "^" << count << " ";
187  }
188  }
189  s << "} " << apply(x.get_arg());
190  str_ = s.str();
191 }
192 
193 void LatexPrinter::bvisit(const Subs &x)
194 {
196  o << "\\left. " << apply(x.get_arg()) << "\\right|_{\\substack{";
197  for (auto p = x.get_dict().begin(); p != x.get_dict().end(); p++) {
198  if (p != x.get_dict().begin()) {
199  o << " \\\\ ";
200  }
201  o << apply(p->first) << "=" << apply(p->second);
202  }
203  o << "}}";
204  str_ = o.str();
205 }
206 
207 void LatexPrinter::bvisit(const Equality &x)
208 {
210  s << apply(x.get_arg1()) << " = " << apply(x.get_arg2());
211  str_ = s.str();
212 }
213 
214 void LatexPrinter::bvisit(const Unequality &x)
215 {
217  s << apply(x.get_arg1()) << " \\neq " << apply(x.get_arg2());
218  str_ = s.str();
219 }
220 
221 void LatexPrinter::bvisit(const LessThan &x)
222 {
224  s << apply(x.get_arg1()) << " \\leq " << apply(x.get_arg2());
225  str_ = s.str();
226 }
227 
228 void LatexPrinter::bvisit(const StrictLessThan &x)
229 {
231  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
232  str_ = s.str();
233 }
234 
235 std::string latex(const DenseMatrix &m, const unsigned max_rows,
236  const unsigned max_cols)
237 {
238  const unsigned int nrows = m.nrows();
239  const unsigned int ncols = m.ncols();
240  unsigned int nrows_display = nrows;
241  if (nrows > max_rows)
242  nrows_display = max_rows - 1;
243  unsigned int ncols_display = ncols;
244  if (ncols > max_cols)
245  ncols_display = max_cols - 1;
246 
248  s << "\\left[\\begin{matrix}" << std::endl;
249 
250  std::string end_of_line = " \\\\\n";
251  if (ncols_display < ncols) {
252  end_of_line = " & \\cdots" + end_of_line;
253  }
254  for (unsigned int row_index = 0; row_index < nrows_display; row_index++) {
255  for (unsigned int column_index = 0; column_index < ncols_display;
256  column_index++) {
257  RCP<const Basic> v = m.get(row_index, column_index);
258 
259  if (v.is_null()) {
260  // element has not been initalized
261  throw SymEngineException(
262  "cannot display uninitialized element");
263  } else {
264  s << latex(*v);
265  }
266  if (column_index < ncols_display - 1)
267  s << " & ";
268  }
269  s << end_of_line;
270  }
271  if (nrows_display < nrows) {
272  for (unsigned int column_index = 0; column_index < ncols_display;
273  column_index++) {
274  s << "\\vdots";
275  if (column_index < ncols_display - 1)
276  s << " & ";
277  }
278  s << end_of_line;
279  }
280  s << "\\end{matrix}\\right]\n";
281 
282  return s.str();
283 }
284 
285 void LatexPrinter::bvisit(const Interval &x)
286 {
288  if (x.get_left_open())
289  s << "\\left(";
290  else
291  s << "\\left[";
292  s << *x.get_start() << ", " << *x.get_end();
293  if (x.get_right_open())
294  s << "\\right)";
295  else
296  s << "\\right]";
297  str_ = s.str();
298 }
299 
300 void LatexPrinter::bvisit(const BooleanAtom &x)
301 {
302  if (x.get_val()) {
303  str_ = "\\mathrm{True}";
304  } else {
305  str_ = "\\mathrm{False}";
306  }
307 }
308 
309 void LatexPrinter::bvisit(const And &x)
310 {
312  auto container = x.get_container();
313  if (is_a<Or>(**container.begin()) or is_a<Xor>(**container.begin())) {
314  s << parenthesize(apply(*container.begin()));
315  } else {
316  s << apply(*container.begin());
317  }
318 
319  for (auto it = ++(container.begin()); it != container.end(); ++it) {
320  s << " \\wedge ";
321  if (is_a<Or>(**it) or is_a<Xor>(**it)) {
322  s << parenthesize(apply(*it));
323  } else {
324  s << apply(*it);
325  }
326  }
327  str_ = s.str();
328 }
329 
330 void LatexPrinter::bvisit(const Or &x)
331 {
333  auto container = x.get_container();
334  if (is_a<And>(**container.begin()) or is_a<Xor>(**container.begin())) {
335  s << parenthesize(apply(*container.begin()));
336  } else {
337  s << apply(*container.begin());
338  }
339 
340  for (auto it = ++(container.begin()); it != container.end(); ++it) {
341  s << " \\vee ";
342  if (is_a<And>(**it) or is_a<Xor>(**it)) {
343  s << parenthesize(apply(*it));
344  } else {
345  s << apply(*it);
346  }
347  }
348  str_ = s.str();
349 }
350 
351 void LatexPrinter::bvisit(const Xor &x)
352 {
354  auto container = x.get_container();
355  if (is_a<Or>(**container.begin()) or is_a<And>(**container.begin())) {
356  s << parenthesize(apply(*container.begin()));
357  } else {
358  s << apply(*container.begin());
359  }
360 
361  for (auto it = ++(container.begin()); it != container.end(); ++it) {
362  s << " \\veebar ";
363  if (is_a<Or>(**it) or is_a<And>(**it)) {
364  s << parenthesize(apply(*it));
365  } else {
366  s << apply(*it);
367  }
368  }
369  str_ = s.str();
370 }
371 
372 void LatexPrinter::print_with_args(const Basic &x, const std::string &join,
374 {
375  vec_basic v = x.get_args();
376  s << apply(*v.begin());
377 
378  for (auto it = ++(v.begin()); it != v.end(); ++it) {
379  s << " " << join << " " << apply(*it);
380  }
381 }
382 
383 void LatexPrinter::bvisit(const Not &x)
384 {
385  str_ = "\\neg " + apply(*x.get_arg());
386 }
387 
388 void LatexPrinter::bvisit(const Union &x)
389 {
391  print_with_args(x, "\\cup", s);
392  str_ = s.str();
393 }
394 
395 void LatexPrinter::bvisit(const Intersection &x)
396 {
398  print_with_args(x, "\\cap", s);
399  str_ = s.str();
400 }
401 
402 void LatexPrinter::bvisit(const Complement &x)
403 {
405  s << apply(x.get_universe()) << " \\setminus " << apply(x.get_container());
406  str_ = s.str();
407 }
408 
409 void LatexPrinter::bvisit(const ImageSet &x)
410 {
412  s << "\\left\\{" << apply(*x.get_expr()) << "\\; |\\; ";
413  s << apply(*x.get_symbol());
414  s << " \\in " << apply(*x.get_baseset()) << "\\right\\}";
415  str_ = s.str();
416 }
417 
418 void LatexPrinter::bvisit(const ConditionSet &x)
419 {
421  s << "\\left\\{" << apply(*x.get_symbol()) << "\\; |\\; ";
422  s << apply(x.get_condition()) << "\\right\\}";
423  str_ = s.str();
424 }
425 
426 void LatexPrinter::bvisit(const EmptySet &x)
427 {
428  str_ = "\\emptyset";
429 }
430 
431 void LatexPrinter::bvisit(const Complexes &x)
432 {
433  str_ = "\\mathbb{C}";
434 }
435 
436 void LatexPrinter::bvisit(const Reals &x)
437 {
438  str_ = "\\mathbb{R}";
439 }
440 
441 void LatexPrinter::bvisit(const Rationals &x)
442 {
443  str_ = "\\mathbb{Q}";
444 }
445 
446 void LatexPrinter::bvisit(const Integers &x)
447 {
448  str_ = "\\mathbb{Z}";
449 }
450 
451 void LatexPrinter::bvisit(const Naturals &x)
452 {
453  str_ = "\\mathbb{N}";
454 }
455 
456 void LatexPrinter::bvisit(const Naturals0 &x)
457 {
458  str_ = "\\mathbb{N}_0";
459 }
460 
461 void LatexPrinter::bvisit(const FiniteSet &x)
462 {
464  s << "\\left{";
465  print_with_args(x, ",", s);
466  s << "\\right}";
467  str_ = s.str();
468 }
469 
470 void LatexPrinter::bvisit(const Contains &x)
471 {
473  s << apply(x.get_expr()) << " \\in " << apply(x.get_set());
474  str_ = s.str();
475 }
476 
477 std::string LatexPrinter::print_mul()
478 {
479  return " ";
480 }
481 
482 bool LatexPrinter::split_mul_coef()
483 {
484  return true;
485 }
486 
487 std::vector<std::string> init_latex_printer_names()
488 {
489  std::vector<std::string> names = init_str_printer_names();
490 
491  for (unsigned i = 0; i < names.size(); i++) {
492  if (names[i] != "") {
493  names[i] = "\\operatorname{" + names[i] + "}";
494  }
495  }
496  names[SYMENGINE_SIN] = "\\sin";
497  names[SYMENGINE_COS] = "\\cos";
498  names[SYMENGINE_TAN] = "\\tan";
499  names[SYMENGINE_COT] = "\\cot";
500  names[SYMENGINE_CSC] = "\\csc";
501  names[SYMENGINE_SEC] = "\\sec";
502  names[SYMENGINE_ATAN2] = "\\operatorname{atan_2}";
503  names[SYMENGINE_SINH] = "\\sinh";
504  names[SYMENGINE_COSH] = "\\cosh";
505  names[SYMENGINE_TANH] = "\\tanh";
506  names[SYMENGINE_COTH] = "\\coth";
507  names[SYMENGINE_LOG] = "\\log";
508  names[SYMENGINE_ZETA] = "\\zeta";
509  names[SYMENGINE_LAMBERTW] = "\\operatorname{W}";
510  names[SYMENGINE_DIRICHLET_ETA] = "\\eta";
511  names[SYMENGINE_KRONECKERDELTA] = "\\delta_";
512  names[SYMENGINE_LEVICIVITA] = "\\varepsilon_";
513  names[SYMENGINE_LOWERGAMMA] = "\\gamma";
514  names[SYMENGINE_UPPERGAMMA] = "\\Gamma";
515  names[SYMENGINE_BETA] = "\\operatorname{B}";
516  names[SYMENGINE_LOG] = "\\log";
517  names[SYMENGINE_GAMMA] = "\\Gamma";
518  names[SYMENGINE_TRUNCATE] = "\\operatorname{truncate}";
519  names[SYMENGINE_PRIMEPI] = "\\pi";
520  return names;
521 }
522 
523 void LatexPrinter::bvisit(const Function &x)
524 {
525  static const std::vector<std::string> names_ = init_latex_printer_names();
527  o << names_[x.get_type_code()] << "{";
528  vec_basic vec = x.get_args();
529  o << parenthesize(apply(vec)) << "}";
530  str_ = o.str();
531 }
532 
533 void LatexPrinter::bvisit(const Floor &x)
534 {
536  o << "\\lfloor{" << apply(x.get_arg()) << "}\\rfloor";
537  str_ = o.str();
538 }
539 
540 void LatexPrinter::bvisit(const Ceiling &x)
541 {
543  o << "\\lceil{" << apply(x.get_arg()) << "}\\rceil";
544  str_ = o.str();
545 }
546 
547 void LatexPrinter::bvisit(const Abs &x)
548 {
550  o << "\\left|" << apply(x.get_arg()) << "}\\right|";
551  str_ = o.str();
552 }
553 
554 std::string LatexPrinter::parenthesize(const std::string &expr)
555 {
556  return "\\left(" + expr + "\\right)";
557 }
558 
559 void LatexPrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
560  const RCP<const Basic> &b)
561 {
562  if (eq(*a, *E)) {
563  o << "e^{" << apply(b) << "}";
564  } else if (eq(*b, *rational(1, 2))) {
565  o << "\\sqrt{" << apply(a) << "}";
566  } else if (is_a<Rational>(*b)
567  and eq(*static_cast<const Rational &>(*b).get_num(), *one)) {
568  o << "\\sqrt[" << apply(static_cast<const Rational &>(*b).get_den())
569  << "]{" << apply(a) << "}";
570  } else {
571  o << parenthesizeLE(a, PrecedenceEnum::Pow);
572  Precedence prec;
573  auto b_str = apply(b);
574  if (b_str.size() > 1) {
575  o << "^{" << b_str << "}";
576  } else {
577  o << "^" << b_str;
578  }
579  }
580 }
581 
582 std::string LatexPrinter::print_div(const std::string &num,
583  const std::string &den, bool paren)
584 {
585  return "\\frac{" + num + "}{" + den + "}";
586 }
587 
588 void LatexPrinter::bvisit(const Piecewise &x)
589 {
591  s << "\\begin{cases} ";
592  const auto &vec = x.get_vec();
593  auto it = vec.begin();
594  auto it_last = --vec.end();
595  while (it != vec.end()) {
596  s << apply(it->first);
597  if (it == it_last) {
598  if (eq(*it->second, *boolTrue)) {
599  s << " & \\text{otherwise} \\end{cases}";
600  } else {
601  s << " & \\text{for}\\: ";
602  s << apply(it->second);
603  s << " \\end{cases}";
604  }
605  } else {
606  s << " & \\text{for}\\: ";
607  s << apply(it->second);
608  s << "\\\\";
609  }
610  it++;
611  }
612  str_ = s.str();
613 }
614 
615 void LatexPrinter::bvisit(const Tuple &x)
616 {
618  vec_basic vec = x.get_args();
619  o << parenthesize(apply(vec));
620  str_ = o.str();
621 }
622 
623 } // namespace SymEngine
The base class for SymEngine.
T count(T... args)
T endl(T... args)
T find(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328
T prev(T... args)
T size(T... args)
T str(T... args)
T substr(T... args)