Searching...
No Matches
pow.cpp
1#include <symengine/pow.h>
3#include <symengine/complex.h>
4#include <symengine/symengine_exception.h>
5#include <symengine/test_visitors.h>
6
7namespace SymEngine
8{
9
10Pow::Pow(const RCP<const Basic> &base, const RCP<const Basic> &exp)
11 : base_{base}, exp_{exp}
12{
13 SYMENGINE_ASSIGN_TYPEID()
14 SYMENGINE_ASSERT(is_canonical(*base, *exp))
15}
16
17bool Pow::is_canonical(const Basic &base, const Basic &exp) const
18{
19 // e.g. 0**x
20 if (is_a<Integer>(base) and down_cast<const Integer &>(base).is_zero()) {
21 if (is_a_Number(exp)) {
22 return false;
23 } else {
24 return true;
25 }
26 }
27 // e.g. 1**x
28 if (is_a<Integer>(base) and down_cast<const Integer &>(base).is_one())
29 return false;
30 // e.g. x**0.0
32 return false;
33 // e.g. x**1
34 if (is_a<Integer>(exp) and down_cast<const Integer &>(exp).is_one())
35 return false;
36 // e.g. 2**3, (2/3)**4
37 if ((is_a<Integer>(base) or is_a<Rational>(base)) and is_a<Integer>(exp))
38 return false;
39 // e.g. (x*y)**2, should rather be x**2*y**2
40 if (is_a<Mul>(base) and is_a<Integer>(exp))
41 return false;
42 // e.g. (x**y)**2, should rather be x**(2*y)
43 if (is_a<Pow>(base) and is_a<Integer>(exp))
44 return false;
45 // If exp is a rational, it should be between 0 and 1, i.e. we don't
46 // allow things like 2**(-1/2) or 2**(3/2)
47 if ((is_a<Rational>(base) or is_a<Integer>(base)) and is_a<Rational>(exp)
48 and (down_cast<const Rational &>(exp).as_rational_class() < 0
49 or down_cast<const Rational &>(exp).as_rational_class() > 1))
50 return false;
51 // Purely Imaginary complex numbers with integral powers are expanded
52 // e.g (2I)**3
53 if (is_a<Complex>(base) and down_cast<const Complex &>(base).is_re_zero()
54 and is_a<Integer>(exp))
55 return false;
56 // e.g. 0.5^2.0 should be represented as 0.25
57 if (is_a_Number(base) and is_a_Number(exp)
58 and (not down_cast<const Number &>(base).is_exact()
59 or not down_cast<const Number &>(exp).is_exact()))
60 return false;
61 return true;
62}
63
64hash_t Pow::__hash__() const
65{
66 hash_t seed = SYMENGINE_POW;
67 hash_combine<Basic>(seed, *base_);
68 hash_combine<Basic>(seed, *exp_);
69 return seed;
70}
71
72bool Pow::__eq__(const Basic &o) const
73{
74 if (is_a<Pow>(o) and eq(*base_, *(down_cast<const Pow &>(o).base_))
75 and eq(*exp_, *(down_cast<const Pow &>(o).exp_)))
76 return true;
77
78 return false;
79}
80
81int Pow::compare(const Basic &o) const
82{
83 SYMENGINE_ASSERT(is_a<Pow>(o))
84 const Pow &s = down_cast<const Pow &>(o);
85 int base_cmp = base_->__cmp__(*s.base_);
86 if (base_cmp == 0)
87 return exp_->__cmp__(*s.exp_);
88 else
89 return base_cmp;
90}
91
92RCP<const Basic> pow(const RCP<const Basic> &a, const RCP<const Basic> &b)
93{
94 if (is_number_and_zero(*b)) {
95 // addnum is used for converting to the type of b.
97 }
98 if (eq(*b, *one))
99 return a;
100
101 if (eq(*a, *zero)) {
102 if (is_a_Number(*b)
103 and rcp_static_cast<const Number>(b)->is_positive()) {
104 return zero;
105 } else if (is_a_Number(*b)
106 and rcp_static_cast<const Number>(b)->is_negative()) {
107 return ComplexInf;
108 } else {
109 return make_rcp<const Pow>(a, b);
110 }
111 }
112
113 if (eq(*a, *one) and not is_a_Number(*b))
114 return one;
115 if (eq(*a, *minus_one)) {
116 if (is_a<Integer>(*b)) {
117 return is_a<Integer>(*div(b, integer(2))) ? one : minus_one;
118 } else if (is_a<Rational>(*b) and eq(*b, *rational(1, 2))) {
119 return I;
120 }
121 }
122
123 if (is_a_Number(*b)) {
124 if (is_a_Number(*a)) {
125 if (is_a<Integer>(*b)) {
126 return down_cast<const Number &>(*a).pow(
127 *rcp_static_cast<const Number>(b));
128 } else if (is_a<Rational>(*b)) {
129 if (is_a<Rational>(*a)) {
130 return down_cast<const Rational &>(*a).powrat(
131 down_cast<const Rational &>(*b));
132 } else if (is_a<Integer>(*a)) {
133 return down_cast<const Rational &>(*b).rpowrat(
134 down_cast<const Integer &>(*a));
135 } else if (is_a<Complex>(*a)) {
136 return make_rcp<const Pow>(a, b);
137 } else {
138 return down_cast<const Number &>(*a).pow(
139 *rcp_static_cast<const Number>(b));
140 }
141 } else if (is_a<Complex>(*b)
142 and down_cast<const Number &>(*a).is_exact()) {
143 return make_rcp<const Pow>(a, b);
144 } else {
145 return down_cast<const Number &>(*a).pow(
146 *rcp_static_cast<const Number>(b));
147 }
148 } else if (eq(*a, *E)) {
149 RCP<const Number> p = rcp_static_cast<const Number>(b);
150 if (not p->is_exact()) {
151 // Evaluate E**0.2, but not E**2
152 return p->get_eval().exp(*p);
153 }
154 } else if (is_a<Mul>(*a)) {
155 // Expand (x*y)**b = x**b*y**b
157 RCP<const Number> coef = one;
158 down_cast<const Mul &>(*a).power_num(
159 outArg(coef), d, rcp_static_cast<const Number>(b));
160 return Mul::from_dict(coef, std::move(d));
161 }
162 }
163 if (is_a<Pow>(*a) and is_a<Integer>(*b)) {
164 // Convert (x**y)**b = x**(b*y), where 'b' is an integer. This holds for
165 // any complex 'x', 'y' and integer 'b'.
166 RCP<const Pow> A = rcp_static_cast<const Pow>(a);
167 return pow(A->get_base(), mul(A->get_exp(), b));
168 }
169 if (is_a<Pow>(*a)
170 and eq(*down_cast<const Pow &>(*a).get_exp(), *minus_one)) {
171 // Convert (x**-1)**b = x**(-b)
172 RCP<const Pow> A = rcp_static_cast<const Pow>(a);
173 return pow(A->get_base(), neg(b));
174 }
175 return make_rcp<const Pow>(a, b);
176}
177
178// This function can overflow, but it is fast.
179// TODO: figure out condition for (m, n) when it overflows and raise an
180// exception.
181void multinomial_coefficients(unsigned m, unsigned n, map_vec_uint &r)
182{
183 vec_uint t;
184 unsigned j, tj, start, k;
185 unsigned long long int v;
186 if (m < 2)
187 throw SymEngineException("multinomial_coefficients: m >= 2 must hold.");
188 t.assign(m, 0);
189 t[0] = n;
190 r[t] = 1;
191 if (n == 0)
192 return;
193 j = 0;
194 while (j < m - 1) {
195 tj = t[j];
196 if (j) {
197 t[j] = 0;
198 t[0] = tj;
199 }
200 if (tj > 1) {
201 t[j + 1] += 1;
202 j = 0;
203 start = 1;
204 v = 0;
205 } else {
206 j += 1;
207 start = j + 1;
208 v = r[t];
209 t[j] += 1;
210 }
211 for (k = start; k < m; k++) {
212 if (t[k]) {
213 t[k] -= 1;
214 v += r[t];
215 t[k] += 1;
216 }
217 }
218 t[0] -= 1;
219 r[t] = (v * tj) / (n - t[0]);
220 }
221}
222
223// Slower, but returns exact (possibly large) integers (as mpz)
224void multinomial_coefficients_mpz(unsigned m, unsigned n, map_vec_mpz &r)
225{
226 vec_uint t;
227 unsigned j, tj, start, k;
228 integer_class v;
229 if (m < 2)
230 throw SymEngineException("multinomial_coefficients: m >= 2 must hold.");
231 t.assign(m, 0);
232 t[0] = n;
233 r[t] = 1;
234 if (n == 0)
235 return;
236 j = 0;
237 while (j < m - 1) {
238 tj = t[j];
239 if (j) {
240 t[j] = 0;
241 t[0] = tj;
242 }
243 if (tj > 1) {
244 t[j + 1] += 1;
245 j = 0;
246 start = 1;
247 v = 0;
248 } else {
249 j += 1;
250 start = j + 1;
251 v = r[t];
252 t[j] += 1;
253 }
254 for (k = start; k < m; k++) {
255 if (t[k]) {
256 t[k] -= 1;
257 v += r[t];
258 t[k] += 1;
259 }
260 }
261 t[0] -= 1;
262 r[t] = (v * tj) / (n - t[0]);
263 }
264}
265
267{
268 return {base_, exp_};
269}
270
271RCP<const Basic> exp(const RCP<const Basic> &x)
272{
273 return pow(E, x);
274}
275
276} // namespace SymEngine
Classes and functions relating to the binary operation of addition.
The lowest unit of symbolic representation.
Definition: basic.h:97
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:115
hash_t __hash__() const override
Definition: pow.cpp:64
Pow(const RCP< const Basic > &base, const RCP< const Basic > &exp)
Pow Constructor.
Definition: pow.cpp:10
bool __eq__(const Basic &o) const override
Definition: pow.cpp:72
vec_basic get_args() const override
Returns the list of arguments.
Definition: pow.cpp:266
bool is_canonical(const Basic &base, const Basic &exp) const
Definition: pow.cpp:17
int compare(const Basic &o) const override
Definition: pow.cpp:81
T move(T... args)
Main namespace for SymEngine package.
bool is_a_Number(const Basic &b)
Definition: number.h:130
RCP< const Basic > div(const RCP< const Basic > &a, const RCP< const Basic > &b)
Division.
Definition: mul.cpp:431
bool is_number_and_zero(const Basic &b)
Definition: number.h:139
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Number > addnum(const RCP< const Number > &self, const RCP< const Number > &other)
Definition: number.h:81
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
Definition: pow.cpp:271
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328
std::enable_if< std::is_integral< T >::value, RCP< constInteger > >::type integer(T i)
Definition: integer.h:197
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
Definition: mul.cpp:443