latex.cpp
1 #include <symengine/printers/latex.h>
2 #include <symengine/printers.h>
3 #include <symengine/basic.h>
4 
5 namespace SymEngine
6 {
7 
8 std::string latex(const Basic &x)
9 {
10  LatexPrinter p;
11  return p.apply(x);
12 }
13 
14 void print_rational_class(const rational_class &r, std::ostringstream &s)
15 {
16  if (get_den(r) == 1) {
17  s << get_num(r);
18  } else {
19  s << "\\frac{" << get_num(r) << "}{" << get_den(r) << "}";
20  }
21 }
22 
23 void LatexPrinter::bvisit(const Symbol &x)
24 {
25  str_ = _print_symbol(x.get_name());
26 }
27 
28 std::string LatexPrinter::_print_symbol(const std::string &name)
29 {
30  if (name.find('\\') != std::string::npos
31  or name.find('{') != std::string::npos) {
32  return name;
33  }
34  if (name[0] == '_') {
35  return _print_symbol(name.substr(1, name.size() - 1));
36  }
37  std::vector<std::string> greeks
38  = {"alpha", "beta", "gamma", "Gamma", "delta", "Delta", "epsilon",
39  "zeta", "eta", "theta", "Theta", "iota", "kappa", "lambda",
40  "Lambda", "mu", "nu", "xi", "omicron", "pi", "Pi",
41  "rho", "sigma", "Sigma", "tau", "upsilon", "Upsilon", "phi",
42  "Phi", "chi", "psi", "Psi", "omega", "Omega"};
43 
44  for (auto &letter : greeks) {
45  if (name == letter) {
46  return "\\" + name;
47  }
48  }
49 
50  size_t idx = name.find("_");
51  if (idx == std::string::npos || idx == name.size() - 1) {
52  return name;
53  } else if (idx < name.size() - 2) {
54  return _print_symbol(name.substr(0, idx)) + "_{"
55  + _print_symbol(name.substr(idx + 1, name.size() - idx - 1))
56  + "}";
57  } else {
58  return _print_symbol(name.substr(0, idx)) + "_"
59  + name.substr(idx + 1, name.size() - idx - 1);
60  }
61 }
62 
63 void LatexPrinter::bvisit(const Rational &x)
64 {
65  const auto &rational = x.as_rational_class();
66  std::ostringstream s;
67  print_rational_class(rational, s);
68  str_ = s.str();
69 }
70 
71 void LatexPrinter::bvisit(const Complex &x)
72 {
73  std::ostringstream s;
74  if (x.real_ != 0) {
75  print_rational_class(x.real_, s);
76  // Since Complex is in canonical form, imaginary_ is not 0.
77  if (mp_sign(x.imaginary_) == 1) {
78  s << " + ";
79  } else {
80  s << " - ";
81  }
82  // If imaginary_ is not 1 or -1, print the absolute value
83  if (x.imaginary_ != mp_sign(x.imaginary_)) {
84  print_rational_class(mp_abs(x.imaginary_), s);
85  s << "j";
86  } else {
87  s << "j";
88  }
89  } else {
90  if (x.imaginary_ != mp_sign(x.imaginary_)) {
91  print_rational_class(x.imaginary_, s);
92  s << "j";
93  } else {
94  if (mp_sign(x.imaginary_) == 1) {
95  s << "j";
96  } else {
97  s << "-j";
98  }
99  }
100  }
101  str_ = s.str();
102 }
103 
104 void LatexPrinter::bvisit(const ComplexBase &x)
105 {
106  RCP<const Number> imag = x.imaginary_part();
107  if (imag->is_negative()) {
108  std::string str = apply(imag);
109  str = str.substr(1, str.length() - 1);
110  str_ = apply(x.real_part()) + " - " + str + "j";
111  } else {
112  str_ = apply(x.real_part()) + " + " + apply(imag) + "j";
113  }
114 }
115 void LatexPrinter::bvisit(const ComplexDouble &x)
116 {
117  bvisit(static_cast<const ComplexBase &>(x));
118 }
119 
120 #ifdef HAVE_SYMENGINE_MPC
121 void LatexPrinter::bvisit(const ComplexMPC &x)
122 {
123  bvisit(static_cast<const ComplexBase &>(x));
124 }
125 #endif
126 
127 void LatexPrinter::bvisit(const Infty &x)
128 {
129  if (x.is_negative_infinity()) {
130  str_ = "-\\infty";
131  } else if (x.is_positive_infinity()) {
132  str_ = "\\infty";
133  } else {
134  str_ = "\\tilde{\\infty}";
135  }
136 }
137 
138 void LatexPrinter::bvisit(const NaN &x)
139 {
140  str_ = "\\mathrm{NaN}";
141 }
142 
143 void LatexPrinter::bvisit(const Constant &x)
144 {
145  if (eq(x, *pi)) {
146  str_ = "\\pi";
147  } else if (eq(x, *E)) {
148  str_ = "e";
149  } else if (eq(x, *EulerGamma)) {
150  str_ = "\\gamma";
151  } else if (eq(x, *Catalan)) {
152  str_ = "G";
153  } else if (eq(x, *GoldenRatio)) {
154  str_ = "\\phi";
155  } else {
156  throw NotImplementedError("Constant " + x.get_name()
157  + " is not implemented.");
158  }
159 }
160 
161 void LatexPrinter::bvisit(const Derivative &x)
162 {
163  const auto &symbols = x.get_symbols();
164  std::ostringstream s;
165  if (symbols.size() == 1) {
166  if (free_symbols(*x.get_arg()).size() == 1) {
167  s << "\\frac{d}{d " << apply(*symbols.begin());
168  } else {
169  s << "\\frac{\\partial}{\\partial " << apply(*symbols.begin());
170  }
171  } else {
172  s << "\\frac{\\partial^" << symbols.size() << "}{";
173  unsigned count = 1;
174  auto it = symbols.begin();
175  RCP<const Basic> prev = *it;
176  ++it;
177  for (; it != symbols.end(); ++it) {
178  if (neq(*prev, **it)) {
179  if (count == 1) {
180  s << "\\partial " << apply(*prev) << " ";
181  } else {
182  s << "\\partial " << apply(*prev) << "^" << count << " ";
183  }
184  count = 1;
185  } else {
186  count++;
187  }
188  prev = *it;
189  }
190  if (count == 1) {
191  s << "\\partial " << apply(*prev) << " ";
192  } else {
193  s << "\\partial " << apply(*prev) << "^" << count << " ";
194  }
195  }
196  s << "} " << apply(x.get_arg());
197  str_ = s.str();
198 }
199 
200 void LatexPrinter::bvisit(const Subs &x)
201 {
202  std::ostringstream o;
203  o << "\\left. " << apply(x.get_arg()) << "\\right|_{\\substack{";
204  for (auto p = x.get_dict().begin(); p != x.get_dict().end(); p++) {
205  if (p != x.get_dict().begin()) {
206  o << " \\\\ ";
207  }
208  o << apply(p->first) << "=" << apply(p->second);
209  }
210  o << "}}";
211  str_ = o.str();
212 }
213 
214 void LatexPrinter::bvisit(const Equality &x)
215 {
216  std::ostringstream s;
217  s << apply(x.get_arg1()) << " = " << apply(x.get_arg2());
218  str_ = s.str();
219 }
220 
221 void LatexPrinter::bvisit(const Unequality &x)
222 {
223  std::ostringstream s;
224  s << apply(x.get_arg1()) << " \\neq " << apply(x.get_arg2());
225  str_ = s.str();
226 }
227 
228 void LatexPrinter::bvisit(const LessThan &x)
229 {
230  std::ostringstream s;
231  s << apply(x.get_arg1()) << " \\leq " << apply(x.get_arg2());
232  str_ = s.str();
233 }
234 
235 void LatexPrinter::bvisit(const StrictLessThan &x)
236 {
237  std::ostringstream s;
238  s << apply(x.get_arg1()) << " < " << apply(x.get_arg2());
239  str_ = s.str();
240 }
241 
242 std::string latex(const DenseMatrix &m, const unsigned max_rows,
243  const unsigned max_cols)
244 {
245  const unsigned int nrows = m.nrows();
246  const unsigned int ncols = m.ncols();
247  unsigned int nrows_display = nrows;
248  if (nrows > max_rows)
249  nrows_display = max_rows - 1;
250  unsigned int ncols_display = ncols;
251  if (ncols > max_cols)
252  ncols_display = max_cols - 1;
253 
254  std::ostringstream s;
255  s << "\\left[\\begin{matrix}" << std::endl;
256 
257  std::string end_of_line = " \\\\\n";
258  if (ncols_display < ncols) {
259  end_of_line = " & \\cdots" + end_of_line;
260  }
261  for (unsigned int row_index = 0; row_index < nrows_display; row_index++) {
262  for (unsigned int column_index = 0; column_index < ncols_display;
263  column_index++) {
264  RCP<const Basic> v = m.get(row_index, column_index);
265 
266  if (v.is_null()) {
267  // element has not been initalized
268  throw SymEngineException(
269  "cannot display uninitialized element");
270  } else {
271  s << latex(*v);
272  }
273  if (column_index < ncols_display - 1)
274  s << " & ";
275  }
276  s << end_of_line;
277  }
278  if (nrows_display < nrows) {
279  for (unsigned int column_index = 0; column_index < ncols_display;
280  column_index++) {
281  s << "\\vdots";
282  if (column_index < ncols_display - 1)
283  s << " & ";
284  }
285  s << end_of_line;
286  }
287  s << "\\end{matrix}\\right]\n";
288 
289  return s.str();
290 }
291 
292 void LatexPrinter::bvisit(const Interval &x)
293 {
294  std::ostringstream s;
295  if (x.get_left_open())
296  s << "\\left(";
297  else
298  s << "\\left[";
299  s << *x.get_start() << ", " << *x.get_end();
300  if (x.get_right_open())
301  s << "\\right)";
302  else
303  s << "\\right]";
304  str_ = s.str();
305 }
306 
307 void LatexPrinter::bvisit(const BooleanAtom &x)
308 {
309  if (x.get_val()) {
310  str_ = "\\mathrm{True}";
311  } else {
312  str_ = "\\mathrm{False}";
313  }
314 }
315 
316 void LatexPrinter::bvisit(const And &x)
317 {
318  std::ostringstream s;
319  auto container = x.get_container();
320  if (is_a<Or>(**container.begin()) or is_a<Xor>(**container.begin())) {
321  s << parenthesize(apply(*container.begin()));
322  } else {
323  s << apply(*container.begin());
324  }
325 
326  for (auto it = ++(container.begin()); it != container.end(); ++it) {
327  s << " \\wedge ";
328  if (is_a<Or>(**it) or is_a<Xor>(**it)) {
329  s << parenthesize(apply(*it));
330  } else {
331  s << apply(*it);
332  }
333  }
334  str_ = s.str();
335 }
336 
337 void LatexPrinter::bvisit(const Or &x)
338 {
339  std::ostringstream s;
340  auto container = x.get_container();
341  if (is_a<And>(**container.begin()) or is_a<Xor>(**container.begin())) {
342  s << parenthesize(apply(*container.begin()));
343  } else {
344  s << apply(*container.begin());
345  }
346 
347  for (auto it = ++(container.begin()); it != container.end(); ++it) {
348  s << " \\vee ";
349  if (is_a<And>(**it) or is_a<Xor>(**it)) {
350  s << parenthesize(apply(*it));
351  } else {
352  s << apply(*it);
353  }
354  }
355  str_ = s.str();
356 }
357 
358 void LatexPrinter::bvisit(const Xor &x)
359 {
360  std::ostringstream s;
361  auto container = x.get_container();
362  if (is_a<Or>(**container.begin()) or is_a<And>(**container.begin())) {
363  s << parenthesize(apply(*container.begin()));
364  } else {
365  s << apply(*container.begin());
366  }
367 
368  for (auto it = ++(container.begin()); it != container.end(); ++it) {
369  s << " \\veebar ";
370  if (is_a<Or>(**it) or is_a<And>(**it)) {
371  s << parenthesize(apply(*it));
372  } else {
373  s << apply(*it);
374  }
375  }
376  str_ = s.str();
377 }
378 
379 void LatexPrinter::print_with_args(const Basic &x, const std::string &join,
380  std::ostringstream &s)
381 {
382  vec_basic v = x.get_args();
383  s << apply(*v.begin());
384 
385  for (auto it = ++(v.begin()); it != v.end(); ++it) {
386  s << " " << join << " " << apply(*it);
387  }
388 }
389 
390 void LatexPrinter::bvisit(const Not &x)
391 {
392  str_ = "\\neg " + apply(*x.get_arg());
393 }
394 
395 void LatexPrinter::bvisit(const Union &x)
396 {
397  std::ostringstream s;
398  print_with_args(x, "\\cup", s);
399  str_ = s.str();
400 }
401 
402 void LatexPrinter::bvisit(const Intersection &x)
403 {
404  std::ostringstream s;
405  print_with_args(x, "\\cap", s);
406  str_ = s.str();
407 }
408 
409 void LatexPrinter::bvisit(const Complement &x)
410 {
411  std::ostringstream s;
412  s << apply(x.get_universe()) << " \\setminus " << apply(x.get_container());
413  str_ = s.str();
414 }
415 
416 void LatexPrinter::bvisit(const ImageSet &x)
417 {
418  std::ostringstream s;
419  s << "\\left\\{" << apply(*x.get_expr()) << "\\; |\\; ";
420  s << apply(*x.get_symbol());
421  s << " \\in " << apply(*x.get_baseset()) << "\\right\\}";
422  str_ = s.str();
423 }
424 
425 void LatexPrinter::bvisit(const ConditionSet &x)
426 {
427  std::ostringstream s;
428  s << "\\left\\{" << apply(*x.get_symbol()) << "\\; |\\; ";
429  s << apply(x.get_condition()) << "\\right\\}";
430  str_ = s.str();
431 }
432 
433 void LatexPrinter::bvisit(const EmptySet &x)
434 {
435  str_ = "\\emptyset";
436 }
437 
438 void LatexPrinter::bvisit(const Complexes &x)
439 {
440  str_ = "\\mathbb{C}";
441 }
442 
443 void LatexPrinter::bvisit(const Reals &x)
444 {
445  str_ = "\\mathbb{R}";
446 }
447 
448 void LatexPrinter::bvisit(const Rationals &x)
449 {
450  str_ = "\\mathbb{Q}";
451 }
452 
453 void LatexPrinter::bvisit(const Integers &x)
454 {
455  str_ = "\\mathbb{Z}";
456 }
457 
458 void LatexPrinter::bvisit(const Naturals &x)
459 {
460  str_ = "\\mathbb{N}";
461 }
462 
463 void LatexPrinter::bvisit(const Naturals0 &x)
464 {
465  str_ = "\\mathbb{N}_0";
466 }
467 
468 void LatexPrinter::bvisit(const FiniteSet &x)
469 {
470  std::ostringstream s;
471  s << "\\left{";
472  print_with_args(x, ",", s);
473  s << "\\right}";
474  str_ = s.str();
475 }
476 
477 void LatexPrinter::bvisit(const Contains &x)
478 {
479  std::ostringstream s;
480  s << apply(x.get_expr()) << " \\in " << apply(x.get_set());
481  str_ = s.str();
482 }
483 
484 std::string LatexPrinter::print_mul()
485 {
486  return " ";
487 }
488 
489 bool LatexPrinter::split_mul_coef()
490 {
491  return true;
492 }
493 
494 std::vector<std::string> init_latex_printer_names()
495 {
496  std::vector<std::string> names = init_str_printer_names();
497 
498  for (unsigned i = 0; i < names.size(); i++) {
499  if (names[i] != "") {
500  names[i] = "\\operatorname{" + names[i] + "}";
501  }
502  }
503  names[SYMENGINE_SIN] = "\\sin";
504  names[SYMENGINE_COS] = "\\cos";
505  names[SYMENGINE_TAN] = "\\tan";
506  names[SYMENGINE_COT] = "\\cot";
507  names[SYMENGINE_CSC] = "\\csc";
508  names[SYMENGINE_SEC] = "\\sec";
509  names[SYMENGINE_ATAN2] = "\\operatorname{atan_2}";
510  names[SYMENGINE_SINH] = "\\sinh";
511  names[SYMENGINE_COSH] = "\\cosh";
512  names[SYMENGINE_TANH] = "\\tanh";
513  names[SYMENGINE_COTH] = "\\coth";
514  names[SYMENGINE_LOG] = "\\log";
515  names[SYMENGINE_ZETA] = "\\zeta";
516  names[SYMENGINE_LAMBERTW] = "\\operatorname{W}";
517  names[SYMENGINE_DIRICHLET_ETA] = "\\eta";
518  names[SYMENGINE_KRONECKERDELTA] = "\\delta_";
519  names[SYMENGINE_LEVICIVITA] = "\\varepsilon_";
520  names[SYMENGINE_LOWERGAMMA] = "\\gamma";
521  names[SYMENGINE_UPPERGAMMA] = "\\Gamma";
522  names[SYMENGINE_BETA] = "\\operatorname{B}";
523  names[SYMENGINE_LOG] = "\\log";
524  names[SYMENGINE_GAMMA] = "\\Gamma";
525  names[SYMENGINE_TRUNCATE] = "\\operatorname{truncate}";
526  names[SYMENGINE_PRIMEPI] = "\\pi";
527  return names;
528 }
529 
530 void LatexPrinter::bvisit(const Function &x)
531 {
532  static const std::vector<std::string> names_ = init_latex_printer_names();
533  std::ostringstream o;
534  o << names_[x.get_type_code()] << "{";
535  vec_basic vec = x.get_args();
536  o << parenthesize(apply(vec)) << "}";
537  str_ = o.str();
538 }
539 
540 void LatexPrinter::bvisit(const Floor &x)
541 {
542  std::ostringstream o;
543  o << "\\lfloor{" << apply(x.get_arg()) << "}\\rfloor";
544  str_ = o.str();
545 }
546 
547 void LatexPrinter::bvisit(const Ceiling &x)
548 {
549  std::ostringstream o;
550  o << "\\lceil{" << apply(x.get_arg()) << "}\\rceil";
551  str_ = o.str();
552 }
553 
554 void LatexPrinter::bvisit(const Abs &x)
555 {
556  std::ostringstream o;
557  o << "\\left|" << apply(x.get_arg()) << "\\right|";
558  str_ = o.str();
559 }
560 
561 std::string LatexPrinter::parenthesize(const std::string &expr)
562 {
563  return "\\left(" + expr + "\\right)";
564 }
565 
566 void LatexPrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
567  const RCP<const Basic> &b)
568 {
569  if (eq(*a, *E)) {
570  o << "e^{" << apply(b) << "}";
571  } else if (eq(*b, *rational(1, 2))) {
572  o << "\\sqrt{" << apply(a) << "}";
573  } else if (is_a<Rational>(*b)
574  and eq(*static_cast<const Rational &>(*b).get_num(), *one)) {
575  o << "\\sqrt[" << apply(static_cast<const Rational &>(*b).get_den())
576  << "]{" << apply(a) << "}";
577  } else {
578  o << parenthesizeLE(a, PrecedenceEnum::Pow);
579  Precedence prec;
580  auto b_str = apply(b);
581  if (b_str.size() > 1) {
582  o << "^{" << b_str << "}";
583  } else {
584  o << "^" << b_str;
585  }
586  }
587 }
588 
589 std::string LatexPrinter::print_div(const std::string &num,
590  const std::string &den, bool paren)
591 {
592  return "\\frac{" + num + "}{" + den + "}";
593 }
594 
595 void LatexPrinter::bvisit(const Piecewise &x)
596 {
597  std::ostringstream s;
598  s << "\\begin{cases} ";
599  const auto &vec = x.get_vec();
600  auto it = vec.begin();
601  auto it_last = --vec.end();
602  while (it != vec.end()) {
603  s << apply(it->first);
604  if (it == it_last) {
605  if (eq(*it->second, *boolTrue)) {
606  s << " & \\text{otherwise} \\end{cases}";
607  } else {
608  s << " & \\text{for}\\: ";
609  s << apply(it->second);
610  s << " \\end{cases}";
611  }
612  } else {
613  s << " & \\text{for}\\: ";
614  s << apply(it->second);
615  s << "\\\\";
616  }
617  it++;
618  }
619  str_ = s.str();
620 }
621 
622 void LatexPrinter::bvisit(const Tuple &x)
623 {
624  std::ostringstream o;
625  vec_basic vec = x.get_args();
626  o << parenthesize(apply(vec));
627  str_ = o.str();
628 }
629 
630 } // namespace SymEngine
The base class for SymEngine.
Main namespace for SymEngine package.
Definition: add.cpp:19
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328