Program Listing for File pow.cpp¶
↰ Return to documentation for file (symengine/symengine/pow.cpp
)
#include <symengine/pow.h>
#include <symengine/add.h>
#include <symengine/complex.h>
#include <symengine/symengine_exception.h>
#include <symengine/test_visitors.h>
namespace SymEngine
{
Pow::Pow(const RCP<const Basic> &base, const RCP<const Basic> &exp)
: base_{base}, exp_{exp}
{
SYMENGINE_ASSIGN_TYPEID()
SYMENGINE_ASSERT(is_canonical(*base, *exp))
}
bool Pow::is_canonical(const Basic &base, const Basic &exp) const
{
// e.g. 0**x
if (is_a<Integer>(base) and down_cast<const Integer &>(base).is_zero()) {
if (is_a_Number(exp)) {
return false;
} else {
return true;
}
}
// e.g. 1**x
if (is_a<Integer>(base) and down_cast<const Integer &>(base).is_one())
return false;
// e.g. x**0.0
if (is_number_and_zero(exp))
return false;
// e.g. x**1
if (is_a<Integer>(exp) and down_cast<const Integer &>(exp).is_one())
return false;
// e.g. 2**3, (2/3)**4
if ((is_a<Integer>(base) or is_a<Rational>(base)) and is_a<Integer>(exp))
return false;
// e.g. (x*y)**2, should rather be x**2*y**2
if (is_a<Mul>(base) and is_a<Integer>(exp))
return false;
// e.g. (x**y)**2, should rather be x**(2*y)
if (is_a<Pow>(base) and is_a<Integer>(exp))
return false;
// If exp is a rational, it should be between 0 and 1, i.e. we don't
// allow things like 2**(-1/2) or 2**(3/2)
if ((is_a<Rational>(base) or is_a<Integer>(base)) and is_a<Rational>(exp)
and (down_cast<const Rational &>(exp).as_rational_class() < 0
or down_cast<const Rational &>(exp).as_rational_class() > 1))
return false;
// Purely Imaginary complex numbers with integral powers are expanded
// e.g (2I)**3
if (is_a<Complex>(base) and down_cast<const Complex &>(base).is_re_zero()
and is_a<Integer>(exp))
return false;
// e.g. 0.5^2.0 should be represented as 0.25
if (is_a_Number(base) and not down_cast<const Number &>(base).is_exact()
and is_a_Number(exp) and not down_cast<const Number &>(exp).is_exact())
return false;
return true;
}
hash_t Pow::__hash__() const
{
hash_t seed = SYMENGINE_POW;
hash_combine<Basic>(seed, *base_);
hash_combine<Basic>(seed, *exp_);
return seed;
}
bool Pow::__eq__(const Basic &o) const
{
if (is_a<Pow>(o) and eq(*base_, *(down_cast<const Pow &>(o).base_))
and eq(*exp_, *(down_cast<const Pow &>(o).exp_)))
return true;
return false;
}
int Pow::compare(const Basic &o) const
{
SYMENGINE_ASSERT(is_a<Pow>(o))
const Pow &s = down_cast<const Pow &>(o);
int base_cmp = base_->__cmp__(*s.base_);
if (base_cmp == 0)
return exp_->__cmp__(*s.exp_);
else
return base_cmp;
}
RCP<const Basic> pow(const RCP<const Basic> &a, const RCP<const Basic> &b)
{
if (is_number_and_zero(*b)) {
// addnum is used for converting to the type of `b`.
return addnum(one, rcp_static_cast<const Number>(b));
}
if (eq(*b, *one))
return a;
if (eq(*a, *zero)) {
if (is_a_Number(*b)
and rcp_static_cast<const Number>(b)->is_positive()) {
return zero;
} else if (is_a_Number(*b)
and rcp_static_cast<const Number>(b)->is_negative()) {
return ComplexInf;
} else {
return make_rcp<const Pow>(a, b);
}
}
if (eq(*a, *one) and not is_a_Number(*b))
return one;
if (eq(*a, *minus_one)) {
if (is_a<Integer>(*b)) {
return is_a<Integer>(*div(b, integer(2))) ? one : minus_one;
} else if (is_a<Rational>(*b) and eq(*b, *rational(1, 2))) {
return I;
}
}
if (is_a_Number(*b)) {
if (is_a_Number(*a)) {
if (is_a<Integer>(*b)) {
return down_cast<const Number &>(*a).pow(
*rcp_static_cast<const Number>(b));
} else if (is_a<Rational>(*b)) {
if (is_a<Rational>(*a)) {
return down_cast<const Rational &>(*a).powrat(
down_cast<const Rational &>(*b));
} else if (is_a<Integer>(*a)) {
return down_cast<const Rational &>(*b).rpowrat(
down_cast<const Integer &>(*a));
} else if (is_a<Complex>(*a)) {
return make_rcp<const Pow>(a, b);
} else {
return down_cast<const Number &>(*a).pow(
*rcp_static_cast<const Number>(b));
}
} else if (is_a<Complex>(*b)
and down_cast<const Number &>(*a).is_exact()) {
return make_rcp<const Pow>(a, b);
} else {
return down_cast<const Number &>(*a).pow(
*rcp_static_cast<const Number>(b));
}
} else if (eq(*a, *E)) {
RCP<const Number> p = rcp_static_cast<const Number>(b);
if (not p->is_exact()) {
// Evaluate E**0.2, but not E**2
return p->get_eval().exp(*p);
}
} else if (is_a<Mul>(*a)) {
// Expand (x*y)**b = x**b*y**b
map_basic_basic d;
RCP<const Number> coef = one;
down_cast<const Mul &>(*a).power_num(
outArg(coef), d, rcp_static_cast<const Number>(b));
return Mul::from_dict(coef, std::move(d));
}
}
if (is_a<Pow>(*a) and is_a<Integer>(*b)) {
// Convert (x**y)**b = x**(b*y), where 'b' is an integer. This holds for
// any complex 'x', 'y' and integer 'b'.
RCP<const Pow> A = rcp_static_cast<const Pow>(a);
return pow(A->get_base(), mul(A->get_exp(), b));
}
if (is_a<Pow>(*a)
and eq(*down_cast<const Pow &>(*a).get_exp(), *minus_one)) {
// Convert (x**-1)**b = x**(-b)
RCP<const Pow> A = rcp_static_cast<const Pow>(a);
return pow(A->get_base(), neg(b));
}
return make_rcp<const Pow>(a, b);
}
// This function can overflow, but it is fast.
// TODO: figure out condition for (m, n) when it overflows and raise an
// exception.
void multinomial_coefficients(unsigned m, unsigned n, map_vec_uint &r)
{
vec_uint t;
unsigned j, tj, start, k;
unsigned long long int v;
if (m < 2)
throw SymEngineException("multinomial_coefficients: m >= 2 must hold.");
t.assign(m, 0);
t[0] = n;
r[t] = 1;
if (n == 0)
return;
j = 0;
while (j < m - 1) {
tj = t[j];
if (j) {
t[j] = 0;
t[0] = tj;
}
if (tj > 1) {
t[j + 1] += 1;
j = 0;
start = 1;
v = 0;
} else {
j += 1;
start = j + 1;
v = r[t];
t[j] += 1;
}
for (k = start; k < m; k++) {
if (t[k]) {
t[k] -= 1;
v += r[t];
t[k] += 1;
}
}
t[0] -= 1;
r[t] = (v * tj) / (n - t[0]);
}
}
// Slower, but returns exact (possibly large) integers (as mpz)
void multinomial_coefficients_mpz(unsigned m, unsigned n, map_vec_mpz &r)
{
vec_uint t;
unsigned j, tj, start, k;
integer_class v;
if (m < 2)
throw SymEngineException("multinomial_coefficients: m >= 2 must hold.");
t.assign(m, 0);
t[0] = n;
r[t] = 1;
if (n == 0)
return;
j = 0;
while (j < m - 1) {
tj = t[j];
if (j) {
t[j] = 0;
t[0] = tj;
}
if (tj > 1) {
t[j + 1] += 1;
j = 0;
start = 1;
v = 0;
} else {
j += 1;
start = j + 1;
v = r[t];
t[j] += 1;
}
for (k = start; k < m; k++) {
if (t[k]) {
t[k] -= 1;
v += r[t];
t[k] += 1;
}
}
t[0] -= 1;
r[t] = (v * tj) / (n - t[0]);
}
}
vec_basic Pow::get_args() const
{
return {base_, exp_};
}
RCP<const Basic> exp(const RCP<const Basic> &x)
{
return pow(E, x);
}
} // SymEngine