subs.h
1 #ifndef SYMENGINE_SUBS_H
2 #define SYMENGINE_SUBS_H
3 
4 #include <symengine/logic.h>
5 #include <symengine/visitor.h>
6 
7 namespace SymEngine
8 {
9 // xreplace replaces subtrees of a node in the expression tree
10 // with a new subtree
11 RCP<const Basic> xreplace(const RCP<const Basic> &x,
12  const map_basic_basic &subs_dict, bool cache = true);
13 // subs substitutes expressions similar to xreplace, but keeps
14 // the mathematical equivalence for derivatives and subs
15 RCP<const Basic> subs(const RCP<const Basic> &x,
16  const map_basic_basic &subs_dict, bool cache = true);
17 // port of sympy.physics.mechanics.msubs where f'(x) and f(x)
18 // are considered independent
19 RCP<const Basic> msubs(const RCP<const Basic> &x,
20  const map_basic_basic &subs_dict, bool cache = true);
21 // port of sympy's subs where subs inside derivatives are done
22 RCP<const Basic> ssubs(const RCP<const Basic> &x,
23  const map_basic_basic &subs_dict, bool cache = true);
24 
25 class XReplaceVisitor : public BaseVisitor<XReplaceVisitor>
26 {
27 
28 protected:
29  RCP<const Basic> result_;
30  const map_basic_basic &subs_dict_;
31  map_basic_basic visited;
32  bool cache;
33 
34 public:
35  XReplaceVisitor(const map_basic_basic &subs_dict, bool cache = true)
36  : subs_dict_(subs_dict), cache(cache)
37  {
38  if (cache) {
39  visited = subs_dict;
40  }
41  }
42  // TODO : Polynomials, Series, Sets
43  void bvisit(const Basic &x)
44  {
45  result_ = x.rcp_from_this();
46  }
47 
48  void bvisit(const Add &x)
49  {
51  RCP<const Number> coef;
52 
53  auto it = subs_dict_.find(x.get_coef());
54  if (it != subs_dict_.end()) {
55  coef = zero;
56  Add::coef_dict_add_term(outArg(coef), d, one, it->second);
57  } else {
58  coef = x.get_coef();
59  }
60 
61  for (const auto &p : x.get_dict()) {
62  auto it
63  = subs_dict_.find(Add::from_dict(zero, {{p.first, p.second}}));
64  if (it != subs_dict_.end()) {
65  Add::coef_dict_add_term(outArg(coef), d, one, it->second);
66  } else {
67  it = subs_dict_.find(p.second);
68  if (it != subs_dict_.end()) {
69  Add::coef_dict_add_term(outArg(coef), d, one,
70  mul(it->second, apply(p.first)));
71  } else {
72  Add::coef_dict_add_term(outArg(coef), d, p.second,
73  apply(p.first));
74  }
75  }
76  }
77  result_ = Add::from_dict(coef, std::move(d));
78  }
79 
80  void bvisit(const Mul &x)
81  {
82  RCP<const Number> coef = one;
84  for (const auto &p : x.get_dict()) {
85  RCP<const Basic> factor_old;
86  if (eq(*p.second, *one)) {
87  factor_old = p.first;
88  } else {
89  factor_old = make_rcp<Pow>(p.first, p.second);
90  }
91  RCP<const Basic> factor = apply(factor_old);
92  if (factor == factor_old) {
93  // TODO: Check if Mul::dict_add_term is enough
94  Mul::dict_add_term_new(outArg(coef), d, p.second, p.first);
95  } else if (is_a_Number(*factor)) {
96  imulnum(outArg(coef), rcp_static_cast<const Number>(factor));
97  } else if (is_a<Mul>(*factor)) {
98  RCP<const Mul> tmp = rcp_static_cast<const Mul>(factor);
99  imulnum(outArg(coef), tmp->get_coef());
100  for (const auto &q : tmp->get_dict()) {
101  Mul::dict_add_term_new(outArg(coef), d, q.second, q.first);
102  }
103  } else {
104  RCP<const Basic> exp, t;
105  Mul::as_base_exp(factor, outArg(exp), outArg(t));
106  Mul::dict_add_term_new(outArg(coef), d, exp, t);
107  }
108  }
109 
110  // Replace the coefficient
111  RCP<const Basic> factor = apply(x.get_coef());
112  if (is_a_Number(*factor)) {
113  imulnum(outArg(coef), rcp_static_cast<const Number>(factor));
114  } else if (is_a<Mul>(*factor)) {
115  RCP<const Mul> tmp = rcp_static_cast<const Mul>(factor);
116  imulnum(outArg(coef), tmp->get_coef());
117  for (const auto &q : tmp->get_dict()) {
118  Mul::dict_add_term_new(outArg(coef), d, q.second, q.first);
119  }
120  } else {
121  RCP<const Basic> exp, t;
122  Mul::as_base_exp(factor, outArg(exp), outArg(t));
123  Mul::dict_add_term_new(outArg(coef), d, exp, t);
124  }
125  result_ = Mul::from_dict(coef, std::move(d));
126  }
127 
128  void bvisit(const Pow &x)
129  {
130  RCP<const Basic> base_new = apply(x.get_base());
131  RCP<const Basic> exp_new = apply(x.get_exp());
132  if (base_new == x.get_base() and exp_new == x.get_exp()) {
133  result_ = x.rcp_from_this();
134  } else {
135  result_ = pow(base_new, exp_new);
136  }
137  }
138 
139  void bvisit(const OneArgFunction &x)
140  {
141  apply(x.get_arg());
142  if (result_ == x.get_arg()) {
143  result_ = x.rcp_from_this();
144  } else {
145  result_ = x.create(result_);
146  }
147  }
148 
149  template <class T>
150  void bvisit(const TwoArgBasic<T> &x)
151  {
152  RCP<const Basic> a = apply(x.get_arg1());
153  RCP<const Basic> b = apply(x.get_arg2());
154  if (a == x.get_arg1() and b == x.get_arg2())
155  result_ = x.rcp_from_this();
156  else
157  result_ = x.create(a, b);
158  }
159 
160  void bvisit(const MultiArgFunction &x)
161  {
162  vec_basic v = x.get_args();
163  for (auto &elem : v) {
164  elem = apply(elem);
165  }
166  result_ = x.create(v);
167  }
168 
169  void bvisit(const FunctionSymbol &x)
170  {
171  vec_basic v = x.get_args();
172  for (auto &elem : v) {
173  elem = apply(elem);
174  }
175  result_ = x.create(v);
176  }
177 
178  void bvisit(const Contains &x)
179  {
180  RCP<const Basic> a = apply(x.get_expr());
181  auto c = apply(x.get_set());
182  if (not is_a_Set(*c))
183  throw SymEngineException("expected an object of type Set");
184  RCP<const Set> b = rcp_static_cast<const Set>(c);
185  if (a == x.get_expr() and b == x.get_set())
186  result_ = x.rcp_from_this();
187  else
188  result_ = x.create(a, b);
189  }
190 
191  void bvisit(const And &x)
192  {
193  set_boolean v;
194  for (const auto &elem : x.get_container()) {
195  auto a = apply(elem);
196  if (not is_a_Boolean(*a))
197  throw SymEngineException("expected an object of type Boolean");
198  v.insert(rcp_static_cast<const Boolean>(a));
199  }
200  result_ = logical_and(v);
201  }
202 
203  void bvisit(const Or &x)
204  {
205  set_boolean v;
206  for (const auto &elem : x.get_container()) {
207  auto a = apply(elem);
208  if (not is_a_Boolean(*a))
209  throw SymEngineException("expected an object of type Boolean");
210  v.insert(rcp_static_cast<const Boolean>(a));
211  }
212  result_ = logical_or(v);
213  }
214 
215  void bvisit(const Not &x)
216  {
217  RCP<const Basic> a = apply(x.get_arg());
218  if (not is_a_Boolean(*a))
219  throw SymEngineException("expected an object of type Boolean");
220  result_ = logical_not(rcp_static_cast<const Boolean>(a));
221  }
222 
223  void bvisit(const Xor &x)
224  {
225  vec_boolean v;
226  for (const auto &elem : x.get_container()) {
227  auto a = apply(elem);
228  if (not is_a_Boolean(*a))
229  throw SymEngineException("expected an object of type Boolean");
230  v.push_back(rcp_static_cast<const Boolean>(a));
231  }
232  result_ = logical_xor(v);
233  }
234 
235  void bvisit(const FiniteSet &x)
236  {
237  set_basic v;
238  for (const auto &elem : x.get_container()) {
239  v.insert(apply(elem));
240  }
241  result_ = x.create(v);
242  }
243 
244  void bvisit(const ImageSet &x)
245  {
246  RCP<const Basic> s = apply(x.get_symbol());
247  RCP<const Basic> expr = apply(x.get_expr());
248  auto bs_ = apply(x.get_baseset());
249  if (not is_a_Set(*bs_))
250  throw SymEngineException("expected an object of type Set");
251  RCP<const Set> bs = rcp_static_cast<const Set>(bs_);
252  if (s == x.get_symbol() and expr == x.get_expr()
253  and bs == x.get_baseset()) {
254  result_ = x.rcp_from_this();
255  } else {
256  result_ = x.create(s, expr, bs);
257  }
258  }
259 
260  void bvisit(const Union &x)
261  {
262  set_set v;
263  for (const auto &elem : x.get_container()) {
264  auto a = apply(elem);
265  if (not is_a_Set(*a))
266  throw SymEngineException("expected an object of type Set");
267  v.insert(rcp_static_cast<const Set>(a));
268  }
269  result_ = x.create(v);
270  }
271 
272  void bvisit(const Piecewise &pw)
273  {
274  PiecewiseVec pwv;
275  pwv.reserve(pw.get_vec().size());
276  for (const auto &expr_pred : pw.get_vec()) {
277  const auto expr = apply(*expr_pred.first);
278  const auto pred = apply(*expr_pred.second);
279  pwv.emplace_back(
280  std::make_pair(expr, rcp_static_cast<const Boolean>(pred)));
281  }
282  result_ = piecewise(std::move(pwv));
283  }
284 
285  void bvisit(const Derivative &x)
286  {
287  auto expr = apply(x.get_arg());
288  for (const auto &sym : x.get_symbols()) {
289  auto s = apply(sym);
290  if (not is_a<Symbol>(*s)) {
291  throw SymEngineException("expected an object of type Symbol");
292  }
293  expr = expr->diff(rcp_static_cast<const Symbol>(s));
294  }
295  result_ = expr;
296  }
297 
298  void bvisit(const Subs &x)
299  {
300  auto expr = apply(x.get_arg());
301  map_basic_basic new_subs_dict;
302  for (const auto &sym : x.get_dict()) {
303  insert(new_subs_dict, apply(sym.first), apply(sym.second));
304  }
305  result_ = subs(expr, new_subs_dict);
306  }
307 
308  void bvisit(const ComplexBase &x)
309  {
310  auto it = subs_dict_.find(I);
311  if (it != subs_dict_.end()) {
312  result_ = add(apply(x.real_part()),
313  mul(apply(x.imaginary_part()), it->second));
314  } else {
315  result_ = x.rcp_from_this();
316  }
317  }
318 
319  RCP<const Basic> apply(const Basic &x)
320  {
321  return apply(x.rcp_from_this());
322  }
323 
324  RCP<const Basic> apply(const RCP<const Basic> &x)
325  {
326  if (cache) {
327  auto it = visited.find(x);
328  if (it != visited.end()) {
329  result_ = it->second;
330  } else {
331  x->accept(*this);
332  insert(visited, x, result_);
333  }
334  } else {
335  auto it = subs_dict_.find(x);
336  if (it != subs_dict_.end()) {
337  result_ = it->second;
338  } else {
339  x->accept(*this);
340  }
341  }
342  return result_;
343  }
344 };
345 
347 inline RCP<const Basic> xreplace(const RCP<const Basic> &x,
348  const map_basic_basic &subs_dict, bool cache)
349 {
350  XReplaceVisitor s(subs_dict, cache);
351  return s.apply(x);
352 }
353 
354 class SubsVisitor : public BaseVisitor<SubsVisitor, XReplaceVisitor>
355 {
356 public:
357  using XReplaceVisitor::bvisit;
358 
359  SubsVisitor(const map_basic_basic &subs_dict_, bool cache = true)
360  : BaseVisitor<SubsVisitor, XReplaceVisitor>(subs_dict_, cache)
361  {
362  }
363 
364  void bvisit(const Pow &x)
365  {
366  RCP<const Basic> base_new = apply(x.get_base());
367  RCP<const Basic> exp_new = apply(x.get_exp());
368  if (subs_dict_.size() == 1 and is_a<Pow>(*((*subs_dict_.begin()).first))
369  and not is_a<Add>(
370  *down_cast<const Pow &>(*(*subs_dict_.begin()).first)
371  .get_exp())) {
372  auto &subs_first
373  = down_cast<const Pow &>(*(*subs_dict_.begin()).first);
374  if (eq(*subs_first.get_base(), *base_new)) {
375  auto newexpo = div(exp_new, subs_first.get_exp());
376  if (is_a_Number(*newexpo) or is_a<Constant>(*newexpo)) {
377  result_ = pow((*subs_dict_.begin()).second, newexpo);
378  return;
379  }
380  }
381  }
382  if (base_new == x.get_base() and exp_new == x.get_exp()) {
383  result_ = x.rcp_from_this();
384  } else {
385  result_ = pow(base_new, exp_new);
386  }
387  }
388 
389  void bvisit(const Derivative &x)
390  {
391  RCP<const Symbol> s;
392  map_basic_basic m, n;
393  bool subs;
394 
395  for (const auto &p : subs_dict_) {
396  // If the derivative arg is to be replaced in its entirety, allow
397  // it.
398  if (eq(*x.get_arg(), *p.first)) {
399  RCP<const Basic> t = p.second;
400  for (auto &sym : x.get_symbols()) {
401  if (not is_a<Symbol>(*sym)) {
402  throw SymEngineException("Error, expected a Symbol.");
403  }
404  t = t->diff(rcp_static_cast<const Symbol>(sym));
405  }
406  result_ = t;
407  return;
408  }
409  }
410  for (const auto &p : subs_dict_) {
411  subs = true;
412  if (eq(*x.get_arg()->subs({{p.first, p.second}}), *x.get_arg()))
413  continue;
414 
415  // If p.first and p.second are symbols and arg_ is
416  // independent of p.second, p.first can be replaced
417  if (is_a<Symbol>(*p.first) and is_a<Symbol>(*p.second)
418  and eq(
419  *x.get_arg()->diff(rcp_static_cast<const Symbol>(p.second)),
420  *zero)) {
421  insert(n, p.first, p.second);
422  continue;
423  }
424  for (const auto &d : x.get_symbols()) {
425  if (is_a<Symbol>(*d)) {
426  s = rcp_static_cast<const Symbol>(d);
427  // If p.first or p.second has non zero derivates wrt to s
428  // p.first cannot be replaced
429  if (neq(*zero, *(p.first->diff(s)))
430  || neq(*zero, *(p.second->diff(s)))) {
431  subs = false;
432  break;
433  }
434  } else {
435  result_
436  = make_rcp<const Subs>(x.rcp_from_this(), subs_dict_);
437  return;
438  }
439  }
440  if (subs) {
441  insert(n, p.first, p.second);
442  } else {
443  insert(m, p.first, p.second);
444  }
445  }
446  auto t = x.get_arg()->subs(n);
447  for (auto &p : x.get_symbols()) {
448  auto t2 = p->subs(n);
449  if (not is_a<Symbol>(*t2)) {
450  throw SymEngineException("Error, expected a Symbol.");
451  }
452  t = t->diff(rcp_static_cast<const Symbol>(t2));
453  }
454  if (m.empty()) {
455  result_ = t;
456  } else {
457  result_ = make_rcp<const Subs>(t, m);
458  }
459  }
460 
461  void bvisit(const Subs &x)
462  {
463  map_basic_basic m, n;
464  for (const auto &p : subs_dict_) {
465  bool found = false;
466  for (const auto &s : x.get_dict()) {
467  if (neq(*(s.first->subs({{p.first, p.second}})), *(s.first))) {
468  found = true;
469  break;
470  }
471  }
472  // If p.first is not replaced in arg_ by dict_,
473  // store p.first in n to replace in arg_
474  if (not found) {
475  insert(n, p.first, p.second);
476  }
477  }
478  for (const auto &s : x.get_dict()) {
479  insert(m, s.first, apply(s.second));
480  }
481  RCP<const Basic> presub = x.get_arg()->subs(n);
482  if (is_a<Subs>(*presub)) {
483  for (auto &q : down_cast<const Subs &>(*presub).get_dict()) {
484  insert(m, q.first, q.second);
485  }
486  result_ = down_cast<const Subs &>(*presub).get_arg()->subs(m);
487  } else {
488  result_ = presub->subs(m);
489  }
490  }
491 };
492 
493 class MSubsVisitor : public BaseVisitor<MSubsVisitor, XReplaceVisitor>
494 {
495 public:
496  using XReplaceVisitor::bvisit;
497 
498  MSubsVisitor(const map_basic_basic &d, bool cache = true)
500  {
501  }
502 
503  void bvisit(const Derivative &x)
504  {
505  result_ = x.rcp_from_this();
506  }
507 
508  void bvisit(const Subs &x)
509  {
510  map_basic_basic m = x.get_dict();
511  for (const auto &p : subs_dict_) {
512  m[p.first] = p.second;
513  }
514  result_ = msubs(x.get_arg(), m);
515  }
516 };
517 
518 class SSubsVisitor : public BaseVisitor<SSubsVisitor, SubsVisitor>
519 {
520 public:
521  using XReplaceVisitor::bvisit;
522 
523  SSubsVisitor(const map_basic_basic &d, bool cache = true)
525  {
526  }
527 
528  void bvisit(const Derivative &x)
529  {
530  apply(x.get_arg());
531  auto t = result_;
532  multiset_basic m;
533  for (auto &p : x.get_symbols()) {
534  apply(p);
535  m.insert(result_);
536  }
537  result_ = Derivative::create(t, m);
538  }
539 
540  void bvisit(const Subs &x)
541  {
542  map_basic_basic m = x.get_dict();
543  for (const auto &p : subs_dict_) {
544  m[p.first] = p.second;
545  }
546  result_ = ssubs(x.get_arg(), m);
547  }
548 };
549 
551 inline RCP<const Basic> msubs(const RCP<const Basic> &x,
552  const map_basic_basic &subs_dict, bool cache)
553 {
554  MSubsVisitor s(subs_dict, cache);
555  return s.apply(x);
556 }
557 
559 inline RCP<const Basic> ssubs(const RCP<const Basic> &x,
560  const map_basic_basic &subs_dict, bool cache)
561 {
562  SSubsVisitor s(subs_dict, cache);
563  return s.apply(x);
564 }
565 
566 inline RCP<const Basic> subs(const RCP<const Basic> &x,
567  const map_basic_basic &subs_dict, bool cache)
568 {
569  SubsVisitor b(subs_dict, cache);
570  return b.apply(x);
571 }
572 
573 } // namespace SymEngine
574 
575 #endif // SYMENGINE_SUBS_H
T begin(T... args)
The base class for representing addition in symbolic expressions.
Definition: add.h:27
static RCP< const Basic > from_dict(const RCP< const Number > &coef, umap_basic_num &&d)
Create an appropriate instance from dictionary quickly.
Definition: add.cpp:140
static void coef_dict_add_term(const Ptr< RCP< const Number >> &coef, umap_basic_num &d, const RCP< const Number > &c, const RCP< const Basic > &term)
Updates the numerical coefficient and the dictionary.
Definition: add.cpp:261
const RCP< const Number > & get_coef() const
Definition: add.h:142
The lowest unit of symbolic representation.
Definition: basic.h:97
RCP< const Basic > subs(const map_basic_basic &subs_dict) const
Substitutes 'subs_dict' into 'self'.
Definition: basic.cpp:80
ComplexBase Class for deriving all complex classes.
Definition: complex.h:16
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
RCP< const Basic > create(const vec_basic &x) const override
Method to construct classes with canonicalization.
Definition: functions.cpp:1905
static void as_base_exp(const RCP< const Basic > &self, const Ptr< RCP< const Basic >> &exp, const Ptr< RCP< const Basic >> &base)
Convert to a base and exponent form.
Definition: mul.cpp:320
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:115
virtual RCP< const Basic > create(const vec_basic &v) const =0
Method to construct classes with canonicalization.
vec_basic get_args() const override
Returns the list of arguments.
Definition: functions.h:159
virtual RCP< const Basic > create(const RCP< const Basic > &arg) const =0
Method to construct classes with canonicalization.
RCP< const Basic > get_arg() const
Definition: functions.h:36
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_base() const
Definition: pow.h:37
virtual RCP< const Basic > create(const RCP< const Basic > &a, const RCP< const Basic > &b) const =0
Method to construct classes with canonicalization.
RCP< const Basic > get_arg1() const
Definition: functions.h:91
RCP< const Basic > get_arg2() const
Definition: functions.h:96
T emplace_back(T... args)
T end(T... args)
T find(T... args)
T insert(T... args)
T make_pair(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a_Number(const Basic &b)
Definition: number.h:130
RCP< const Basic > div(const RCP< const Basic > &a, const RCP< const Basic > &b)
Division.
Definition: mul.cpp:431
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > msubs(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
Subs which treat f(t) and Derivative(f(t), t) as separate variables.
Definition: subs.h:551
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
Definition: pow.cpp:271
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
int factor(const Ptr< RCP< const Integer >> &f, const Integer &n, double B1)
Definition: ntheory.cpp:371
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
RCP< const Basic > xreplace(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
Mappings in the subs_dict are applied to the expression tree of x
Definition: subs.h:347
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > ssubs(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
SymPy compatible subs.
Definition: subs.h:559
T push_back(T... args)
T reserve(T... args)
T size(T... args)