latex.cpp
1 #include <symengine/printers/latex.h>
2 #include <symengine/printers.h>
3 #include <symengine/basic.h>
4 
5 namespace SymEngine
6 {
7 
8 std::string latex(const Basic &x)
9 {
10  LatexPrinter p;
11  return p.apply(x);
12 }
13 
14 void print_rational_class(const rational_class &r, std::ostringstream &s)
15 {
16  if (get_den(r) == 1) {
17  s << get_num(r);
18  } else {
19  s << "\\frac{" << get_num(r) << "}{" << get_den(r) << "}";
20  }
21 }
22 
23 void LatexPrinter::bvisit(const Symbol &x)
24 {
25  std::string name = x.get_name();
26 
27  if (name.find('\\') != std::string::npos
28  or name.find('{') != std::string::npos) {
29  str_ = name;
30  return;
31  }
32  if (name[0] == '_') {
33  name = name.substr(1, name.size());
34  }
36  = {"alpha", "beta", "gamma", "Gamma", "delta", "Delta", "epsilon",
37  "zeta", "eta", "theta", "Theta", "iota", "kappa", "lambda",
38  "Lambda", "mu", "nu", "xi", "omicron", "pi", "Pi",
39  "rho", "sigma", "Sigma", "tau", "upsilon", "Upsilon", "phi",
40  "Phi", "chi", "psi", "Psi", "omega", "Omega"};
41 
42  str_ = name;
43  for (auto &letter : greeks) {
44  if (name == letter) {
45  str_ = "\\" + name;
46  break;
47  }
48  if (name.size() > letter.size() and name.find(letter + "_") == 0) {
49  str_ = "\\" + name;
50  break;
51  }
52  }
53 
54  if (name.find("_") != std::string::npos) {
55  int count = 0;
56  std::string prev = str_;
57  str_ = "";
58  for (size_t i = 0; i < prev.size(); i++) {
59  const char &c = prev[i];
60  if (c == '_' and prev.size() > i + 2) {
61  str_ += "_{";
62  count++;
63  } else {
64  str_ += c;
65  }
66  }
67  str_ += std::string(count, '}');
68  }
69  return;
70 }
71 
72 void LatexPrinter::bvisit(const Rational &x)
73 {
74  const auto &rational = x.as_rational_class();
76  print_rational_class(rational, s);
77  str_ = s.str();
78 }
79 
80 void LatexPrinter::bvisit(const Complex &x)
81 {
83  if (x.real_ != 0) {
84  print_rational_class(x.real_, s);
85  // Since Complex is in canonical form, imaginary_ is not 0.
86  if (mp_sign(x.imaginary_) == 1) {
87  s << " + ";
88  } else {
89  s << " - ";
90  }
91  // If imaginary_ is not 1 or -1, print the absolute value
92  if (x.imaginary_ != mp_sign(x.imaginary_)) {
93  print_rational_class(mp_abs(x.imaginary_), s);
94  s << "j";
95  } else {
96  s << "j";
97  }
98  } else {
99  if (x.imaginary_ != mp_sign(x.imaginary_)) {
100  print_rational_class(x.imaginary_, s);
101  s << "j";
102  } else {
103  if (mp_sign(x.imaginary_) == 1) {
104  s << "j";
105  } else {
106  s << "-j";
107  }
108  }
109  }
110  str_ = s.str();
111 }
112 
113 void LatexPrinter::bvisit(const ComplexBase &x)
114 {
115  RCP<const Number> imag = x.imaginary_part();
116  if (imag->is_negative()) {
117  std::string str = apply(imag);
118  str = str.substr(1, str.length() - 1);
119  str_ = apply(x.real_part()) + " - " + str + "j";
120  } else {
121  str_ = apply(x.real_part()) + " + " + apply(imag) + "j";
122  }
123 }
124 void LatexPrinter::bvisit(const ComplexDouble &x)
125 {
126  bvisit(static_cast<const ComplexBase &>(x));
127 }
128 
129 #ifdef HAVE_SYMENGINE_MPC
130 void LatexPrinter::bvisit(const ComplexMPC &x)
131 {
132  bvisit(static_cast<const ComplexBase &>(x));
133 }
134 #endif
135 
136 void LatexPrinter::bvisit(const Infty &x)
137 {
138  if (x.is_negative_infinity()) {
139  str_ = "-\\infty";
140  } else if (x.is_positive_infinity()) {
141  str_ = "\\infty";
142  } else {
143  str_ = "\\tilde{\\infty}";
144  }
145 }
146 
147 void LatexPrinter::bvisit(const NaN &x)
148 {
149  str_ = "\\mathrm{NaN}";
150 }
151 
152 void LatexPrinter::bvisit(const Constant &x)
153 {
154  if (eq(x, *pi)) {
155  str_ = "\\pi";
156  } else if (eq(x, *E)) {
157  str_ = "e";
158  } else if (eq(x, *EulerGamma)) {
159  str_ = "\\gamma";
160  } else if (eq(x, *Catalan)) {
161  str_ = "G";
162  } else if (eq(x, *GoldenRatio)) {
163  str_ = "\\phi";
164  } else {
165  throw NotImplementedError("Constant " + x.get_name()
166  + " is not implemented.");
167  }
168 }
169 
170 void LatexPrinter::bvisit(const Derivative &x)
171 {
172  const auto &symbols = x.get_symbols();
174  if (symbols.size() == 1) {
175  if (free_symbols(*x.get_arg()).size() == 1) {
176  s << "\\frac{d}{d " << apply(*symbols.begin());
177  } else {
178  s << "\\frac{\\partial}{\\partial " << apply(*symbols.begin());
179  }
180  } else {
181  s << "\\frac{\\partial^" << symbols.size() << "}{";
182  unsigned count = 1;
183  auto it = symbols.begin();
184  RCP<const Basic> prev = *it;
185  ++it;
186  for (; it != symbols.end(); ++it) {
187  if (neq(*prev, **it)) {
188  if (count == 1) {
189  s << "\\partial " << apply(*prev) << " ";
190  } else {
191  s << "\\partial " << apply(*prev) << "^" << count << " ";
192  }
193  count = 1;
194  } else {
195  count++;
196  }
197  prev = *it;
198  }
199  if (count == 1) {
200  s << "\\partial " << apply(*prev) << " ";
201  } else {
202  s << "\\partial " << apply(*prev) << "^" << count << " ";
203  }
204  }
205  s << "} " << apply(x.get_arg());
206  str_ = s.str();
207 }
208 
209 void LatexPrinter::bvisit(const Subs &x)
210 {
212  o << "\\left. " << apply(x.get_arg()) << "\\right|_{\\substack{";
213  for (auto p = x.get_dict().begin(); p != x.get_dict().end(); p++) {
214  if (p != x.get_dict().begin()) {
215  o << " \\\\ ";
216  }
217  o << apply(p->first) << "=" << apply(p->second);
218  }
219  o << "}}";
220  str_ = o.str();
221 }
222 
223 void LatexPrinter::bvisit(const Equality &x)
224 {
226  s << apply(x.get_arg1()) << " = " << apply(x.get_arg2());
227  str_ = s.str();
228 }
229 
230 void LatexPrinter::bvisit(const Unequality &x)
231 {
233  s << apply(x.get_arg1()) << " \\neq " << apply(x.get_arg2());
234  str_ = s.str();
235 }
236 
237 void LatexPrinter::bvisit(const LessThan &x)
238 {
240  s << apply(x.get_arg1()) << " \\leq " << apply(x.get_arg2());
241  str_ = s.str();
242 }
243 
244 void LatexPrinter::bvisit(const StrictLessThan &x)
245 {
247  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
248  str_ = s.str();
249 }
250 
251 std::string latex(const DenseMatrix &m, const unsigned max_rows,
252  const unsigned max_cols)
253 {
254  const unsigned int nrows = m.nrows();
255  const unsigned int ncols = m.ncols();
256  unsigned int nrows_display = nrows;
257  if (nrows > max_rows)
258  nrows_display = max_rows - 1;
259  unsigned int ncols_display = ncols;
260  if (ncols > max_cols)
261  ncols_display = max_cols - 1;
262 
264  s << "\\left[\\begin{matrix}" << std::endl;
265 
266  std::string end_of_line = " \\\\\n";
267  if (ncols_display < ncols) {
268  end_of_line = " & \\cdots" + end_of_line;
269  }
270  for (unsigned int row_index = 0; row_index < nrows_display; row_index++) {
271  for (unsigned int column_index = 0; column_index < ncols_display;
272  column_index++) {
273  RCP<const Basic> v = m.get(row_index, column_index);
274 
275  if (v.is_null()) {
276  // element has not been initalized
277  throw SymEngineException(
278  "cannot display uninitialized element");
279  } else {
280  s << latex(*v);
281  }
282  if (column_index < ncols_display - 1)
283  s << " & ";
284  }
285  s << end_of_line;
286  }
287  if (nrows_display < nrows) {
288  for (unsigned int column_index = 0; column_index < ncols_display;
289  column_index++) {
290  s << "\\vdots";
291  if (column_index < ncols_display - 1)
292  s << " & ";
293  }
294  s << end_of_line;
295  }
296  s << "\\end{matrix}\\right]\n";
297 
298  return s.str();
299 }
300 
301 void LatexPrinter::bvisit(const Interval &x)
302 {
304  if (x.get_left_open())
305  s << "\\left(";
306  else
307  s << "\\left[";
308  s << *x.get_start() << ", " << *x.get_end();
309  if (x.get_right_open())
310  s << "\\right)";
311  else
312  s << "\\right]";
313  str_ = s.str();
314 }
315 
316 void LatexPrinter::bvisit(const BooleanAtom &x)
317 {
318  if (x.get_val()) {
319  str_ = "\\mathrm{True}";
320  } else {
321  str_ = "\\mathrm{False}";
322  }
323 }
324 
325 void LatexPrinter::bvisit(const And &x)
326 {
328  auto container = x.get_container();
329  if (is_a<Or>(**container.begin()) or is_a<Xor>(**container.begin())) {
330  s << parenthesize(apply(*container.begin()));
331  } else {
332  s << apply(*container.begin());
333  }
334 
335  for (auto it = ++(container.begin()); it != container.end(); ++it) {
336  s << " \\wedge ";
337  if (is_a<Or>(**it) or is_a<Xor>(**it)) {
338  s << parenthesize(apply(*it));
339  } else {
340  s << apply(*it);
341  }
342  }
343  str_ = s.str();
344 }
345 
346 void LatexPrinter::bvisit(const Or &x)
347 {
349  auto container = x.get_container();
350  if (is_a<And>(**container.begin()) or is_a<Xor>(**container.begin())) {
351  s << parenthesize(apply(*container.begin()));
352  } else {
353  s << apply(*container.begin());
354  }
355 
356  for (auto it = ++(container.begin()); it != container.end(); ++it) {
357  s << " \\vee ";
358  if (is_a<And>(**it) or is_a<Xor>(**it)) {
359  s << parenthesize(apply(*it));
360  } else {
361  s << apply(*it);
362  }
363  }
364  str_ = s.str();
365 }
366 
367 void LatexPrinter::bvisit(const Xor &x)
368 {
370  auto container = x.get_container();
371  if (is_a<Or>(**container.begin()) or is_a<And>(**container.begin())) {
372  s << parenthesize(apply(*container.begin()));
373  } else {
374  s << apply(*container.begin());
375  }
376 
377  for (auto it = ++(container.begin()); it != container.end(); ++it) {
378  s << " \\veebar ";
379  if (is_a<Or>(**it) or is_a<And>(**it)) {
380  s << parenthesize(apply(*it));
381  } else {
382  s << apply(*it);
383  }
384  }
385  str_ = s.str();
386 }
387 
388 void LatexPrinter::print_with_args(const Basic &x, const std::string &join,
390 {
391  vec_basic v = x.get_args();
392  s << apply(*v.begin());
393 
394  for (auto it = ++(v.begin()); it != v.end(); ++it) {
395  s << " " << join << " " << apply(*it);
396  }
397 }
398 
399 void LatexPrinter::bvisit(const Not &x)
400 {
401  str_ = "\\neg " + apply(*x.get_arg());
402 }
403 
404 void LatexPrinter::bvisit(const Union &x)
405 {
407  print_with_args(x, "\\cup", s);
408  str_ = s.str();
409 }
410 
411 void LatexPrinter::bvisit(const Intersection &x)
412 {
414  print_with_args(x, "\\cap", s);
415  str_ = s.str();
416 }
417 
418 void LatexPrinter::bvisit(const Complement &x)
419 {
421  s << apply(x.get_universe()) << " \\setminus " << apply(x.get_container());
422  str_ = s.str();
423 }
424 
425 void LatexPrinter::bvisit(const ImageSet &x)
426 {
428  s << "\\left\\{" << apply(*x.get_expr()) << "\\; |\\; ";
429  s << apply(*x.get_symbol());
430  s << " \\in " << apply(*x.get_baseset()) << "\\right\\}";
431  str_ = s.str();
432 }
433 
434 void LatexPrinter::bvisit(const ConditionSet &x)
435 {
437  s << "\\left\\{" << apply(*x.get_symbol()) << "\\; |\\; ";
438  s << apply(x.get_condition()) << "\\right\\}";
439  str_ = s.str();
440 }
441 
442 void LatexPrinter::bvisit(const EmptySet &x)
443 {
444  str_ = "\\emptyset";
445 }
446 
447 void LatexPrinter::bvisit(const Complexes &x)
448 {
449  str_ = "\\mathbb{C}";
450 }
451 
452 void LatexPrinter::bvisit(const Reals &x)
453 {
454  str_ = "\\mathbb{R}";
455 }
456 
457 void LatexPrinter::bvisit(const Rationals &x)
458 {
459  str_ = "\\mathbb{Q}";
460 }
461 
462 void LatexPrinter::bvisit(const Integers &x)
463 {
464  str_ = "\\mathbb{Z}";
465 }
466 
467 void LatexPrinter::bvisit(const Naturals &x)
468 {
469  str_ = "\\mathbb{N}";
470 }
471 
472 void LatexPrinter::bvisit(const Naturals0 &x)
473 {
474  str_ = "\\mathbb{N}_0";
475 }
476 
477 void LatexPrinter::bvisit(const FiniteSet &x)
478 {
480  s << "\\left{";
481  print_with_args(x, ",", s);
482  s << "\\right}";
483  str_ = s.str();
484 }
485 
486 void LatexPrinter::bvisit(const Contains &x)
487 {
489  s << apply(x.get_expr()) << " \\in " << apply(x.get_set());
490  str_ = s.str();
491 }
492 
493 std::string LatexPrinter::print_mul()
494 {
495  return " ";
496 }
497 
498 bool LatexPrinter::split_mul_coef()
499 {
500  return true;
501 }
502 
503 std::vector<std::string> init_latex_printer_names()
504 {
505  std::vector<std::string> names = init_str_printer_names();
506 
507  for (unsigned i = 0; i < names.size(); i++) {
508  if (names[i] != "") {
509  names[i] = "\\operatorname{" + names[i] + "}";
510  }
511  }
512  names[SYMENGINE_SIN] = "\\sin";
513  names[SYMENGINE_COS] = "\\cos";
514  names[SYMENGINE_TAN] = "\\tan";
515  names[SYMENGINE_COT] = "\\cot";
516  names[SYMENGINE_CSC] = "\\csc";
517  names[SYMENGINE_SEC] = "\\sec";
518  names[SYMENGINE_ATAN2] = "\\operatorname{atan_2}";
519  names[SYMENGINE_SINH] = "\\sinh";
520  names[SYMENGINE_COSH] = "\\cosh";
521  names[SYMENGINE_TANH] = "\\tanh";
522  names[SYMENGINE_COTH] = "\\coth";
523  names[SYMENGINE_LOG] = "\\log";
524  names[SYMENGINE_ZETA] = "\\zeta";
525  names[SYMENGINE_LAMBERTW] = "\\operatorname{W}";
526  names[SYMENGINE_DIRICHLET_ETA] = "\\eta";
527  names[SYMENGINE_KRONECKERDELTA] = "\\delta_";
528  names[SYMENGINE_LEVICIVITA] = "\\varepsilon_";
529  names[SYMENGINE_LOWERGAMMA] = "\\gamma";
530  names[SYMENGINE_UPPERGAMMA] = "\\Gamma";
531  names[SYMENGINE_BETA] = "\\operatorname{B}";
532  names[SYMENGINE_LOG] = "\\log";
533  names[SYMENGINE_GAMMA] = "\\Gamma";
534  names[SYMENGINE_TRUNCATE] = "\\operatorname{truncate}";
535  names[SYMENGINE_PRIMEPI] = "\\pi";
536  return names;
537 }
538 
539 void LatexPrinter::bvisit(const Function &x)
540 {
541  static const std::vector<std::string> names_ = init_latex_printer_names();
543  o << names_[x.get_type_code()] << "{";
544  vec_basic vec = x.get_args();
545  o << parenthesize(apply(vec)) << "}";
546  str_ = o.str();
547 }
548 
549 void LatexPrinter::bvisit(const Floor &x)
550 {
552  o << "\\lfloor{" << apply(x.get_arg()) << "}\\rfloor";
553  str_ = o.str();
554 }
555 
556 void LatexPrinter::bvisit(const Ceiling &x)
557 {
559  o << "\\lceil{" << apply(x.get_arg()) << "}\\rceil";
560  str_ = o.str();
561 }
562 
563 void LatexPrinter::bvisit(const Abs &x)
564 {
566  o << "\\left|" << apply(x.get_arg()) << "\\right|";
567  str_ = o.str();
568 }
569 
570 std::string LatexPrinter::parenthesize(const std::string &expr)
571 {
572  return "\\left(" + expr + "\\right)";
573 }
574 
575 void LatexPrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
576  const RCP<const Basic> &b)
577 {
578  if (eq(*a, *E)) {
579  o << "e^{" << apply(b) << "}";
580  } else if (eq(*b, *rational(1, 2))) {
581  o << "\\sqrt{" << apply(a) << "}";
582  } else if (is_a<Rational>(*b)
583  and eq(*static_cast<const Rational &>(*b).get_num(), *one)) {
584  o << "\\sqrt[" << apply(static_cast<const Rational &>(*b).get_den())
585  << "]{" << apply(a) << "}";
586  } else {
587  o << parenthesizeLE(a, PrecedenceEnum::Pow);
588  Precedence prec;
589  auto b_str = apply(b);
590  if (b_str.size() > 1) {
591  o << "^{" << b_str << "}";
592  } else {
593  o << "^" << b_str;
594  }
595  }
596 }
597 
598 std::string LatexPrinter::print_div(const std::string &num,
599  const std::string &den, bool paren)
600 {
601  return "\\frac{" + num + "}{" + den + "}";
602 }
603 
604 void LatexPrinter::bvisit(const Piecewise &x)
605 {
607  s << "\\begin{cases} ";
608  const auto &vec = x.get_vec();
609  auto it = vec.begin();
610  auto it_last = --vec.end();
611  while (it != vec.end()) {
612  s << apply(it->first);
613  if (it == it_last) {
614  if (eq(*it->second, *boolTrue)) {
615  s << " & \\text{otherwise} \\end{cases}";
616  } else {
617  s << " & \\text{for}\\: ";
618  s << apply(it->second);
619  s << " \\end{cases}";
620  }
621  } else {
622  s << " & \\text{for}\\: ";
623  s << apply(it->second);
624  s << "\\\\";
625  }
626  it++;
627  }
628  str_ = s.str();
629 }
630 
631 void LatexPrinter::bvisit(const Tuple &x)
632 {
634  vec_basic vec = x.get_args();
635  o << parenthesize(apply(vec));
636  str_ = o.str();
637 }
638 
639 } // namespace SymEngine
The base class for SymEngine.
T count(T... args)
T endl(T... args)
T find(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328
T prev(T... args)
T size(T... args)
T str(T... args)
T substr(T... args)