Program Listing for File subs.h

Return to documentation for file (symengine/symengine/subs.h)

#ifndef SYMENGINE_SUBS_H
#define SYMENGINE_SUBS_H

#include <symengine/logic.h>
#include <symengine/visitor.h>

namespace SymEngine
{
// xreplace replaces subtrees of a node in the expression tree
// with a new subtree
RCP<const Basic> xreplace(const RCP<const Basic> &x,
                          const map_basic_basic &subs_dict, bool cache = true);
// subs substitutes expressions similar to xreplace, but keeps
// the mathematical equivalence for derivatives and subs
RCP<const Basic> subs(const RCP<const Basic> &x,
                      const map_basic_basic &subs_dict, bool cache = true);
// port of sympy.physics.mechanics.msubs where f'(x) and f(x)
// are considered independent
RCP<const Basic> msubs(const RCP<const Basic> &x,
                       const map_basic_basic &subs_dict, bool cache = true);
// port of sympy's subs where subs inside derivatives are done
RCP<const Basic> ssubs(const RCP<const Basic> &x,
                       const map_basic_basic &subs_dict, bool cache = true);

class XReplaceVisitor : public BaseVisitor<XReplaceVisitor>
{

protected:
    RCP<const Basic> result_;
    const map_basic_basic &subs_dict_;
    map_basic_basic visited;
    bool cache;

public:
    XReplaceVisitor(const map_basic_basic &subs_dict, bool cache = true)
        : subs_dict_(subs_dict), cache(cache)
    {
        if (cache) {
            visited = subs_dict;
        }
    }
    // TODO : Polynomials, Series, Sets
    void bvisit(const Basic &x)
    {
        result_ = x.rcp_from_this();
    }

    void bvisit(const Add &x)
    {
        SymEngine::umap_basic_num d;
        RCP<const Number> coef;

        auto it = subs_dict_.find(x.get_coef());
        if (it != subs_dict_.end()) {
            coef = zero;
            Add::coef_dict_add_term(outArg(coef), d, one, it->second);
        } else {
            coef = x.get_coef();
        }

        for (const auto &p : x.get_dict()) {
            auto it
                = subs_dict_.find(Add::from_dict(zero, {{p.first, p.second}}));
            if (it != subs_dict_.end()) {
                Add::coef_dict_add_term(outArg(coef), d, one, it->second);
            } else {
                it = subs_dict_.find(p.second);
                if (it != subs_dict_.end()) {
                    Add::coef_dict_add_term(outArg(coef), d, one,
                                            mul(it->second, apply(p.first)));
                } else {
                    Add::coef_dict_add_term(outArg(coef), d, p.second,
                                            apply(p.first));
                }
            }
        }
        result_ = Add::from_dict(coef, std::move(d));
    }

    void bvisit(const Mul &x)
    {
        RCP<const Number> coef = one;
        map_basic_basic d;
        for (const auto &p : x.get_dict()) {
            RCP<const Basic> factor_old;
            if (eq(*p.second, *one)) {
                factor_old = p.first;
            } else {
                factor_old = make_rcp<Pow>(p.first, p.second);
            }
            RCP<const Basic> factor = apply(factor_old);
            if (factor == factor_old) {
                // TODO: Check if Mul::dict_add_term is enough
                Mul::dict_add_term_new(outArg(coef), d, p.second, p.first);
            } else if (is_a_Number(*factor)) {
                if (down_cast<const Number &>(*factor).is_zero()) {
                    result_ = factor;
                    return;
                }
                imulnum(outArg(coef), rcp_static_cast<const Number>(factor));
            } else if (is_a<Mul>(*factor)) {
                RCP<const Mul> tmp = rcp_static_cast<const Mul>(factor);
                imulnum(outArg(coef), tmp->get_coef());
                for (const auto &q : tmp->get_dict()) {
                    Mul::dict_add_term_new(outArg(coef), d, q.second, q.first);
                }
            } else {
                RCP<const Basic> exp, t;
                Mul::as_base_exp(factor, outArg(exp), outArg(t));
                Mul::dict_add_term_new(outArg(coef), d, exp, t);
            }
        }

        // Replace the coefficient
        RCP<const Basic> factor = apply(x.get_coef());
        RCP<const Basic> exp, t;
        Mul::as_base_exp(factor, outArg(exp), outArg(t));
        Mul::dict_add_term_new(outArg(coef), d, exp, t);

        result_ = Mul::from_dict(coef, std::move(d));
    }

    void bvisit(const Pow &x)
    {
        RCP<const Basic> base_new = apply(x.get_base());
        RCP<const Basic> exp_new = apply(x.get_exp());
        if (base_new == x.get_base() and exp_new == x.get_exp()) {
            result_ = x.rcp_from_this();
        } else {
            result_ = pow(base_new, exp_new);
        }
    }

    void bvisit(const OneArgFunction &x)
    {
        apply(x.get_arg());
        if (result_ == x.get_arg()) {
            result_ = x.rcp_from_this();
        } else {
            result_ = x.create(result_);
        }
    }

    template <class T>
    void bvisit(const TwoArgBasic<T> &x)
    {
        RCP<const Basic> a = apply(x.get_arg1());
        RCP<const Basic> b = apply(x.get_arg2());
        if (a == x.get_arg1() and b == x.get_arg2())
            result_ = x.rcp_from_this();
        else
            result_ = x.create(a, b);
    }

    void bvisit(const MultiArgFunction &x)
    {
        vec_basic v = x.get_args();
        for (auto &elem : v) {
            elem = apply(elem);
        }
        result_ = x.create(v);
    }

    void bvisit(const FunctionSymbol &x)
    {
        vec_basic v = x.get_args();
        for (auto &elem : v) {
            elem = apply(elem);
        }
        result_ = x.create(v);
    }

    void bvisit(const Contains &x)
    {
        RCP<const Basic> a = apply(x.get_expr());
        auto c = apply(x.get_set());
        if (not is_a_Set(*c))
            throw SymEngineException("expected an object of type Set");
        RCP<const Set> b = rcp_static_cast<const Set>(c);
        if (a == x.get_expr() and b == x.get_set())
            result_ = x.rcp_from_this();
        else
            result_ = x.create(a, b);
    }

    void bvisit(const And &x)
    {
        set_boolean v;
        for (const auto &elem : x.get_container()) {
            auto a = apply(elem);
            if (not is_a_Boolean(*a))
                throw SymEngineException("expected an object of type Boolean");
            v.insert(rcp_static_cast<const Boolean>(a));
        }
        result_ = x.create(v);
    }

    void bvisit(const FiniteSet &x)
    {
        set_basic v;
        for (const auto &elem : x.get_container()) {
            v.insert(apply(elem));
        }
        result_ = x.create(v);
    }

    void bvisit(const ImageSet &x)
    {
        RCP<const Basic> s = apply(x.get_symbol());
        RCP<const Basic> expr = apply(x.get_expr());
        auto bs_ = apply(x.get_baseset());
        if (not is_a_Set(*bs_))
            throw SymEngineException("expected an object of type Set");
        RCP<const Set> bs = rcp_static_cast<const Set>(bs_);
        if (s == x.get_symbol() and expr == x.get_expr()
            and bs == x.get_baseset()) {
            result_ = x.rcp_from_this();
        } else {
            result_ = x.create(s, expr, bs);
        }
    }

    void bvisit(const Union &x)
    {
        set_set v;
        for (const auto &elem : x.get_container()) {
            auto a = apply(elem);
            if (not is_a_Set(*a))
                throw SymEngineException("expected an object of type Set");
            v.insert(rcp_static_cast<const Set>(a));
        }
        result_ = x.create(v);
    }

    void bvisit(const Piecewise &pw)
    {
        PiecewiseVec pwv;
        pwv.reserve(pw.get_vec().size());
        for (const auto &expr_pred : pw.get_vec()) {
            const auto expr = apply(*expr_pred.first);
            const auto pred = apply(*expr_pred.second);
            pwv.emplace_back(
                std::make_pair(expr, rcp_static_cast<const Boolean>(pred)));
        }
        result_ = piecewise(std::move(pwv));
    }

    void bvisit(const Derivative &x)
    {
        auto expr = apply(x.get_arg());
        for (const auto &sym : x.get_symbols()) {
            auto s = apply(sym);
            if (not is_a<Symbol>(*s)) {
                throw SymEngineException("expected an object of type Symbol");
            }
            expr = expr->diff(rcp_static_cast<const Symbol>(s));
        }
        result_ = expr;
    }

    void bvisit(const Subs &x)
    {
        auto expr = apply(x.get_arg());
        map_basic_basic new_subs_dict;
        for (const auto &sym : x.get_dict()) {
            insert(new_subs_dict, apply(sym.first), apply(sym.second));
        }
        result_ = subs(expr, new_subs_dict);
    }

    RCP<const Basic> apply(const Basic &x)
    {
        return apply(x.rcp_from_this());
    }

    RCP<const Basic> apply(const RCP<const Basic> &x)
    {
        if (cache) {
            auto it = visited.find(x);
            if (it != visited.end()) {
                result_ = it->second;
            } else {
                x->accept(*this);
                insert(visited, x, result_);
            }
        } else {
            auto it = subs_dict_.find(x);
            if (it != subs_dict_.end()) {
                result_ = it->second;
            } else {
                x->accept(*this);
            }
        }
        return result_;
    }
};

inline RCP<const Basic> xreplace(const RCP<const Basic> &x,
                                 const map_basic_basic &subs_dict, bool cache)
{
    XReplaceVisitor s(subs_dict, cache);
    return s.apply(x);
}

class SubsVisitor : public BaseVisitor<SubsVisitor, XReplaceVisitor>
{
public:
    using XReplaceVisitor::bvisit;

    SubsVisitor(const map_basic_basic &subs_dict_, bool cache = true)
        : BaseVisitor<SubsVisitor, XReplaceVisitor>(subs_dict_, cache)
    {
    }

    void bvisit(const Pow &x)
    {
        RCP<const Basic> base_new = apply(x.get_base());
        RCP<const Basic> exp_new = apply(x.get_exp());
        if (subs_dict_.size() == 1 and is_a<Pow>(*((*subs_dict_.begin()).first))
            and not is_a<Add>(
                    *down_cast<const Pow &>(*(*subs_dict_.begin()).first)
                         .get_exp())) {
            auto &subs_first
                = down_cast<const Pow &>(*(*subs_dict_.begin()).first);
            if (eq(*subs_first.get_base(), *base_new)) {
                auto newexpo = div(exp_new, subs_first.get_exp());
                if (is_a_Number(*newexpo) or is_a<Constant>(*newexpo)) {
                    result_ = pow((*subs_dict_.begin()).second, newexpo);
                    return;
                }
            }
        }
        if (base_new == x.get_base() and exp_new == x.get_exp()) {
            result_ = x.rcp_from_this();
        } else {
            result_ = pow(base_new, exp_new);
        }
    }

    void bvisit(const Derivative &x)
    {
        RCP<const Symbol> s;
        map_basic_basic m, n;
        bool subs;

        for (const auto &p : subs_dict_) {
            // If the derivative arg is to be replaced in its entirety, allow
            // it.
            if (eq(*x.get_arg(), *p.first)) {
                RCP<const Basic> t = p.second;
                for (auto &sym : x.get_symbols()) {
                    if (not is_a<Symbol>(*sym)) {
                        throw SymEngineException("Error, expected a Symbol.");
                    }
                    t = t->diff(rcp_static_cast<const Symbol>(sym));
                }
                result_ = t;
                return;
            }
        }
        for (const auto &p : subs_dict_) {
            subs = true;
            if (eq(*x.get_arg()->subs({{p.first, p.second}}), *x.get_arg()))
                continue;

            // If p.first and p.second are symbols and arg_ is
            // independent of p.second, p.first can be replaced
            if (is_a<Symbol>(*p.first) and is_a<Symbol>(*p.second)
                and eq(*x.get_arg()->diff(
                           rcp_static_cast<const Symbol>(p.second)),
                       *zero)) {
                insert(n, p.first, p.second);
                continue;
            }
            for (const auto &d : x.get_symbols()) {
                if (is_a<Symbol>(*d)) {
                    s = rcp_static_cast<const Symbol>(d);
                    // If p.first or p.second has non zero derivates wrt to s
                    // p.first cannot be replaced
                    if (neq(*zero, *(p.first->diff(s)))
                        || neq(*zero, *(p.second->diff(s)))) {
                        subs = false;
                        break;
                    }
                } else {
                    result_
                        = make_rcp<const Subs>(x.rcp_from_this(), subs_dict_);
                    return;
                }
            }
            if (subs) {
                insert(n, p.first, p.second);
            } else {
                insert(m, p.first, p.second);
            }
        }
        auto t = x.get_arg()->subs(n);
        for (auto &p : x.get_symbols()) {
            auto t2 = p->subs(n);
            if (not is_a<Symbol>(*t2)) {
                throw SymEngineException("Error, expected a Symbol.");
            }
            t = t->diff(rcp_static_cast<const Symbol>(t2));
        }
        if (m.empty()) {
            result_ = t;
        } else {
            result_ = make_rcp<const Subs>(t, m);
        }
    }

    void bvisit(const Subs &x)
    {
        map_basic_basic m, n;
        for (const auto &p : subs_dict_) {
            bool found = false;
            for (const auto &s : x.get_dict()) {
                if (neq(*(s.first->subs({{p.first, p.second}})), *(s.first))) {
                    found = true;
                    break;
                }
            }
            // If p.first is not replaced in arg_ by dict_,
            // store p.first in n to replace in arg_
            if (not found) {
                insert(n, p.first, p.second);
            }
        }
        for (const auto &s : x.get_dict()) {
            insert(m, s.first, apply(s.second));
        }
        RCP<const Basic> presub = x.get_arg()->subs(n);
        if (is_a<Subs>(*presub)) {
            for (auto &q : down_cast<const Subs &>(*presub).get_dict()) {
                insert(m, q.first, q.second);
            }
            result_ = down_cast<const Subs &>(*presub).get_arg()->subs(m);
        } else {
            result_ = presub->subs(m);
        }
    }
};

class MSubsVisitor : public BaseVisitor<MSubsVisitor, XReplaceVisitor>
{
public:
    using XReplaceVisitor::bvisit;

    MSubsVisitor(const map_basic_basic &d, bool cache = true)
        : BaseVisitor<MSubsVisitor, XReplaceVisitor>(d, cache)
    {
    }

    void bvisit(const Derivative &x)
    {
        result_ = x.rcp_from_this();
    }

    void bvisit(const Subs &x)
    {
        map_basic_basic m = x.get_dict();
        for (const auto &p : subs_dict_) {
            m[p.first] = p.second;
        }
        result_ = msubs(x.get_arg(), m);
    }
};

class SSubsVisitor : public BaseVisitor<SSubsVisitor, SubsVisitor>
{
public:
    using XReplaceVisitor::bvisit;

    SSubsVisitor(const map_basic_basic &d, bool cache = true)
        : BaseVisitor<SSubsVisitor, SubsVisitor>(d, cache)
    {
    }

    void bvisit(const Derivative &x)
    {
        apply(x.get_arg());
        auto t = result_;
        multiset_basic m;
        for (auto &p : x.get_symbols()) {
            apply(p);
            m.insert(result_);
        }
        result_ = Derivative::create(t, m);
    }

    void bvisit(const Subs &x)
    {
        map_basic_basic m = x.get_dict();
        for (const auto &p : subs_dict_) {
            m[p.first] = p.second;
        }
        result_ = ssubs(x.get_arg(), m);
    }
};

inline RCP<const Basic> msubs(const RCP<const Basic> &x,
                              const map_basic_basic &subs_dict, bool cache)
{
    MSubsVisitor s(subs_dict, cache);
    return s.apply(x);
}

inline RCP<const Basic> ssubs(const RCP<const Basic> &x,
                              const map_basic_basic &subs_dict, bool cache)
{
    SSubsVisitor s(subs_dict, cache);
    return s.apply(x);
}

inline RCP<const Basic> subs(const RCP<const Basic> &x,
                             const map_basic_basic &subs_dict, bool cache)
{
    SubsVisitor b(subs_dict, cache);
    return b.apply(x);
}

} // namespace SymEngine

#endif // SYMENGINE_SUBS_H