2 #include <symengine/polys/basic_conversions.h>
5 #include <symengine/as_real_imag.cpp>
10 RCP<const Set> solve_poly_linear(
const vec_basic &coeffs,
11 const RCP<const Set> &domain)
13 if (coeffs.size() != 2) {
14 throw SymEngineException(
"Expected a polynomial of degree 1. Try with "
15 "solve() or solve_poly()");
17 auto root =
neg(
div(coeffs[0], coeffs[1]));
21 RCP<const Set> solve_poly_quadratic(
const vec_basic &coeffs,
22 const RCP<const Set> &domain)
24 if (coeffs.size() != 3) {
25 throw SymEngineException(
"Expected a polynomial of degree 2. Try with "
26 "solve() or solve_poly()");
30 auto b =
div(coeffs[1], a), c =
div(coeffs[0], a);
31 RCP<const Basic> root1, root2;
35 }
else if (
eq(*b, *zero)) {
42 root1 =
add(lterm, rterm);
43 root2 =
sub(lterm, rterm);
48 RCP<const Set> solve_poly_cubic(
const vec_basic &coeffs,
49 const RCP<const Set> &domain)
51 if (coeffs.size() != 4) {
52 throw SymEngineException(
"Expected a polynomial of degree 3. Try with "
53 "solve() or solve_poly()");
57 auto b =
div(coeffs[2], a), c =
div(coeffs[1], a), d =
div(coeffs[0], a);
64 RCP<const Basic> root1, root2, root3;
67 auto fset = solve_poly_quadratic({c, b, one}, domain);
68 SYMENGINE_ASSERT(is_a<FiniteSet>(*fset));
69 auto cont = down_cast<const FiniteSet &>(*fset).get_container();
70 if (cont.size() == 2) {
71 root2 = *cont.begin();
74 root2 = root3 = *cont.begin();
81 if (
eq(*delta, *zero)) {
82 if (
eq(*delta0, *zero)) {
83 root1 = root2 = root3 =
div(
neg(b), i3);
92 auto Cexpr =
div(
add(delta1, temp), i2);
93 if (
eq(*Cexpr, *zero)) {
94 Cexpr =
div(
sub(delta1, temp), i2);
96 auto C =
pow(Cexpr,
div(one, i3));
100 auto cbrt1 =
add(temp, coef);
101 auto cbrt2 =
sub(temp, coef);
111 RCP<const Set> solve_poly_quartic(
const vec_basic &coeffs,
112 const RCP<const Set> &domain)
114 if (coeffs.size() != 5) {
115 throw SymEngineException(
"Expected a polynomial of degree 4. Try with "
116 "solve() or solve_poly()");
124 auto a =
div(coeffs[3], lc), b =
div(coeffs[2], lc), c =
div(coeffs[1], lc),
125 d =
div(coeffs[0], lc);
129 vec_basic newcoeffs(4);
130 newcoeffs[0] = c, newcoeffs[1] = b, newcoeffs[2] = a,
132 auto rcubic = solve_poly_cubic(newcoeffs, domain);
133 SYMENGINE_ASSERT(is_a<FiniteSet>(*rcubic));
134 roots = down_cast<const FiniteSet &>(*rcubic).get_container();
139 auto sqa =
mul(a, a);
140 auto cba =
mul(sqa, a);
141 auto aby4 =
div(a, i4);
149 vec_basic newcoeffs(4);
150 newcoeffs[0] = ff, newcoeffs[1] = e, newcoeffs[2] = zero,
152 auto rcubic = solve_poly_cubic(newcoeffs, domain);
153 SYMENGINE_ASSERT(is_a<FiniteSet>(*rcubic));
154 auto rtemp = down_cast<const FiniteSet &>(*rcubic).get_container();
155 SYMENGINE_ASSERT(rtemp.size() > 0 and rtemp.size() <= 3);
156 for (
auto &r : rtemp) {
157 roots.insert(
sub(r, aby4));
159 roots.insert(
neg(aby4));
160 }
else if (
eq(*ff, *zero)) {
161 vec_basic newcoeffs(3);
162 newcoeffs[0] = g, newcoeffs[1] = e, newcoeffs[2] = one;
163 auto rquad = solve_poly_quadratic(newcoeffs, domain);
164 SYMENGINE_ASSERT(is_a<FiniteSet>(*rquad));
165 auto rtemp = down_cast<const FiniteSet &>(*rquad).get_container();
166 SYMENGINE_ASSERT(rtemp.size() > 0 and rtemp.size() <= 2);
167 for (
auto &r : rtemp) {
168 auto sqrtr =
sqrt(r);
169 roots.insert(
sub(sqrtr, aby4));
170 roots.insert(
sub(
neg(sqrtr), aby4));
174 vec_basic newcoeffs(4);
175 newcoeffs[0] =
neg(
div(
mul(ff, ff), i64)),
177 newcoeffs[2] =
div(e, i2);
180 auto rcubic = solve_poly_cubic(newcoeffs);
181 SYMENGINE_ASSERT(is_a<FiniteSet>(*rcubic));
182 roots = down_cast<const FiniteSet &>(*rcubic).get_container();
183 SYMENGINE_ASSERT(roots.size() > 0 and roots.size() <= 3);
184 auto p =
sqrt(*roots.begin());
186 if (roots.size() > 1) {
191 roots.insert(
add({p, q, r,
neg(aby4)}));
200 RCP<const Set> solve_poly_heuristics(
const vec_basic &coeffs,
201 const RCP<const Set> &domain)
203 auto degree = coeffs.size() - 1;
206 if (
eq(*coeffs[0], *zero)) {
213 return solve_poly_linear(coeffs, domain);
215 return solve_poly_quadratic(coeffs, domain);
217 return solve_poly_cubic(coeffs, domain);
219 return solve_poly_quartic(coeffs, domain);
221 throw SymEngineException(
222 "expected a polynomial of order between 0 to 4");
226 inline RCP<const Basic> get_coeff_basic(
const integer_class &i)
231 inline RCP<const Basic> get_coeff_basic(
const Expression &i)
233 return i.get_basic();
236 template <
typename Poly>
237 inline vec_basic extract_coeffs(
const RCP<const Poly> &f)
239 int degree = f->get_degree();
241 for (
int i = 0; i <= degree; i++)
242 coeffs.
push_back(get_coeff_basic(f->get_coeff(i)));
246 RCP<const Set> solve_poly(
const RCP<const Basic> &f,
247 const RCP<const Symbol> &sym,
248 const RCP<const Set> &domain)
251 #if defined(HAVE_SYMENGINE_FLINT) && __FLINT_RELEASE > 20502
253 auto poly = from_basic<UIntPolyFlint>(f, sym);
254 auto fac = factors(*poly);
256 for (
const auto &elem : fac) {
257 auto uip = UIntPoly::from_poly(*elem.first);
258 auto degree = uip->get_poly().degree();
261 solve_poly_heuristics(extract_coeffs(uip), domain));
265 domain->contains(sym)})));
268 return SymEngine::set_union(solns);
269 }
catch (SymEngineException &x) {
273 RCP<const Basic> gen = rcp_static_cast<const Basic>(sym);
274 auto uexp = from_basic<UExprPoly>(f, gen);
275 auto degree = uexp->get_degree();
277 return solve_poly_heuristics(extract_coeffs(uexp), domain);
280 logical_and({
Eq(f, zero), domain->contains(sym)}));
284 RCP<const Set> solve_rational(
const RCP<const Basic> &f,
285 const RCP<const Symbol> &sym,
286 const RCP<const Set> &domain)
288 RCP<const Basic> num, den;
289 as_numer_denom(f, outArg(num), outArg(den));
290 if (has_symbol(*den, *sym)) {
291 auto numsoln = solve(num, sym, domain);
292 auto densoln = solve(den, sym, domain);
293 return set_complement(numsoln, densoln);
295 return solve_poly(num, sym, domain);
301 :
public BaseVisitor<IsALinearArgTrigVisitor, LocalStopVisitor>
304 Ptr<const Symbol> x_;
310 bool apply(
const Basic &b)
314 preorder_traversal_local_stop(b, *
this);
318 bool apply(
const RCP<const Basic> &b)
323 void bvisit(
const Basic &x)
328 void bvisit(
const Symbol &x)
336 template <
typename T,
338 = enable_if_t<std::is_base_of<TrigFunction, T>::value
340 void bvisit(
const T &x)
342 is_ = (from_basic<UExprPoly>(x.get_args()[0], (*x_).rcp_from_this())
351 bool is_a_LinearArgTrigEquation(
const Basic &b,
const Symbol &x)
360 RCP<const Set> result_;
362 RCP<const Dummy> nD_;
363 RCP<const Symbol> sym_;
364 RCP<const Set> domain_;
368 RCP<const Symbol> sym, RCP<const Set> domain)
369 : gY_(gY), nD_(nD), sym_(sym), domain_(domain)
373 void bvisit(
const Basic &x)
378 void bvisit(
const Add &x)
382 if (has_symbol(*elem, *sym_)) {
388 auto depX =
add(f1X), indepX =
add(f2X);
389 if (not
eq(*indepX, *zero)) {
390 gY_ = imageset(nD_,
sub(nD_, indepX), gY_);
391 result_ = apply(*depX);
397 void bvisit(
const Mul &x)
401 if (has_symbol(*elem, *sym_)) {
407 auto depX =
mul(f1X), indepX =
mul(f2X);
408 if (not
eq(*indepX, *one)) {
409 if (
eq(*indepX, *NegInf) or
eq(*indepX, *Inf)
410 or
eq(*indepX, *ComplexInf)) {
413 gY_ = imageset(nD_,
div(nD_, indepX), gY_);
414 result_ = apply(*depX);
421 void bvisit(
const Pow &x)
423 if (
eq(*x.
get_base(), *E) and is_a<FiniteSet>(*gY_)) {
425 for (
const auto &elem :
426 down_cast<const FiniteSet &>(*gY_).get_container()) {
427 if (
eq(*elem, *zero))
429 RCP<const Basic> re, im;
430 as_real_imag(elem, outArg(re), outArg(im));
432 auto logarg =
atan2(im, re);
441 gY_ = set_union(inv);
448 RCP<const Set> apply(
const Basic &b)
452 return set_intersection({domain_, result_});
456 RCP<const Set> invertComplex(
const RCP<const Basic> &fX,
457 const RCP<const Set> &gY,
458 const RCP<const Symbol> &sym,
459 const RCP<const Dummy> &nD,
460 const RCP<const Set> &domain)
466 RCP<const Set> solve_trig(
const RCP<const Basic> &f,
467 const RCP<const Symbol> &sym,
468 const RCP<const Set> &domain)
471 auto exp_f = rewrite_as_exp(f);
472 RCP<const Basic> num, den;
473 as_numer_denom(exp_f, outArg(num), outArg(den));
475 auto xD =
dummy(
"x");
477 auto temp =
exp(
mul(I, sym));
483 if (has_symbol(*num, *sym) or has_symbol(*den, *sym)) {
487 auto soln = set_complement(solve(num, xD), solve(den, xD));
490 else if (is_a<FiniteSet>(*soln)) {
498 for (
const auto &elem :
499 down_cast<const FiniteSet &>(*soln).get_container()) {
504 if (not is_a_Set(*ans))
505 throw SymEngineException(
"Expected an object of type Set");
508 return conditionset(sym, logical_and({
Eq(f, zero), domain->contains(sym)}));
511 RCP<const Set> solve(
const RCP<const Basic> &f,
const RCP<const Symbol> &sym,
512 const RCP<const Set> &domain)
514 if (
eq(*f, *boolTrue))
516 if (
eq(*f, *boolFalse))
518 if (is_a<Equality>(*f)) {
519 return solve(
sub(down_cast<const Relational &>(*f).get_arg1(),
520 down_cast<const Relational &>(*f).get_arg2()),
522 }
else if (is_a<Unequality>(*f)) {
523 auto soln = solve(
sub(down_cast<const Relational &>(*f).get_arg1(),
524 down_cast<const Relational &>(*f).get_arg2()),
526 return set_complement(domain, soln);
527 }
else if (is_a_Relational(*f)) {
529 return conditionset(sym, logical_and({rcp_static_cast<const Boolean>(f),
530 domain->contains(sym)}));
541 if (not has_symbol(*f, *sym))
544 if (is_a_LinearArgTrigEquation(*f, *sym)) {
545 return solve_trig(f, sym, domain);
549 auto args = f->get_args();
551 for (
auto &a : args) {
552 solns.
insert(solve(a, sym, domain));
554 return SymEngine::set_union(solns);
557 return solve_rational(f, sym, domain);
560 vec_basic linsolve_helper(
const DenseMatrix &A,
const DenseMatrix &b)
562 DenseMatrix res(A.nrows(), 1);
563 fraction_free_gauss_jordan_solve(A, b, res);
565 for (
unsigned i = 0; i < res.nrows(); i++) {
571 vec_basic linsolve(
const DenseMatrix &system,
const vec_sym &syms)
577 return linsolve_helper(A, b);
580 vec_basic linsolve(
const vec_basic &system,
const vec_sym &syms)
582 auto mat = linear_eqns_to_matrix(system, syms);
583 DenseMatrix A = mat.first, b = mat.second;
584 return linsolve_helper(A, b);
587 set_basic get_set_from_vec(
const vec_sym &syms)
596 linear_eqns_to_matrix(
const vec_basic &equations,
const vec_sym &syms)
598 auto size = numeric_cast<unsigned int>(syms.size());
599 DenseMatrix A(numeric_cast<unsigned int>(equations.size()), size);
604 auto gens = get_set_from_vec(syms);
605 umap_basic_uint index_of_sym;
606 for (
unsigned int i = 0; i < size; i++) {
607 index_of_sym[syms[i]] = i;
609 for (
const auto &eqn : equations) {
611 if (is_a<Equality>(*eqn)) {
612 neqn =
sub(down_cast<const Equality &>(*eqn).get_arg2(),
613 down_cast<const Equality &>(*eqn).get_arg1());
616 auto mpoly = from_basic<MExprPoly>(neqn, gens);
617 RCP<const Basic> rem = zero;
618 for (
const auto &p : mpoly->get_poly().dict_) {
619 RCP<const Basic> res = (p.second.get_basic());
620 int whichvar = 0, non_zero = 0;
621 RCP<const Basic> cursim;
622 for (
auto &sym : gens) {
623 if (0 != p.first[whichvar]) {
626 if (p.first[whichvar] != 1 or non_zero == 2) {
627 throw SymEngineException(
"Expected a linear equation.");
635 A.set(row, index_of_sym[cursim], res);
638 bvec.push_back(
neg(rem));
642 A, DenseMatrix(numeric_cast<unsigned int>(equations.size()), 1, bvec));
The base class for representing addition in symbolic expressions.
vec_basic get_args() const override
Returns the arguments of the Add.
The lowest unit of symbolic representation.
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_exp() const
RCP< const Basic > get_base() const
Main namespace for SymEngine package.
bool is_a_Number(const Basic &b)
RCP< const Set > interval(const RCP< const Number > &start, const RCP< const Number > &end, const bool left_open=false, const bool right_open=false)
RCP< const Basic > div(const RCP< const Basic > &a, const RCP< const Basic > &b)
Division.
RCP< const Dummy > dummy()
inline version to return Dummy
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
RCP< const Symbol > symbol(const std::string &name)
inline version to return Symbol
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
RCP< const EmptySet > emptyset()
RCP< const Basic > atan2(const RCP< const Basic > &num, const RCP< const Basic > &den)
Canonicalize ATan2:
RCP< const Basic > sub(const RCP< const Basic > &a, const RCP< const Basic > &b)
Substracts b from a.
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
RCP< const Set > conditionset(const RCP< const Basic > &sym, const RCP< const Boolean > &condition)
RCP< const Basic > log(const RCP< const Basic > &arg)
Returns the Natural Logarithm from argument arg
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
RCP< const Boolean > Eq(const RCP< const Basic > &lhs)
Returns the canonicalized Equality object from a single argument.
RCP< const Set > finiteset(const set_basic &container)
RCP< const Basic > expand(const RCP< const Basic > &self, bool deep=true)
Expands self
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
T set_intersection(T... args)