subs.h
1 #ifndef SYMENGINE_SUBS_H
2 #define SYMENGINE_SUBS_H
3 
4 #include <symengine/logic.h>
5 #include <symengine/visitor.h>
6 
7 namespace SymEngine
8 {
9 // xreplace replaces subtrees of a node in the expression tree
10 // with a new subtree
11 RCP<const Basic> xreplace(const RCP<const Basic> &x,
12  const map_basic_basic &subs_dict, bool cache = true);
13 // subs substitutes expressions similar to xreplace, but keeps
14 // the mathematical equivalence for derivatives and subs
15 RCP<const Basic> subs(const RCP<const Basic> &x,
16  const map_basic_basic &subs_dict, bool cache = true);
17 // port of sympy.physics.mechanics.msubs where f'(x) and f(x)
18 // are considered independent
19 RCP<const Basic> msubs(const RCP<const Basic> &x,
20  const map_basic_basic &subs_dict, bool cache = true);
21 // port of sympy's subs where subs inside derivatives are done
22 RCP<const Basic> ssubs(const RCP<const Basic> &x,
23  const map_basic_basic &subs_dict, bool cache = true);
24 
25 class XReplaceVisitor : public BaseVisitor<XReplaceVisitor>
26 {
27 
28 protected:
29  RCP<const Basic> result_;
30  const map_basic_basic &subs_dict_;
31  map_basic_basic visited;
32  bool cache;
33 
34 public:
35  XReplaceVisitor(const map_basic_basic &subs_dict, bool cache = true)
36  : subs_dict_(subs_dict), cache(cache)
37  {
38  if (cache) {
39  visited = subs_dict;
40  }
41  }
42  // TODO : Polynomials, Series, Sets
43  void bvisit(const Basic &x)
44  {
45  result_ = x.rcp_from_this();
46  }
47 
48  void bvisit(const Add &x)
49  {
51  RCP<const Number> coef;
52 
53  auto it = subs_dict_.find(x.get_coef());
54  if (it != subs_dict_.end()) {
55  coef = zero;
56  Add::coef_dict_add_term(outArg(coef), d, one, it->second);
57  } else {
58  coef = x.get_coef();
59  }
60 
61  for (const auto &p : x.get_dict()) {
62  auto it
63  = subs_dict_.find(Add::from_dict(zero, {{p.first, p.second}}));
64  if (it != subs_dict_.end()) {
65  Add::coef_dict_add_term(outArg(coef), d, one, it->second);
66  } else {
67  it = subs_dict_.find(p.second);
68  if (it != subs_dict_.end()) {
69  Add::coef_dict_add_term(outArg(coef), d, one,
70  mul(it->second, apply(p.first)));
71  } else {
72  Add::coef_dict_add_term(outArg(coef), d, p.second,
73  apply(p.first));
74  }
75  }
76  }
77  result_ = Add::from_dict(coef, std::move(d));
78  }
79 
80  void bvisit(const Mul &x)
81  {
82  RCP<const Number> coef = one;
84  for (const auto &p : x.get_dict()) {
85  RCP<const Basic> factor_old;
86  if (eq(*p.second, *one)) {
87  factor_old = p.first;
88  } else {
89  factor_old = make_rcp<Pow>(p.first, p.second);
90  }
91  RCP<const Basic> factor = apply(factor_old);
92  if (factor == factor_old) {
93  // TODO: Check if Mul::dict_add_term is enough
94  Mul::dict_add_term_new(outArg(coef), d, p.second, p.first);
95  } else if (is_a_Number(*factor)) {
96  if (down_cast<const Number &>(*factor).is_zero()) {
97  result_ = factor;
98  return;
99  }
100  imulnum(outArg(coef), rcp_static_cast<const Number>(factor));
101  } else if (is_a<Mul>(*factor)) {
102  RCP<const Mul> tmp = rcp_static_cast<const Mul>(factor);
103  imulnum(outArg(coef), tmp->get_coef());
104  for (const auto &q : tmp->get_dict()) {
105  Mul::dict_add_term_new(outArg(coef), d, q.second, q.first);
106  }
107  } else {
108  RCP<const Basic> exp, t;
109  Mul::as_base_exp(factor, outArg(exp), outArg(t));
110  Mul::dict_add_term_new(outArg(coef), d, exp, t);
111  }
112  }
113 
114  // Replace the coefficient
115  RCP<const Basic> factor = apply(x.get_coef());
116  RCP<const Basic> exp, t;
117  Mul::as_base_exp(factor, outArg(exp), outArg(t));
118  Mul::dict_add_term_new(outArg(coef), d, exp, t);
119 
120  result_ = Mul::from_dict(coef, std::move(d));
121  }
122 
123  void bvisit(const Pow &x)
124  {
125  RCP<const Basic> base_new = apply(x.get_base());
126  RCP<const Basic> exp_new = apply(x.get_exp());
127  if (base_new == x.get_base() and exp_new == x.get_exp()) {
128  result_ = x.rcp_from_this();
129  } else {
130  result_ = pow(base_new, exp_new);
131  }
132  }
133 
134  void bvisit(const OneArgFunction &x)
135  {
136  apply(x.get_arg());
137  if (result_ == x.get_arg()) {
138  result_ = x.rcp_from_this();
139  } else {
140  result_ = x.create(result_);
141  }
142  }
143 
144  template <class T>
145  void bvisit(const TwoArgBasic<T> &x)
146  {
147  RCP<const Basic> a = apply(x.get_arg1());
148  RCP<const Basic> b = apply(x.get_arg2());
149  if (a == x.get_arg1() and b == x.get_arg2())
150  result_ = x.rcp_from_this();
151  else
152  result_ = x.create(a, b);
153  }
154 
155  void bvisit(const MultiArgFunction &x)
156  {
157  vec_basic v = x.get_args();
158  for (auto &elem : v) {
159  elem = apply(elem);
160  }
161  result_ = x.create(v);
162  }
163 
164  void bvisit(const FunctionSymbol &x)
165  {
166  vec_basic v = x.get_args();
167  for (auto &elem : v) {
168  elem = apply(elem);
169  }
170  result_ = x.create(v);
171  }
172 
173  void bvisit(const Contains &x)
174  {
175  RCP<const Basic> a = apply(x.get_expr());
176  auto c = apply(x.get_set());
177  if (not is_a_Set(*c))
178  throw SymEngineException("expected an object of type Set");
179  RCP<const Set> b = rcp_static_cast<const Set>(c);
180  if (a == x.get_expr() and b == x.get_set())
181  result_ = x.rcp_from_this();
182  else
183  result_ = x.create(a, b);
184  }
185 
186  void bvisit(const And &x)
187  {
188  set_boolean v;
189  for (const auto &elem : x.get_container()) {
190  auto a = apply(elem);
191  if (not is_a_Boolean(*a))
192  throw SymEngineException("expected an object of type Boolean");
193  v.insert(rcp_static_cast<const Boolean>(a));
194  }
195  result_ = logical_and(v);
196  }
197 
198  void bvisit(const Or &x)
199  {
200  set_boolean v;
201  for (const auto &elem : x.get_container()) {
202  auto a = apply(elem);
203  if (not is_a_Boolean(*a))
204  throw SymEngineException("expected an object of type Boolean");
205  v.insert(rcp_static_cast<const Boolean>(a));
206  }
207  result_ = logical_or(v);
208  }
209 
210  void bvisit(const Not &x)
211  {
212  RCP<const Basic> a = apply(x.get_arg());
213  if (not is_a_Boolean(*a))
214  throw SymEngineException("expected an object of type Boolean");
215  result_ = logical_not(rcp_static_cast<const Boolean>(a));
216  }
217 
218  void bvisit(const Xor &x)
219  {
220  vec_boolean v;
221  for (const auto &elem : x.get_container()) {
222  auto a = apply(elem);
223  if (not is_a_Boolean(*a))
224  throw SymEngineException("expected an object of type Boolean");
225  v.push_back(rcp_static_cast<const Boolean>(a));
226  }
227  result_ = logical_xor(v);
228  }
229 
230  void bvisit(const FiniteSet &x)
231  {
232  set_basic v;
233  for (const auto &elem : x.get_container()) {
234  v.insert(apply(elem));
235  }
236  result_ = x.create(v);
237  }
238 
239  void bvisit(const ImageSet &x)
240  {
241  RCP<const Basic> s = apply(x.get_symbol());
242  RCP<const Basic> expr = apply(x.get_expr());
243  auto bs_ = apply(x.get_baseset());
244  if (not is_a_Set(*bs_))
245  throw SymEngineException("expected an object of type Set");
246  RCP<const Set> bs = rcp_static_cast<const Set>(bs_);
247  if (s == x.get_symbol() and expr == x.get_expr()
248  and bs == x.get_baseset()) {
249  result_ = x.rcp_from_this();
250  } else {
251  result_ = x.create(s, expr, bs);
252  }
253  }
254 
255  void bvisit(const Union &x)
256  {
257  set_set v;
258  for (const auto &elem : x.get_container()) {
259  auto a = apply(elem);
260  if (not is_a_Set(*a))
261  throw SymEngineException("expected an object of type Set");
262  v.insert(rcp_static_cast<const Set>(a));
263  }
264  result_ = x.create(v);
265  }
266 
267  void bvisit(const Piecewise &pw)
268  {
269  PiecewiseVec pwv;
270  pwv.reserve(pw.get_vec().size());
271  for (const auto &expr_pred : pw.get_vec()) {
272  const auto expr = apply(*expr_pred.first);
273  const auto pred = apply(*expr_pred.second);
274  pwv.emplace_back(
275  std::make_pair(expr, rcp_static_cast<const Boolean>(pred)));
276  }
277  result_ = piecewise(std::move(pwv));
278  }
279 
280  void bvisit(const Derivative &x)
281  {
282  auto expr = apply(x.get_arg());
283  for (const auto &sym : x.get_symbols()) {
284  auto s = apply(sym);
285  if (not is_a<Symbol>(*s)) {
286  throw SymEngineException("expected an object of type Symbol");
287  }
288  expr = expr->diff(rcp_static_cast<const Symbol>(s));
289  }
290  result_ = expr;
291  }
292 
293  void bvisit(const Subs &x)
294  {
295  auto expr = apply(x.get_arg());
296  map_basic_basic new_subs_dict;
297  for (const auto &sym : x.get_dict()) {
298  insert(new_subs_dict, apply(sym.first), apply(sym.second));
299  }
300  result_ = subs(expr, new_subs_dict);
301  }
302 
303  RCP<const Basic> apply(const Basic &x)
304  {
305  return apply(x.rcp_from_this());
306  }
307 
308  RCP<const Basic> apply(const RCP<const Basic> &x)
309  {
310  if (cache) {
311  auto it = visited.find(x);
312  if (it != visited.end()) {
313  result_ = it->second;
314  } else {
315  x->accept(*this);
316  insert(visited, x, result_);
317  }
318  } else {
319  auto it = subs_dict_.find(x);
320  if (it != subs_dict_.end()) {
321  result_ = it->second;
322  } else {
323  x->accept(*this);
324  }
325  }
326  return result_;
327  }
328 };
329 
331 inline RCP<const Basic> xreplace(const RCP<const Basic> &x,
332  const map_basic_basic &subs_dict, bool cache)
333 {
334  XReplaceVisitor s(subs_dict, cache);
335  return s.apply(x);
336 }
337 
338 class SubsVisitor : public BaseVisitor<SubsVisitor, XReplaceVisitor>
339 {
340 public:
341  using XReplaceVisitor::bvisit;
342 
343  SubsVisitor(const map_basic_basic &subs_dict_, bool cache = true)
344  : BaseVisitor<SubsVisitor, XReplaceVisitor>(subs_dict_, cache)
345  {
346  }
347 
348  void bvisit(const Pow &x)
349  {
350  RCP<const Basic> base_new = apply(x.get_base());
351  RCP<const Basic> exp_new = apply(x.get_exp());
352  if (subs_dict_.size() == 1 and is_a<Pow>(*((*subs_dict_.begin()).first))
353  and not is_a<Add>(
354  *down_cast<const Pow &>(*(*subs_dict_.begin()).first)
355  .get_exp())) {
356  auto &subs_first
357  = down_cast<const Pow &>(*(*subs_dict_.begin()).first);
358  if (eq(*subs_first.get_base(), *base_new)) {
359  auto newexpo = div(exp_new, subs_first.get_exp());
360  if (is_a_Number(*newexpo) or is_a<Constant>(*newexpo)) {
361  result_ = pow((*subs_dict_.begin()).second, newexpo);
362  return;
363  }
364  }
365  }
366  if (base_new == x.get_base() and exp_new == x.get_exp()) {
367  result_ = x.rcp_from_this();
368  } else {
369  result_ = pow(base_new, exp_new);
370  }
371  }
372 
373  void bvisit(const Derivative &x)
374  {
375  RCP<const Symbol> s;
376  map_basic_basic m, n;
377  bool subs;
378 
379  for (const auto &p : subs_dict_) {
380  // If the derivative arg is to be replaced in its entirety, allow
381  // it.
382  if (eq(*x.get_arg(), *p.first)) {
383  RCP<const Basic> t = p.second;
384  for (auto &sym : x.get_symbols()) {
385  if (not is_a<Symbol>(*sym)) {
386  throw SymEngineException("Error, expected a Symbol.");
387  }
388  t = t->diff(rcp_static_cast<const Symbol>(sym));
389  }
390  result_ = t;
391  return;
392  }
393  }
394  for (const auto &p : subs_dict_) {
395  subs = true;
396  if (eq(*x.get_arg()->subs({{p.first, p.second}}), *x.get_arg()))
397  continue;
398 
399  // If p.first and p.second are symbols and arg_ is
400  // independent of p.second, p.first can be replaced
401  if (is_a<Symbol>(*p.first) and is_a<Symbol>(*p.second)
402  and eq(
403  *x.get_arg()->diff(rcp_static_cast<const Symbol>(p.second)),
404  *zero)) {
405  insert(n, p.first, p.second);
406  continue;
407  }
408  for (const auto &d : x.get_symbols()) {
409  if (is_a<Symbol>(*d)) {
410  s = rcp_static_cast<const Symbol>(d);
411  // If p.first or p.second has non zero derivates wrt to s
412  // p.first cannot be replaced
413  if (neq(*zero, *(p.first->diff(s)))
414  || neq(*zero, *(p.second->diff(s)))) {
415  subs = false;
416  break;
417  }
418  } else {
419  result_
420  = make_rcp<const Subs>(x.rcp_from_this(), subs_dict_);
421  return;
422  }
423  }
424  if (subs) {
425  insert(n, p.first, p.second);
426  } else {
427  insert(m, p.first, p.second);
428  }
429  }
430  auto t = x.get_arg()->subs(n);
431  for (auto &p : x.get_symbols()) {
432  auto t2 = p->subs(n);
433  if (not is_a<Symbol>(*t2)) {
434  throw SymEngineException("Error, expected a Symbol.");
435  }
436  t = t->diff(rcp_static_cast<const Symbol>(t2));
437  }
438  if (m.empty()) {
439  result_ = t;
440  } else {
441  result_ = make_rcp<const Subs>(t, m);
442  }
443  }
444 
445  void bvisit(const Subs &x)
446  {
447  map_basic_basic m, n;
448  for (const auto &p : subs_dict_) {
449  bool found = false;
450  for (const auto &s : x.get_dict()) {
451  if (neq(*(s.first->subs({{p.first, p.second}})), *(s.first))) {
452  found = true;
453  break;
454  }
455  }
456  // If p.first is not replaced in arg_ by dict_,
457  // store p.first in n to replace in arg_
458  if (not found) {
459  insert(n, p.first, p.second);
460  }
461  }
462  for (const auto &s : x.get_dict()) {
463  insert(m, s.first, apply(s.second));
464  }
465  RCP<const Basic> presub = x.get_arg()->subs(n);
466  if (is_a<Subs>(*presub)) {
467  for (auto &q : down_cast<const Subs &>(*presub).get_dict()) {
468  insert(m, q.first, q.second);
469  }
470  result_ = down_cast<const Subs &>(*presub).get_arg()->subs(m);
471  } else {
472  result_ = presub->subs(m);
473  }
474  }
475 };
476 
477 class MSubsVisitor : public BaseVisitor<MSubsVisitor, XReplaceVisitor>
478 {
479 public:
480  using XReplaceVisitor::bvisit;
481 
482  MSubsVisitor(const map_basic_basic &d, bool cache = true)
484  {
485  }
486 
487  void bvisit(const Derivative &x)
488  {
489  result_ = x.rcp_from_this();
490  }
491 
492  void bvisit(const Subs &x)
493  {
494  map_basic_basic m = x.get_dict();
495  for (const auto &p : subs_dict_) {
496  m[p.first] = p.second;
497  }
498  result_ = msubs(x.get_arg(), m);
499  }
500 };
501 
502 class SSubsVisitor : public BaseVisitor<SSubsVisitor, SubsVisitor>
503 {
504 public:
505  using XReplaceVisitor::bvisit;
506 
507  SSubsVisitor(const map_basic_basic &d, bool cache = true)
509  {
510  }
511 
512  void bvisit(const Derivative &x)
513  {
514  apply(x.get_arg());
515  auto t = result_;
516  multiset_basic m;
517  for (auto &p : x.get_symbols()) {
518  apply(p);
519  m.insert(result_);
520  }
521  result_ = Derivative::create(t, m);
522  }
523 
524  void bvisit(const Subs &x)
525  {
526  map_basic_basic m = x.get_dict();
527  for (const auto &p : subs_dict_) {
528  m[p.first] = p.second;
529  }
530  result_ = ssubs(x.get_arg(), m);
531  }
532 };
533 
535 inline RCP<const Basic> msubs(const RCP<const Basic> &x,
536  const map_basic_basic &subs_dict, bool cache)
537 {
538  MSubsVisitor s(subs_dict, cache);
539  return s.apply(x);
540 }
541 
543 inline RCP<const Basic> ssubs(const RCP<const Basic> &x,
544  const map_basic_basic &subs_dict, bool cache)
545 {
546  SSubsVisitor s(subs_dict, cache);
547  return s.apply(x);
548 }
549 
550 inline RCP<const Basic> subs(const RCP<const Basic> &x,
551  const map_basic_basic &subs_dict, bool cache)
552 {
553  SubsVisitor b(subs_dict, cache);
554  return b.apply(x);
555 }
556 
557 } // namespace SymEngine
558 
559 #endif // SYMENGINE_SUBS_H
T begin(T... args)
The base class for representing addition in symbolic expressions.
Definition: add.h:27
static RCP< const Basic > from_dict(const RCP< const Number > &coef, umap_basic_num &&d)
Create an appropriate instance from dictionary quickly.
Definition: add.cpp:140
static void coef_dict_add_term(const Ptr< RCP< const Number >> &coef, umap_basic_num &d, const RCP< const Number > &c, const RCP< const Basic > &term)
Updates the numerical coefficient and the dictionary.
Definition: add.cpp:261
const RCP< const Number > & get_coef() const
Definition: add.h:142
The lowest unit of symbolic representation.
Definition: basic.h:95
RCP< const Basic > subs(const map_basic_basic &subs_dict) const
Substitutes 'subs_dict' into 'self'.
Definition: basic.cpp:27
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
virtual RCP< const Basic > create(const vec_basic &x) const
Method to construct classes with canonicalization.
Definition: functions.cpp:1862
static void as_base_exp(const RCP< const Basic > &self, const Ptr< RCP< const Basic >> &exp, const Ptr< RCP< const Basic >> &base)
Convert to a base and exponent form.
Definition: mul.cpp:315
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:116
virtual RCP< const Basic > create(const vec_basic &v) const =0
Method to construct classes with canonicalization.
virtual vec_basic get_args() const
Returns the list of arguments.
Definition: functions.h:159
virtual RCP< const Basic > create(const RCP< const Basic > &arg) const =0
Method to construct classes with canonicalization.
RCP< const Basic > get_arg() const
Definition: functions.h:36
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_base() const
Definition: pow.h:37
virtual RCP< const Basic > create(const RCP< const Basic > &a, const RCP< const Basic > &b) const =0
Method to construct classes with canonicalization.
RCP< const Basic > get_arg1() const
Definition: functions.h:91
RCP< const Basic > get_arg2() const
Definition: functions.h:96
T emplace_back(T... args)
T end(T... args)
T find(T... args)
T insert(T... args)
T make_pair(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a_Number(const Basic &b)
Definition: number.h:130
RCP< const Basic > div(const RCP< const Basic > &a, const RCP< const Basic > &b)
Division.
Definition: mul.cpp:426
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > msubs(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
Subs which treat f(t) and Derivative(f(t), t) as separate variables.
Definition: subs.h:535
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
Definition: pow.cpp:270
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:347
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
int factor(const Ptr< RCP< const Integer >> &f, const Integer &n, double B1)
Definition: ntheory.cpp:368
RCP< const Basic > xreplace(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
Mappings in the subs_dict are applied to the expression tree of x
Definition: subs.h:331
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > ssubs(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
SymPy compatible subs.
Definition: subs.h:543
T push_back(T... args)
T reserve(T... args)
T size(T... args)