Program Listing for File solve.cpp

Return to documentation for file (symengine/symengine/solve.cpp)

#include <symengine/solve.h>
#include <symengine/polys/basic_conversions.h>
#include <symengine/logic.h>
#include <symengine/mul.h>
#include <symengine/as_real_imag.cpp>

namespace SymEngine
{

RCP<const Set> solve_poly_linear(const vec_basic &coeffs,
                                 const RCP<const Set> &domain)
{
    if (coeffs.size() != 2) {
        throw SymEngineException("Expected a polynomial of degree 1. Try with "
                                 "solve() or solve_poly()");
    }
    auto root = neg(div(coeffs[0], coeffs[1]));
    return set_intersection({domain, finiteset({root})});
}

RCP<const Set> solve_poly_quadratic(const vec_basic &coeffs,
                                    const RCP<const Set> &domain)
{
    if (coeffs.size() != 3) {
        throw SymEngineException("Expected a polynomial of degree 2. Try with "
                                 "solve() or solve_poly()");
    }

    auto a = coeffs[2];
    auto b = div(coeffs[1], a), c = div(coeffs[0], a);
    RCP<const Basic> root1, root2;
    if (eq(*c, *zero)) {
        root1 = neg(b);
        root2 = zero;
    } else if (eq(*b, *zero)) {
        root1 = sqrt(neg(c));
        root2 = neg(root1);
    } else {
        auto discriminant = sub(mul(b, b), mul(integer(4), c));
        auto lterm = div(neg(b), integer(2));
        auto rterm = div(sqrt(discriminant), integer(2));
        root1 = add(lterm, rterm);
        root2 = sub(lterm, rterm);
    }
    return set_intersection({domain, finiteset({root1, root2})});
}

RCP<const Set> solve_poly_cubic(const vec_basic &coeffs,
                                const RCP<const Set> &domain)
{
    if (coeffs.size() != 4) {
        throw SymEngineException("Expected a polynomial of degree 3. Try with "
                                 "solve() or solve_poly()");
    }

    auto a = coeffs[3];
    auto b = div(coeffs[2], a), c = div(coeffs[1], a), d = div(coeffs[0], a);

    // ref :
    // https://en.wikipedia.org/wiki/Cubic_function#General_solution_to_the_cubic_equation_with_real_coefficients
    auto i2 = integer(2), i3 = integer(3), i4 = integer(4), i9 = integer(9),
         i27 = integer(27);

    RCP<const Basic> root1, root2, root3;
    if (eq(*d, *zero)) {
        root1 = zero;
        auto fset = solve_poly_quadratic({c, b, one}, domain);
        SYMENGINE_ASSERT(is_a<FiniteSet>(*fset));
        auto cont = down_cast<const FiniteSet &>(*fset).get_container();
        if (cont.size() == 2) {
            root2 = *cont.begin();
            root3 = *std::next(cont.begin());
        } else {
            root2 = root3 = *cont.begin();
        }
    } else {
        auto delta0 = sub(mul(b, b), mul(i3, c));
        auto delta1
            = add(sub(mul(pow(b, i3), i2), mul({i9, b, c})), mul(i27, d));
        auto delta = div(sub(mul(i4, pow(delta0, i3)), pow(delta1, i2)), i27);
        if (eq(*delta, *zero)) {
            if (eq(*delta0, *zero)) {
                root1 = root2 = root3 = div(neg(b), i3);
            } else {
                root1 = root2
                    = div(sub(mul(i9, d), mul(b, c)), mul(i2, delta0));
                root3 = div(sub(mul({i4, b, c}), add(mul(d, i9), pow(b, i3))),
                            delta0);
            }
        } else {
            auto temp = sqrt(mul(neg(i27), delta));
            auto Cexpr = div(add(delta1, temp), i2);
            if (eq(*Cexpr, *zero)) {
                Cexpr = div(sub(delta1, temp), i2);
            }
            auto C = pow(Cexpr, div(one, i3));
            root1 = neg(div(add(b, add(C, div(delta0, C))), i3));
            auto coef = div(mul(I, sqrt(i3)), i2);
            temp = neg(div(one, i2));
            auto cbrt1 = add(temp, coef);
            auto cbrt2 = sub(temp, coef);
            root2 = neg(div(
                add(b, add(mul(cbrt1, C), div(delta0, mul(cbrt1, C)))), i3));
            root3 = neg(div(
                add(b, add(mul(cbrt2, C), div(delta0, mul(cbrt2, C)))), i3));
        }
    }
    return set_intersection({domain, finiteset({root1, root2, root3})});
}

RCP<const Set> solve_poly_quartic(const vec_basic &coeffs,
                                  const RCP<const Set> &domain)
{
    if (coeffs.size() != 5) {
        throw SymEngineException("Expected a polynomial of degree 4. Try with "
                                 "solve() or solve_poly()");
    }

    auto i2 = integer(2), i3 = integer(3), i4 = integer(4), i8 = integer(8),
         i16 = integer(16), i64 = integer(64), i256 = integer(256);

    // ref : http://mathforum.org/dr.math/faq/faq.cubic.equations.html
    auto lc = coeffs[4];
    auto a = div(coeffs[3], lc), b = div(coeffs[2], lc), c = div(coeffs[1], lc),
         d = div(coeffs[0], lc);
    set_basic roots;

    if (eq(*d, *zero)) {
        vec_basic newcoeffs(4);
        newcoeffs[0] = c, newcoeffs[1] = b, newcoeffs[2] = a,
        newcoeffs[3] = one;
        auto rcubic = solve_poly_cubic(newcoeffs, domain);
        SYMENGINE_ASSERT(is_a<FiniteSet>(*rcubic));
        roots = down_cast<const FiniteSet &>(*rcubic).get_container();
        roots.insert(zero);
    } else {
        // substitute x = y-a/4 to get equation of the form y**4 + e*y**2 + f*y
        // + g = 0
        auto sqa = mul(a, a);
        auto cba = mul(sqa, a);
        auto aby4 = div(a, i4);
        auto e = sub(b, div(mul(i3, sqa), i8));
        auto ff = sub(add(c, div(cba, i8)), div(mul(a, b), i2));
        auto g = sub(add(d, div(mul(sqa, b), i16)),
                     add(div(mul(a, c), i4), div(mul({i3, cba, a}), i256)));

        // two special cases
        if (eq(*g, *zero)) {
            vec_basic newcoeffs(4);
            newcoeffs[0] = ff, newcoeffs[1] = e, newcoeffs[2] = zero,
            newcoeffs[3] = one;
            auto rcubic = solve_poly_cubic(newcoeffs, domain);
            SYMENGINE_ASSERT(is_a<FiniteSet>(*rcubic));
            auto rtemp = down_cast<const FiniteSet &>(*rcubic).get_container();
            SYMENGINE_ASSERT(rtemp.size() > 0 and rtemp.size() <= 3);
            for (auto &r : rtemp) {
                roots.insert(sub(r, aby4));
            }
            roots.insert(neg(aby4));
        } else if (eq(*ff, *zero)) {
            vec_basic newcoeffs(3);
            newcoeffs[0] = g, newcoeffs[1] = e, newcoeffs[2] = one;
            auto rquad = solve_poly_quadratic(newcoeffs, domain);
            SYMENGINE_ASSERT(is_a<FiniteSet>(*rquad));
            auto rtemp = down_cast<const FiniteSet &>(*rquad).get_container();
            SYMENGINE_ASSERT(rtemp.size() > 0 and rtemp.size() <= 2);
            for (auto &r : rtemp) {
                auto sqrtr = sqrt(r);
                roots.insert(sub(sqrtr, aby4));
                roots.insert(sub(neg(sqrtr), aby4));
            }
        } else {
            // Leonhard Euler's method
            vec_basic newcoeffs(4);
            newcoeffs[0] = neg(div(mul(ff, ff), i64)),
            newcoeffs[1] = div(sub(mul(e, e), mul(i4, g)), i16),
            newcoeffs[2] = div(e, i2);
            newcoeffs[3] = one;

            auto rcubic = solve_poly_cubic(newcoeffs);
            SYMENGINE_ASSERT(is_a<FiniteSet>(*rcubic));
            roots = down_cast<const FiniteSet &>(*rcubic).get_container();
            SYMENGINE_ASSERT(roots.size() > 0 and roots.size() <= 3);
            auto p = sqrt(*roots.begin());
            auto q = p;
            if (roots.size() > 1) {
                q = sqrt(*std::next(roots.begin()));
            }
            auto r = div(neg(ff), mul({i8, p, q}));
            roots.clear();
            roots.insert(add({p, q, r, neg(aby4)}));
            roots.insert(add({p, neg(q), neg(r), neg(aby4)}));
            roots.insert(add({neg(p), q, neg(r), neg(aby4)}));
            roots.insert(add({neg(p), neg(q), r, neg(aby4)}));
        }
    }
    return set_intersection({domain, finiteset(roots)});
}

RCP<const Set> solve_poly_heuristics(const vec_basic &coeffs,
                                     const RCP<const Set> &domain)
{
    auto degree = coeffs.size() - 1;
    switch (degree) {
        case 0: {
            if (eq(*coeffs[0], *zero)) {
                return domain;
            } else {
                return emptyset();
            }
        }
        case 1:
            return solve_poly_linear(coeffs, domain);
        case 2:
            return solve_poly_quadratic(coeffs, domain);
        case 3:
            return solve_poly_cubic(coeffs, domain);
        case 4:
            return solve_poly_quartic(coeffs, domain);
        default:
            throw SymEngineException(
                "expected a polynomial of order between 0 to 4");
    }
}

inline RCP<const Basic> get_coeff_basic(const integer_class &i)
{
    return integer(i);
}

inline RCP<const Basic> get_coeff_basic(const Expression &i)
{
    return i.get_basic();
}

template <typename Poly>
inline vec_basic extract_coeffs(const RCP<const Poly> &f)
{
    int degree = f->get_degree();
    vec_basic coeffs;
    for (int i = 0; i <= degree; i++)
        coeffs.push_back(get_coeff_basic(f->get_coeff(i)));
    return coeffs;
}

RCP<const Set> solve_poly(const RCP<const Basic> &f,
                          const RCP<const Symbol> &sym,
                          const RCP<const Set> &domain)
{

#if defined(HAVE_SYMENGINE_FLINT) && __FLINT_RELEASE > 20502
    try {
        auto poly = from_basic<UIntPolyFlint>(f, sym);
        auto fac = factors(*poly);
        set_set solns;
        for (const auto &elem : fac) {
            auto uip = UIntPoly::from_poly(*elem.first);
            auto degree = uip->get_poly().degree();
            if (degree <= 4) {
                solns.insert(
                    solve_poly_heuristics(extract_coeffs(uip), domain));
            } else {
                solns.insert(
                    conditionset(sym, logical_and({Eq(uip->as_symbolic(), zero),
                                                   domain->contains(sym)})));
            }
        }
        return SymEngine::set_union(solns);
    } catch (SymEngineException &x) {
        // Try next
    }
#endif
    RCP<const Basic> gen = rcp_static_cast<const Basic>(sym);
    auto uexp = from_basic<UExprPoly>(f, gen);
    auto degree = uexp->get_degree();
    if (degree <= 4) {
        return solve_poly_heuristics(extract_coeffs(uexp), domain);
    } else {
        return conditionset(sym,
                            logical_and({Eq(f, zero), domain->contains(sym)}));
    }
}

RCP<const Set> solve_rational(const RCP<const Basic> &f,
                              const RCP<const Symbol> &sym,
                              const RCP<const Set> &domain)
{
    RCP<const Basic> num, den;
    as_numer_denom(f, outArg(num), outArg(den));
    if (has_symbol(*den, *sym)) {
        auto numsoln = solve(num, sym, domain);
        auto densoln = solve(den, sym, domain);
        return set_complement(numsoln, densoln);
    }
    return solve_poly(num, sym, domain);
}

/* Helper Visitors for solve_trig */

class IsALinearArgTrigVisitor
    : public BaseVisitor<IsALinearArgTrigVisitor, LocalStopVisitor>
{
protected:
    Ptr<const Symbol> x_;
    bool is_;

public:
    IsALinearArgTrigVisitor(Ptr<const Symbol> x) : x_(x)
    {
    }

    bool apply(const Basic &b)
    {
        stop_ = false;
        is_ = true;
        preorder_traversal_local_stop(b, *this);
        return is_;
    }

    bool apply(const RCP<const Basic> &b)
    {
        return apply(*b);
    }

    void bvisit(const Basic &x)
    {
        local_stop_ = false;
    }

    void bvisit(const Symbol &x)
    {
        if (x_->__eq__(x)) {
            is_ = 0;
            stop_ = true;
        }
    }

    template <typename T,
              typename
              = enable_if_t<std::is_base_of<TrigFunction, T>::value
                            or std::is_base_of<HyperbolicFunction, T>::value>>
    void bvisit(const T &x)
    {
        is_ = (from_basic<UExprPoly>(x.get_args()[0], (*x_).rcp_from_this())
                   ->get_degree()
               <= 1);
        if (not is_)
            stop_ = true;
        local_stop_ = true;
    }
};

bool is_a_LinearArgTrigEquation(const Basic &b, const Symbol &x)
{
    IsALinearArgTrigVisitor v(ptrFromRef(x));
    return v.apply(b);
}

class InvertComplexVisitor : public BaseVisitor<InvertComplexVisitor>
{
protected:
    RCP<const Set> result_;
    RCP<const Set> gY_;
    RCP<const Dummy> nD_;
    RCP<const Symbol> sym_;
    RCP<const Set> domain_;

public:
    InvertComplexVisitor(RCP<const Set> gY, RCP<const Dummy> nD,
                         RCP<const Symbol> sym, RCP<const Set> domain)
        : gY_(gY), nD_(nD), sym_(sym), domain_(domain)
    {
    }

    void bvisit(const Basic &x)
    {
        result_ = gY_;
    }

    void bvisit(const Add &x)
    {
        vec_basic f1X, f2X;
        for (auto &elem : x.get_args()) {
            if (has_symbol(*elem, *sym_)) {
                f1X.push_back(elem);
            } else {
                f2X.push_back(elem);
            }
        }
        auto depX = add(f1X), indepX = add(f2X);
        if (not eq(*indepX, *zero)) {
            gY_ = imageset(nD_, sub(nD_, indepX), gY_);
            result_ = apply(*depX);
        } else {
            result_ = gY_;
        }
    }

    void bvisit(const Mul &x)
    {
        vec_basic f1X, f2X;
        for (auto &elem : x.get_args()) {
            if (has_symbol(*elem, *sym_)) {
                f1X.push_back(elem);
            } else {
                f2X.push_back(elem);
            }
        }
        auto depX = mul(f1X), indepX = mul(f2X);
        if (not eq(*indepX, *one)) {
            if (eq(*indepX, *NegInf) or eq(*indepX, *Inf)
                or eq(*indepX, *ComplexInf)) {
                result_ = emptyset();
            } else {
                gY_ = imageset(nD_, div(nD_, indepX), gY_);
                result_ = apply(*depX);
            }
        } else {
            result_ = gY_;
        }
    }

    void bvisit(const Pow &x)
    {
        if (eq(*x.get_base(), *E) and is_a<FiniteSet>(*gY_)) {
            set_set inv;
            for (const auto &elem :
                 down_cast<const FiniteSet &>(*gY_).get_container()) {
                if (eq(*elem, *zero))
                    continue;
                RCP<const Basic> re, im;
                as_real_imag(elem, outArg(re), outArg(im));
                auto logabs = log(add(mul(re, re), mul(im, im)));
                auto logarg = atan2(im, re);
                inv.insert(imageset(
                    nD_, add(mul(add(mul({integer(2), nD_, pi}), logarg), I),
                             div(logabs, integer(2))),
                    interval(NegInf, Inf, true,
                             true))); // TODO : replace interval(-oo,oo) with
                // Set of Integers once Class for Range is implemented.
            }
            gY_ = set_union(inv);
            apply(*x.get_exp());
            return;
        }
        result_ = gY_;
    }

    RCP<const Set> apply(const Basic &b)
    {
        result_ = gY_;
        b.accept(*this);
        return set_intersection({domain_, result_});
    }
};

RCP<const Set> invertComplex(const RCP<const Basic> &fX,
                             const RCP<const Set> &gY,
                             const RCP<const Symbol> &sym,
                             const RCP<const Dummy> &nD,
                             const RCP<const Set> &domain)
{
    InvertComplexVisitor v(gY, nD, sym, domain);
    return v.apply(*fX);
}

RCP<const Set> solve_trig(const RCP<const Basic> &f,
                          const RCP<const Symbol> &sym,
                          const RCP<const Set> &domain)
{
    // TODO : first simplify f using `fu`.
    auto exp_f = rewrite_as_exp(f);
    RCP<const Basic> num, den;
    as_numer_denom(exp_f, outArg(num), outArg(den));

    auto xD = dummy("x");
    map_basic_basic d;
    auto temp = exp(mul(I, sym));
    d[temp] = xD;
    num = expand(num), den = expand(den);
    num = num->subs(d);
    den = den->subs(d);

    if (has_symbol(*num, *sym) or has_symbol(*den, *sym)) {
        return conditionset(sym, logical_and({Eq(f, zero)}));
    }

    auto soln = set_complement(solve(num, xD), solve(den, xD));
    if (eq(*soln, *emptyset()))
        return emptyset();
    else if (is_a<FiniteSet>(*soln)) {
        set_set res;
        auto nD
            = dummy("n"); // use the same dummy for finding every solution set.
        auto n = symbol(
            "n"); // replaces the above dummy in final set of solutions.
        map_basic_basic d;
        d[nD] = n;
        for (const auto &elem :
             down_cast<const FiniteSet &>(*soln).get_container()) {
            res.insert(
                invertComplex(exp(mul(I, sym)), finiteset({elem}), sym, nD));
        }
        auto ans = set_union(res)->subs(d);
        if (not is_a_Set(*ans))
            throw SymEngineException("Expected an object of type Set");
        return set_intersection({rcp_static_cast<const Set>(ans), domain});
    }
    return conditionset(sym, logical_and({Eq(f, zero), domain->contains(sym)}));
}

RCP<const Set> solve(const RCP<const Basic> &f, const RCP<const Symbol> &sym,
                     const RCP<const Set> &domain)
{
    if (eq(*f, *boolTrue))
        return domain;
    if (eq(*f, *boolFalse))
        return emptyset();
    if (is_a<Equality>(*f)) {
        return solve(sub(down_cast<const Relational &>(*f).get_arg1(),
                         down_cast<const Relational &>(*f).get_arg2()),
                     sym, domain);
    } else if (is_a<Unequality>(*f)) {
        auto soln = solve(sub(down_cast<const Relational &>(*f).get_arg1(),
                              down_cast<const Relational &>(*f).get_arg2()),
                          sym, domain);
        return set_complement(domain, soln);
    } else if (is_a_Relational(*f)) {
        // Solving inequalities is not implemented yet.
        return conditionset(sym, logical_and({rcp_static_cast<const Boolean>(f),
                                              domain->contains(sym)}));
    }

    if (is_a_Number(*f)) {
        if (eq(*f, *zero)) {
            return domain;
        } else {
            return emptyset();
        }
    }

    if (not has_symbol(*f, *sym))
        return emptyset();

    if (is_a_LinearArgTrigEquation(*f, *sym)) {
        return solve_trig(f, sym, domain);
    }

    if (is_a<Mul>(*f)) {
        auto args = f->get_args();
        set_set solns;
        for (auto &a : args) {
            solns.insert(solve(a, sym, domain));
        }
        return SymEngine::set_union(solns);
    }

    return solve_rational(f, sym, domain);
}

vec_basic linsolve_helper(const DenseMatrix &A, const DenseMatrix &b)
{
    DenseMatrix res(A.nrows(), 1);
    fraction_free_gauss_jordan_solve(A, b, res);
    vec_basic fs;
    for (unsigned i = 0; i < res.nrows(); i++) {
        fs.push_back(res.get(i, 0));
    }
    return fs;
}

vec_basic linsolve(const DenseMatrix &system, const vec_sym &syms)
{
    DenseMatrix A(system.nrows(), system.ncols() - 1), b(system.nrows(), 1);
    system.submatrix(A, 0, 0, system.nrows() - 1, system.ncols() - 2);
    system.submatrix(b, 0, system.ncols() - 1, system.nrows() - 1,
                     system.ncols() - 1);
    return linsolve_helper(A, b);
}

vec_basic linsolve(const vec_basic &system, const vec_sym &syms)
{
    auto mat = linear_eqns_to_matrix(system, syms);
    DenseMatrix A = mat.first, b = mat.second;
    return linsolve_helper(A, b);
}

set_basic get_set_from_vec(const vec_sym &syms)
{
    set_basic sb;
    for (auto &s : syms)
        sb.insert(s);
    return sb;
}

std::pair<DenseMatrix, DenseMatrix>
linear_eqns_to_matrix(const vec_basic &equations, const vec_sym &syms)
{
    auto size = numeric_cast<unsigned int>(syms.size());
    DenseMatrix A(numeric_cast<unsigned int>(equations.size()), size);
    zeros(A);
    vec_basic bvec;

    int row = 0;
    auto gens = get_set_from_vec(syms);
    umap_basic_uint index_of_sym;
    for (unsigned int i = 0; i < size; i++) {
        index_of_sym[syms[i]] = i;
    }
    for (const auto &eqn : equations) {
        auto neqn = eqn;
        if (is_a<Equality>(*eqn)) {
            neqn = sub(down_cast<const Equality &>(*eqn).get_arg2(),
                       down_cast<const Equality &>(*eqn).get_arg1());
        }

        auto mpoly = from_basic<MExprPoly>(neqn, gens);
        RCP<const Basic> rem = zero;
        for (const auto &p : mpoly->get_poly().dict_) {
            RCP<const Basic> res = (p.second.get_basic());
            int whichvar = 0, non_zero = 0;
            RCP<const Basic> cursim;
            for (auto &sym : gens) {
                if (0 != p.first[whichvar]) {
                    non_zero++;
                    cursim = sym;
                    if (p.first[whichvar] != 1 or non_zero == 2) {
                        throw SymEngineException("Expected a linear equation.");
                    }
                }
                whichvar++;
            }
            if (not non_zero) {
                rem = res;
            } else {
                A.set(row, index_of_sym[cursim], res);
            }
        }
        bvec.push_back(neg(rem));
        ++row;
    }
    return std::make_pair(
        A, DenseMatrix(numeric_cast<unsigned int>(equations.size()), 1, bvec));
}
}