pow.cpp
1 #include <symengine/pow.h>
2 #include <symengine/add.h>
3 #include <symengine/complex.h>
4 #include <symengine/symengine_exception.h>
5 #include <symengine/test_visitors.h>
6 
7 namespace SymEngine
8 {
9 
10 Pow::Pow(const RCP<const Basic> &base, const RCP<const Basic> &exp)
11  : base_{base}, exp_{exp}
12 {
13  SYMENGINE_ASSIGN_TYPEID()
14  SYMENGINE_ASSERT(is_canonical(*base, *exp))
15 }
16 
17 bool Pow::is_canonical(const Basic &base, const Basic &exp) const
18 {
19  // e.g. 0**x
20  if (is_a<Integer>(base) and down_cast<const Integer &>(base).is_zero()) {
21  if (is_a_Number(exp)) {
22  return false;
23  } else {
24  return true;
25  }
26  }
27  // e.g. 1**x
28  if (is_a<Integer>(base) and down_cast<const Integer &>(base).is_one())
29  return false;
30  // e.g. x**0.0
32  return false;
33  // e.g. x**1
34  if (is_a<Integer>(exp) and down_cast<const Integer &>(exp).is_one())
35  return false;
36  // e.g. 2**3, (2/3)**4
37  if ((is_a<Integer>(base) or is_a<Rational>(base)) and is_a<Integer>(exp))
38  return false;
39  // e.g. (x*y)**2, should rather be x**2*y**2
40  if (is_a<Mul>(base) and is_a<Integer>(exp))
41  return false;
42  // e.g. (x**y)**2, should rather be x**(2*y)
43  if (is_a<Pow>(base) and is_a<Integer>(exp))
44  return false;
45  // If exp is a rational, it should be between 0 and 1, i.e. we don't
46  // allow things like 2**(-1/2) or 2**(3/2)
47  if ((is_a<Rational>(base) or is_a<Integer>(base)) and is_a<Rational>(exp)
48  and (down_cast<const Rational &>(exp).as_rational_class() < 0
49  or down_cast<const Rational &>(exp).as_rational_class() > 1))
50  return false;
51  // Purely Imaginary complex numbers with integral powers are expanded
52  // e.g (2I)**3
53  if (is_a<Complex>(base) and down_cast<const Complex &>(base).is_re_zero()
54  and is_a<Integer>(exp))
55  return false;
56  // e.g. 0.5^2.0 should be represented as 0.25
57  if (is_a_Number(base) and is_a_Number(exp)
58  and (not down_cast<const Number &>(base).is_exact()
59  or not down_cast<const Number &>(exp).is_exact()))
60  return false;
61  return true;
62 }
63 
64 hash_t Pow::__hash__() const
65 {
66  hash_t seed = SYMENGINE_POW;
67  hash_combine<Basic>(seed, *base_);
68  hash_combine<Basic>(seed, *exp_);
69  return seed;
70 }
71 
72 bool Pow::__eq__(const Basic &o) const
73 {
74  if (is_a<Pow>(o) and eq(*base_, *(down_cast<const Pow &>(o).base_))
75  and eq(*exp_, *(down_cast<const Pow &>(o).exp_)))
76  return true;
77 
78  return false;
79 }
80 
81 int Pow::compare(const Basic &o) const
82 {
83  SYMENGINE_ASSERT(is_a<Pow>(o))
84  const Pow &s = down_cast<const Pow &>(o);
85  int base_cmp = base_->__cmp__(*s.base_);
86  if (base_cmp == 0)
87  return exp_->__cmp__(*s.exp_);
88  else
89  return base_cmp;
90 }
91 
92 RCP<const Basic> pow(const RCP<const Basic> &a, const RCP<const Basic> &b)
93 {
94  if (is_number_and_zero(*b)) {
95  // addnum is used for converting to the type of `b`.
96  return addnum(one, rcp_static_cast<const Number>(b));
97  }
98  if (eq(*b, *one))
99  return a;
100 
101  if (eq(*a, *zero)) {
102  if (is_a_Number(*b)
103  and rcp_static_cast<const Number>(b)->is_positive()) {
104  return zero;
105  } else if (is_a_Number(*b)
106  and rcp_static_cast<const Number>(b)->is_negative()) {
107  return ComplexInf;
108  } else {
109  return make_rcp<const Pow>(a, b);
110  }
111  }
112 
113  if (eq(*a, *one) and not is_a_Number(*b))
114  return one;
115  if (eq(*a, *minus_one)) {
116  if (is_a<Integer>(*b)) {
117  return is_a<Integer>(*div(b, integer(2))) ? one : minus_one;
118  } else if (is_a<Rational>(*b) and eq(*b, *rational(1, 2))) {
119  return I;
120  }
121  }
122 
123  if (is_a_Number(*b)) {
124  if (is_a_Number(*a)) {
125  if (is_a<Integer>(*b)) {
126  return down_cast<const Number &>(*a).pow(
127  *rcp_static_cast<const Number>(b));
128  } else if (is_a<Rational>(*b)) {
129  if (is_a<Rational>(*a)) {
130  return down_cast<const Rational &>(*a).powrat(
131  down_cast<const Rational &>(*b));
132  } else if (is_a<Integer>(*a)) {
133  return down_cast<const Rational &>(*b).rpowrat(
134  down_cast<const Integer &>(*a));
135  } else if (is_a<Complex>(*a)) {
136  return make_rcp<const Pow>(a, b);
137  } else {
138  return down_cast<const Number &>(*a).pow(
139  *rcp_static_cast<const Number>(b));
140  }
141  } else if (is_a<Complex>(*b)
142  and down_cast<const Number &>(*a).is_exact()) {
143  return make_rcp<const Pow>(a, b);
144  } else {
145  return down_cast<const Number &>(*a).pow(
146  *rcp_static_cast<const Number>(b));
147  }
148  } else if (eq(*a, *E)) {
149  RCP<const Number> p = rcp_static_cast<const Number>(b);
150  if (not p->is_exact()) {
151  // Evaluate E**0.2, but not E**2
152  return p->get_eval().exp(*p);
153  }
154  } else if (is_a<Mul>(*a)) {
155  // Expand (x*y)**b = x**b*y**b
156  map_basic_basic d;
157  RCP<const Number> coef = one;
158  down_cast<const Mul &>(*a).power_num(
159  outArg(coef), d, rcp_static_cast<const Number>(b));
160  return Mul::from_dict(coef, std::move(d));
161  }
162  }
163  if (is_a<Pow>(*a) and is_a<Integer>(*b)) {
164  // Convert (x**y)**b = x**(b*y), where 'b' is an integer. This holds for
165  // any complex 'x', 'y' and integer 'b'.
166  RCP<const Pow> A = rcp_static_cast<const Pow>(a);
167  return pow(A->get_base(), mul(A->get_exp(), b));
168  }
169  if (is_a<Pow>(*a)
170  and eq(*down_cast<const Pow &>(*a).get_exp(), *minus_one)) {
171  // Convert (x**-1)**b = x**(-b)
172  RCP<const Pow> A = rcp_static_cast<const Pow>(a);
173  return pow(A->get_base(), neg(b));
174  }
175  return make_rcp<const Pow>(a, b);
176 }
177 
178 // This function can overflow, but it is fast.
179 // TODO: figure out condition for (m, n) when it overflows and raise an
180 // exception.
181 void multinomial_coefficients(unsigned m, unsigned n, map_vec_uint &r)
182 {
183  vec_uint t;
184  unsigned j, tj, start, k;
185  unsigned long long int v;
186  if (m < 2)
187  throw SymEngineException("multinomial_coefficients: m >= 2 must hold.");
188  t.assign(m, 0);
189  t[0] = n;
190  r[t] = 1;
191  if (n == 0)
192  return;
193  j = 0;
194  while (j < m - 1) {
195  tj = t[j];
196  if (j) {
197  t[j] = 0;
198  t[0] = tj;
199  }
200  if (tj > 1) {
201  t[j + 1] += 1;
202  j = 0;
203  start = 1;
204  v = 0;
205  } else {
206  j += 1;
207  start = j + 1;
208  v = r[t];
209  t[j] += 1;
210  }
211  for (k = start; k < m; k++) {
212  if (t[k]) {
213  t[k] -= 1;
214  v += r[t];
215  t[k] += 1;
216  }
217  }
218  t[0] -= 1;
219  r[t] = (v * tj) / (n - t[0]);
220  }
221 }
222 
223 // Slower, but returns exact (possibly large) integers (as mpz)
224 void multinomial_coefficients_mpz(unsigned m, unsigned n, map_vec_mpz &r)
225 {
226  vec_uint t;
227  unsigned j, tj, start, k;
228  integer_class v;
229  if (m < 2)
230  throw SymEngineException("multinomial_coefficients: m >= 2 must hold.");
231  t.assign(m, 0);
232  t[0] = n;
233  r[t] = 1;
234  if (n == 0)
235  return;
236  j = 0;
237  while (j < m - 1) {
238  tj = t[j];
239  if (j) {
240  t[j] = 0;
241  t[0] = tj;
242  }
243  if (tj > 1) {
244  t[j + 1] += 1;
245  j = 0;
246  start = 1;
247  v = 0;
248  } else {
249  j += 1;
250  start = j + 1;
251  v = r[t];
252  t[j] += 1;
253  }
254  for (k = start; k < m; k++) {
255  if (t[k]) {
256  t[k] -= 1;
257  v += r[t];
258  t[k] += 1;
259  }
260  }
261  t[0] -= 1;
262  r[t] = (v * tj) / (n - t[0]);
263  }
264 }
265 
267 {
268  return {base_, exp_};
269 }
270 
271 RCP<const Basic> exp(const RCP<const Basic> &x)
272 {
273  return pow(E, x);
274 }
275 
276 } // namespace SymEngine
Classes and functions relating to the binary operation of addition.
The lowest unit of symbolic representation.
Definition: basic.h:97
static RCP< const Basic > from_dict(const RCP< const Number > &coef, map_basic_basic &&d)
Create a Mul from a dict.
Definition: mul.cpp:115
hash_t __hash__() const override
Definition: pow.cpp:64
Pow(const RCP< const Basic > &base, const RCP< const Basic > &exp)
Pow Constructor.
Definition: pow.cpp:10
bool __eq__(const Basic &o) const override
Definition: pow.cpp:72
vec_basic get_args() const override
Returns the list of arguments.
Definition: pow.cpp:266
bool is_canonical(const Basic &base, const Basic &exp) const
Definition: pow.cpp:17
int compare(const Basic &o) const override
Definition: pow.cpp:81
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a_Number(const Basic &b)
Definition: number.h:130
RCP< const Basic > div(const RCP< const Basic > &a, const RCP< const Basic > &b)
Division.
Definition: mul.cpp:431
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
bool is_number_and_zero(const Basic &b)
Definition: number.h:139
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
Definition: pow.cpp:271
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
RCP< const Number > rational(long n, long d)
convenience creator from two longs
Definition: rational.h:328
RCP< const Number > addnum(const RCP< const Number > &self, const RCP< const Number > &other)
Add self and other
Definition: number.h:81
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
Definition: mul.cpp:443