Loading...
Searching...
No Matches
subs.h
1#ifndef SYMENGINE_SUBS_H
2#define SYMENGINE_SUBS_H
3
4#include <symengine/logic.h>
5#include <symengine/visitor.h>
6
7namespace SymEngine
8{
9// xreplace replaces subtrees of a node in the expression tree
10// with a new subtree
11RCP<const Basic> xreplace(const RCP<const Basic> &x,
12 const map_basic_basic &subs_dict, bool cache = true);
13// subs substitutes expressions similar to xreplace, but keeps
14// the mathematical equivalence for derivatives and subs
15RCP<const Basic> subs(const RCP<const Basic> &x,
16 const map_basic_basic &subs_dict, bool cache = true);
17// port of sympy.physics.mechanics.msubs where f'(x) and f(x)
18// are considered independent
19RCP<const Basic> msubs(const RCP<const Basic> &x,
20 const map_basic_basic &subs_dict, bool cache = true);
21// port of sympy's subs where subs inside derivatives are done
22RCP<const Basic> ssubs(const RCP<const Basic> &x,
23 const map_basic_basic &subs_dict, bool cache = true);
24
25class XReplaceVisitor : public BaseVisitor<XReplaceVisitor>
26{
27
28protected:
29 RCP<const Basic> result_;
30 const map_basic_basic &subs_dict_;
31 map_basic_basic visited;
32 bool cache;
33
34public:
35 XReplaceVisitor(const map_basic_basic &subs_dict, bool cache = true)
36 : subs_dict_(subs_dict), cache(cache)
37 {
38 if (cache) {
39 visited = subs_dict;
40 }
41 }
42 // TODO : Polynomials, Series, Sets
43 void bvisit(const Basic &x)
44 {
45 result_ = x.rcp_from_this();
46 }
47
48 void bvisit(const Add &x)
49 {
51 RCP<const Number> coef;
52
53 auto it = subs_dict_.find(x.get_coef());
54 if (it != subs_dict_.end()) {
55 coef = zero;
56 Add::coef_dict_add_term(outArg(coef), d, one, it->second);
57 } else {
58 coef = x.get_coef();
59 }
60
61 for (const auto &p : x.get_dict()) {
62 auto it
63 = subs_dict_.find(Add::from_dict(zero, {{p.first, p.second}}));
64 if (it != subs_dict_.end()) {
65 Add::coef_dict_add_term(outArg(coef), d, one, it->second);
66 } else {
67 it = subs_dict_.find(p.second);
68 if (it != subs_dict_.end()) {
69 Add::coef_dict_add_term(outArg(coef), d, one,
70 mul(it->second, apply(p.first)));
71 } else {
72 Add::coef_dict_add_term(outArg(coef), d, p.second,
73 apply(p.first));
74 }
75 }
76 }
77 result_ = Add::from_dict(coef, std::move(d));
78 }
79
80 void bvisit(const Mul &x)
81 {
82 RCP<const Number> coef = one;
84 for (const auto &p : x.get_dict()) {
85 RCP<const Basic> factor_old;
86 if (eq(*p.second, *one)) {
87 factor_old = p.first;
88 } else {
89 factor_old = make_rcp<Pow>(p.first, p.second);
90 }
91 RCP<const Basic> factor = apply(factor_old);
92 if (factor == factor_old) {
93 // TODO: Check if Mul::dict_add_term is enough
94 Mul::dict_add_term_new(outArg(coef), d, p.second, p.first);
95 } else if (is_a_Number(*factor)) {
96 imulnum(outArg(coef), rcp_static_cast<const Number>(factor));
97 } else if (is_a<Mul>(*factor)) {
98 RCP<const Mul> tmp = rcp_static_cast<const Mul>(factor);
99 imulnum(outArg(coef), tmp->get_coef());
100 for (const auto &q : tmp->get_dict()) {
101 Mul::dict_add_term_new(outArg(coef), d, q.second, q.first);
102 }
103 } else {
104 RCP<const Basic> exp, t;
105 Mul::as_base_exp(factor, outArg(exp), outArg(t));
106 Mul::dict_add_term_new(outArg(coef), d, exp, t);
107 }
108 }
109
110 // Replace the coefficient
111 RCP<const Basic> factor = apply(x.get_coef());
112 RCP<const Basic> exp, t;
113 Mul::as_base_exp(factor, outArg(exp), outArg(t));
114 Mul::dict_add_term_new(outArg(coef), d, exp, t);
115
116 result_ = Mul::from_dict(coef, std::move(d));
117 }
118
119 void bvisit(const Pow &x)
120 {
121 RCP<const Basic> base_new = apply(x.get_base());
122 RCP<const Basic> exp_new = apply(x.get_exp());
123 if (base_new == x.get_base() and exp_new == x.get_exp()) {
124 result_ = x.rcp_from_this();
125 } else {
126 result_ = pow(base_new, exp_new);
127 }
128 }
129
130 void bvisit(const OneArgFunction &x)
131 {
132 apply(x.get_arg());
133 if (result_ == x.get_arg()) {
134 result_ = x.rcp_from_this();
135 } else {
136 result_ = x.create(result_);
137 }
138 }
139
140 template <class T>
141 void bvisit(const TwoArgBasic<T> &x)
142 {
143 RCP<const Basic> a = apply(x.get_arg1());
144 RCP<const Basic> b = apply(x.get_arg2());
145 if (a == x.get_arg1() and b == x.get_arg2())
146 result_ = x.rcp_from_this();
147 else
148 result_ = x.create(a, b);
149 }
150
151 void bvisit(const MultiArgFunction &x)
152 {
153 vec_basic v = x.get_args();
154 for (auto &elem : v) {
155 elem = apply(elem);
156 }
157 result_ = x.create(v);
158 }
159
160 void bvisit(const FunctionSymbol &x)
161 {
162 vec_basic v = x.get_args();
163 for (auto &elem : v) {
164 elem = apply(elem);
165 }
166 result_ = x.create(v);
167 }
168
169 void bvisit(const Contains &x)
170 {
171 RCP<const Basic> a = apply(x.get_expr());
172 auto c = apply(x.get_set());
173 if (not is_a_Set(*c))
174 throw SymEngineException("expected an object of type Set");
175 RCP<const Set> b = rcp_static_cast<const Set>(c);
176 if (a == x.get_expr() and b == x.get_set())
177 result_ = x.rcp_from_this();
178 else
179 result_ = x.create(a, b);
180 }
181
182 void bvisit(const And &x)
183 {
184 set_boolean v;
185 for (const auto &elem : x.get_container()) {
186 auto a = apply(elem);
187 if (not is_a_Boolean(*a))
188 throw SymEngineException("expected an object of type Boolean");
189 v.insert(rcp_static_cast<const Boolean>(a));
190 }
191 result_ = logical_and(v);
192 }
193
194 void bvisit(const Or &x)
195 {
196 set_boolean v;
197 for (const auto &elem : x.get_container()) {
198 auto a = apply(elem);
199 if (not is_a_Boolean(*a))
200 throw SymEngineException("expected an object of type Boolean");
201 v.insert(rcp_static_cast<const Boolean>(a));
202 }
203 result_ = logical_or(v);
204 }
205
206 void bvisit(const Not &x)
207 {
208 RCP<const Basic> a = apply(x.get_arg());
209 if (not is_a_Boolean(*a))
210 throw SymEngineException("expected an object of type Boolean");
211 result_ = logical_not(rcp_static_cast<const Boolean>(a));
212 }
213
214 void bvisit(const Xor &x)
215 {
216 vec_boolean v;
217 for (const auto &elem : x.get_container()) {
218 auto a = apply(elem);
219 if (not is_a_Boolean(*a))
220 throw SymEngineException("expected an object of type Boolean");
221 v.push_back(rcp_static_cast<const Boolean>(a));
222 }
223 result_ = logical_xor(v);
224 }
225
226 void bvisit(const FiniteSet &x)
227 {
228 set_basic v;
229 for (const auto &elem : x.get_container()) {
230 v.insert(apply(elem));
231 }
232 result_ = x.create(v);
233 }
234
235 void bvisit(const ImageSet &x)
236 {
237 RCP<const Basic> s = apply(x.get_symbol());
238 RCP<const Basic> expr = apply(x.get_expr());
239 auto bs_ = apply(x.get_baseset());
240 if (not is_a_Set(*bs_))
241 throw SymEngineException("expected an object of type Set");
242 RCP<const Set> bs = rcp_static_cast<const Set>(bs_);
243 if (s == x.get_symbol() and expr == x.get_expr()
244 and bs == x.get_baseset()) {
245 result_ = x.rcp_from_this();
246 } else {
247 result_ = x.create(s, expr, bs);
248 }
249 }
250
251 void bvisit(const Union &x)
252 {
253 set_set v;
254 for (const auto &elem : x.get_container()) {
255 auto a = apply(elem);
256 if (not is_a_Set(*a))
257 throw SymEngineException("expected an object of type Set");
258 v.insert(rcp_static_cast<const Set>(a));
259 }
260 result_ = x.create(v);
261 }
262
263 void bvisit(const Piecewise &pw)
264 {
265 PiecewiseVec pwv;
266 pwv.reserve(pw.get_vec().size());
267 for (const auto &expr_pred : pw.get_vec()) {
268 const auto expr = apply(*expr_pred.first);
269 const auto pred = apply(*expr_pred.second);
270 pwv.emplace_back(
271 std::make_pair(expr, rcp_static_cast<const Boolean>(pred)));
272 }
273 result_ = piecewise(std::move(pwv));
274 }
275
276 void bvisit(const Derivative &x)
277 {
278 auto expr = apply(x.get_arg());
279 for (const auto &sym : x.get_symbols()) {
280 auto s = apply(sym);
281 if (not is_a<Symbol>(*s)) {
282 throw SymEngineException("expected an object of type Symbol");
283 }
284 expr = expr->diff(rcp_static_cast<const Symbol>(s));
285 }
286 result_ = expr;
287 }
288
289 void bvisit(const Subs &x)
290 {
291 auto expr = apply(x.get_arg());
292 map_basic_basic new_subs_dict;
293 for (const auto &sym : x.get_dict()) {
294 insert(new_subs_dict, apply(sym.first), apply(sym.second));
295 }
296 result_ = subs(expr, new_subs_dict);
297 }
298
299 void bvisit(const ComplexBase &x)
300 {
301 auto it = subs_dict_.find(I);
302 if (it != subs_dict_.end()) {
303 result_ = add(apply(x.real_part()),
304 mul(apply(x.imaginary_part()), it->second));
305 } else {
306 result_ = x.rcp_from_this();
307 }
308 }
309
310 RCP<const Basic> apply(const Basic &x)
311 {
312 return apply(x.rcp_from_this());
313 }
314
315 RCP<const Basic> apply(const RCP<const Basic> &x)
316 {
317 if (cache) {
318 auto it = visited.find(x);
319 if (it != visited.end()) {
320 result_ = it->second;
321 } else {
322 x->accept(*this);
323 insert(visited, x, result_);
324 }
325 } else {
326 auto it = subs_dict_.find(x);
327 if (it != subs_dict_.end()) {
328 result_ = it->second;
329 } else {
330 x->accept(*this);
331 }
332 }
333 return result_;
334 }
335};
336
338inline RCP<const Basic> xreplace(const RCP<const Basic> &x,
339 const map_basic_basic &subs_dict, bool cache)
340{
341 XReplaceVisitor s(subs_dict, cache);
342 return s.apply(x);
343}
344
345class SubsVisitor : public BaseVisitor<SubsVisitor, XReplaceVisitor>
346{
347public:
348 using XReplaceVisitor::bvisit;
349
350 SubsVisitor(const map_basic_basic &subs_dict_, bool cache = true)
352 {
353 }
354
355 void bvisit(const Pow &x)
356 {
357 RCP<const Basic> base_new = apply(x.get_base());
358 RCP<const Basic> exp_new = apply(x.get_exp());
359 if (subs_dict_.size() == 1 and is_a<Pow>(*((*subs_dict_.begin()).first))
360 and not is_a<Add>(
361 *down_cast<const Pow &>(*(*subs_dict_.begin()).first)
362 .get_exp())) {
363 auto &subs_first
364 = down_cast<const Pow &>(*(*subs_dict_.begin()).first);
365 if (eq(*subs_first.get_base(), *base_new)) {
366 auto newexpo = div(exp_new, subs_first.get_exp());
367 if (is_a_Number(*newexpo) or is_a<Constant>(*newexpo)) {
368 result_ = pow((*subs_dict_.begin()).second, newexpo);
369 return;
370 }
371 }
372 }
373 if (base_new == x.get_base() and exp_new == x.get_exp()) {
374 result_ = x.rcp_from_this();
375 } else {
376 result_ = pow(base_new, exp_new);
377 }
378 }
379
380 void bvisit(const Derivative &x)
381 {
382 RCP<const Symbol> s;
383 map_basic_basic m, n;
384 bool subs;
385
386 for (const auto &p : subs_dict_) {
387 // If the derivative arg is to be replaced in its entirety, allow
388 // it.
389 if (eq(*x.get_arg(), *p.first)) {
390 RCP<const Basic> t = p.second;
391 for (auto &sym : x.get_symbols()) {
392 if (not is_a<Symbol>(*sym)) {
393 throw SymEngineException("Error, expected a Symbol.");
394 }
395 t = t->diff(rcp_static_cast<const Symbol>(sym));
396 }
397 result_ = t;
398 return;
399 }
400 }
401 for (const auto &p : subs_dict_) {
402 subs = true;
403 if (eq(*x.get_arg()->subs({{p.first, p.second}}), *x.get_arg()))
404 continue;
405
406 // If p.first and p.second are symbols and arg_ is
407 // independent of p.second, p.first can be replaced
408 if (is_a<Symbol>(*p.first) and is_a<Symbol>(*p.second)
409 and eq(
410 *x.get_arg()->diff(rcp_static_cast<const Symbol>(p.second)),
411 *zero)) {
412 insert(n, p.first, p.second);
413 continue;
414 }
415 for (const auto &d : x.get_symbols()) {
416 if (is_a<Symbol>(*d)) {
417 s = rcp_static_cast<const Symbol>(d);
418 // If p.first or p.second has non zero derivates wrt to s
419 // p.first cannot be replaced
420 if (neq(*zero, *(p.first->diff(s)))
421 || neq(*zero, *(p.second->diff(s)))) {
422 subs = false;
423 break;
424 }
425 } else {
426 result_
427 = make_rcp<const Subs>(x.rcp_from_this(), subs_dict_);
428 return;
429 }
430 }
431 if (subs) {
432 insert(n, p.first, p.second);
433 } else {
434 insert(m, p.first, p.second);
435 }
436 }
437 auto t = x.get_arg()->subs(n);
438 for (auto &p : x.get_symbols()) {
439 auto t2 = p->subs(n);
440 if (not is_a<Symbol>(*t2)) {
441 throw SymEngineException("Error, expected a Symbol.");
442 }
443 t = t->diff(rcp_static_cast<const Symbol>(t2));
444 }
445 if (m.empty()) {
446 result_ = t;
447 } else {
448 result_ = make_rcp<const Subs>(t, m);
449 }
450 }
451
452 void bvisit(const Subs &x)
453 {
454 map_basic_basic m, n;
455 for (const auto &p : subs_dict_) {
456 bool found = false;
457 for (const auto &s : x.get_dict()) {
458 if (neq(*(s.first->subs({{p.first, p.second}})), *(s.first))) {
459 found = true;
460 break;
461 }
462 }
463 // If p.first is not replaced in arg_ by dict_,
464 // store p.first in n to replace in arg_
465 if (not found) {
466 insert(n, p.first, p.second);
467 }
468 }
469 for (const auto &s : x.get_dict()) {
470 insert(m, s.first, apply(s.second));
471 }
472 RCP<const Basic> presub = x.get_arg()->subs(n);
473 if (is_a<Subs>(*presub)) {
474 for (auto &q : down_cast<const Subs &>(*presub).get_dict()) {
475 insert(m, q.first, q.second);
476 }
477 result_ = down_cast<const Subs &>(*presub).get_arg()->subs(m);
478 } else {
479 result_ = presub->subs(m);
480 }
481 }
482};
483
484class MSubsVisitor : public BaseVisitor<MSubsVisitor, XReplaceVisitor>
485{
486public:
487 using XReplaceVisitor::bvisit;
488
489 MSubsVisitor(const map_basic_basic &d, bool cache = true)
491 {
492 }
493
494 void bvisit(const Derivative &x)
495 {
496 result_ = x.rcp_from_this();
497 }
498
499 void bvisit(const Subs &x)
500 {
501 map_basic_basic m = x.get_dict();
502 for (const auto &p : subs_dict_) {
503 m[p.first] = p.second;
504 }
505 result_ = msubs(x.get_arg(), m);
506 }
507};
508
509class SSubsVisitor : public BaseVisitor<SSubsVisitor, SubsVisitor>
510{
511public:
512 using XReplaceVisitor::bvisit;
513
514 SSubsVisitor(const map_basic_basic &d, bool cache = true)
516 {
517 }
518
519 void bvisit(const Derivative &x)
520 {
521 apply(x.get_arg());
522 auto t = result_;
524 for (auto &p : x.get_symbols()) {
525 apply(p);
526 m.insert(result_);
527 }
528 result_ = Derivative::create(t, m);
529 }
530
531 void bvisit(const Subs &x)
532 {
533 map_basic_basic m = x.get_dict();
534 for (const auto &p : subs_dict_) {
535 m[p.first] = p.second;
536 }
537 result_ = ssubs(x.get_arg(), m);
538 }
539};
540
542inline RCP<const Basic> msubs(const RCP<const Basic> &x,
543 const map_basic_basic &subs_dict, bool cache)
544{
545 MSubsVisitor s(subs_dict, cache);
546 return s.apply(x);
547}
548
550inline RCP<const Basic> ssubs(const RCP<const Basic> &x,
551 const map_basic_basic &subs_dict, bool cache)
552{
553 SSubsVisitor s(subs_dict, cache);
554 return s.apply(x);
555}
556
557inline RCP<const Basic> subs(const RCP<const Basic> &x,
558 const map_basic_basic &subs_dict, bool cache)
559{
560 SubsVisitor b(subs_dict, cache);
561 return b.apply(x);
562}
563
564} // namespace SymEngine
565
566#endif // SYMENGINE_SUBS_H
T begin(T... args)
The base class for representing addition in symbolic expressions.
Definition: add.h:27
static RCP< const Basic > from_dict(const RCP< const Number > &coef, umap_basic_num &&d)
Create an appropriate instance from dictionary quickly.
Definition: add.cpp:140
const RCP< const Number > & get_coef() const
Definition: add.h:142
static void coef_dict_add_term(const Ptr< RCP< const Number > > &coef, umap_basic_num &d, const RCP< const Number > &c, const RCP< const Basic > &term)
Updates the numerical coefficient and the dictionary.
Definition: add.cpp:261
The lowest unit of symbolic representation.
Definition: basic.h:97
RCP< const Basic > subs(const map_basic_basic &subs_dict) const
Substitutes 'subs_dict' into 'self'.
Definition: basic.cpp:80
ComplexBase Class for deriving all complex classes.
Definition: complex.h:16
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
RCP< const Basic > create(const vec_basic &x) const override
Method to construct classes with canonicalization.
Definition: functions.cpp:1905
static void as_base_exp(const RCP< const Basic > &self, const Ptr< RCP< const Basic > > &exp, const Ptr< RCP< const Basic > > &base)
Convert to a base and exponent form.
Definition: mul.cpp:320
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:115
virtual RCP< const Basic > create(const vec_basic &v) const =0
Method to construct classes with canonicalization.
vec_basic get_args() const override
Returns the list of arguments.
Definition: functions.h:159
RCP< const Basic > get_arg() const
Definition: functions.h:36
virtual RCP< const Basic > create(const RCP< const Basic > &arg) const =0
Method to construct classes with canonicalization.
RCP< const Basic > get_base() const
Definition: pow.h:37
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_arg2() const
Definition: functions.h:96
RCP< const Basic > get_arg1() const
Definition: functions.h:91
virtual RCP< const Basic > create(const RCP< const Basic > &a, const RCP< const Basic > &b) const =0
Method to construct classes with canonicalization.
T emplace_back(T... args)
T end(T... args)
T find(T... args)
T insert(T... args)
T make_pair(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a_Number(const Basic &b)
Definition: number.h:130
RCP< const Basic > div(const RCP< const Basic > &a, const RCP< const Basic > &b)
Division.
Definition: mul.cpp:431
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > msubs(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
Subs which treat f(t) and Derivative(f(t), t) as separate variables.
Definition: subs.h:542
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
Definition: pow.cpp:271
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
void insert(T1 &m, const T2 &first, const T3 &second)
Definition: dict.h:83
int factor(const Ptr< RCP< const Integer > > &f, const Integer &n, double B1)
Definition: ntheory.cpp:370
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
RCP< const Basic > xreplace(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
Mappings in the subs_dict are applied to the expression tree of x
Definition: subs.h:338
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > ssubs(const RCP< const Basic > &x, const map_basic_basic &subs_dict, bool cache=true)
SymPy compatible subs.
Definition: subs.h:550
T push_back(T... args)
T reserve(T... args)
T size(T... args)