msymenginepoly.h
1 #ifndef SYMENGINE_POLYNOMIALS_MULTIVARIATE
2 #define SYMENGINE_POLYNOMIALS_MULTIVARIATE
3 
4 #include <symengine/expression.h>
5 #include <symengine/monomials.h>
7 #include <symengine/polys/uexprpoly.h>
8 #include <symengine/symengine_casts.h>
9 
10 namespace SymEngine
11 {
12 
13 template <typename Vec, typename Value, typename Wrapper>
15 {
16 public:
18  Dict dict_;
19  unsigned int vec_size;
20 
21  typedef Vec vec_type;
22  typedef Value coef_type;
23  typedef Dict dict_type;
24 
25  UDictWrapper(unsigned int s) SYMENGINE_NOEXCEPT
26  {
27  vec_size = s;
28  }
29 
30  UDictWrapper() SYMENGINE_NOEXCEPT {}
31 
32  ~UDictWrapper() SYMENGINE_NOEXCEPT {}
33 
34  UDictWrapper(Dict &&p, unsigned int sz)
35  {
36  auto iter = p.begin();
37  while (iter != p.end()) {
38  if (iter->second == 0) {
39  auto toErase = iter;
40  iter++;
41  p.erase(toErase);
42  } else {
43  iter++;
44  }
45  }
46 
47  dict_ = p;
48  vec_size = sz;
49  }
50 
51  UDictWrapper(const Dict &p, unsigned int sz)
52  {
53  for (auto &iter : p) {
54  if (iter.second != Value(0))
55  dict_[iter.first] = iter.second;
56  }
57  vec_size = sz;
58  }
59 
60  Wrapper &operator=(Wrapper &&other)
61  {
62  if (this != &other)
63  dict_ = std::move(other.dict_);
64  return static_cast<Wrapper &>(*this);
65  }
66 
67  friend Wrapper operator+(const Wrapper &a, const Wrapper &b)
68  {
69  SYMENGINE_ASSERT(a.vec_size == b.vec_size)
70  Wrapper c = a;
71  c += b;
72  return c;
73  }
74 
75  // both wrappers must have "aligned" vectors, ie same size
76  // and vector positions refer to the same generators
77  Wrapper &operator+=(const Wrapper &other)
78  {
79  SYMENGINE_ASSERT(vec_size == other.vec_size)
80 
81  for (auto &iter : other.dict_) {
82  auto t = dict_.find(iter.first);
83  if (t != dict_.end()) {
84  t->second += iter.second;
85  if (t->second == 0)
86  dict_.erase(t);
87  } else {
88  dict_.insert(t, {iter.first, iter.second});
89  }
90  }
91  return static_cast<Wrapper &>(*this);
92  }
93 
94  friend Wrapper operator-(const Wrapper &a, const Wrapper &b)
95  {
96  SYMENGINE_ASSERT(a.vec_size == b.vec_size)
97 
98  Wrapper c = a;
99  c -= b;
100  return c;
101  }
102 
103  Wrapper operator-() const
104  {
105  auto c = *this;
106  for (auto &iter : c.dict_)
107  iter.second *= -1;
108  return static_cast<Wrapper &>(c);
109  }
110 
111  // both wrappers must have "aligned" vectors, ie same size
112  // and vector positions refer to the same generators
113  Wrapper &operator-=(const Wrapper &other)
114  {
115  SYMENGINE_ASSERT(vec_size == other.vec_size)
116 
117  for (auto &iter : other.dict_) {
118  auto t = dict_.find(iter.first);
119  if (t != dict_.end()) {
120  t->second -= iter.second;
121  if (t->second == 0)
122  dict_.erase(t);
123  } else {
124  dict_.insert(t, {iter.first, -iter.second});
125  }
126  }
127  return static_cast<Wrapper &>(*this);
128  }
129 
130  static Wrapper mul(const Wrapper &a, const Wrapper &b)
131  {
132  SYMENGINE_ASSERT(a.vec_size == b.vec_size)
133 
134  Wrapper p(a.vec_size);
135  for (auto const &a_ : a.dict_) {
136  for (auto const &b_ : b.dict_) {
137 
138  Vec target(a.vec_size, 0);
139  for (unsigned int i = 0; i < a.vec_size; i++)
140  target[i] = a_.first[i] + b_.first[i];
141 
142  if (p.dict_.find(target) == p.dict_.end()) {
143  p.dict_.insert({target, a_.second * b_.second});
144  } else {
145  p.dict_.find(target)->second += a_.second * b_.second;
146  }
147  }
148  }
149 
150  for (auto it = p.dict_.begin(); it != p.dict_.end();) {
151  if (it->second == 0) {
152  p.dict_.erase(it++);
153  } else {
154  ++it;
155  }
156  }
157  return p;
158  }
159 
160  static Wrapper pow(const Wrapper &a, unsigned int p)
161  {
162  Wrapper tmp = a, res(a.vec_size);
163 
164  Vec zero_v(a.vec_size, 0);
165  res.dict_[zero_v] = 1_z;
166 
167  while (p != 1) {
168  if (p % 2 == 0) {
169  tmp = tmp * tmp;
170  } else {
171  res = res * tmp;
172  tmp = tmp * tmp;
173  }
174  p >>= 1;
175  }
176 
177  return (res * tmp);
178  }
179 
180  friend Wrapper operator*(const Wrapper &a, const Wrapper &b)
181  {
182  SYMENGINE_ASSERT(a.vec_size == b.vec_size)
183  return Wrapper::mul(a, b);
184  }
185 
186  Wrapper &operator*=(const Wrapper &other)
187  {
188  SYMENGINE_ASSERT(vec_size == other.vec_size)
189 
190  if (dict_.empty())
191  return static_cast<Wrapper &>(*this);
192 
193  if (other.dict_.empty()) {
194  dict_.clear();
195  return static_cast<Wrapper &>(*this);
196  }
197 
198  Vec zero_v(vec_size, 0);
199  // ! other is a just constant term
200  if (other.dict_.size() == 1
201  and other.dict_.find(zero_v) != other.dict_.end()) {
202  auto t = other.dict_.begin();
203  for (auto &i1 : dict_)
204  i1.second *= t->second;
205  return static_cast<Wrapper &>(*this);
206  }
207 
208  Wrapper res = Wrapper::mul(static_cast<Wrapper &>(*this), other);
209  res.dict_.swap(this->dict_);
210  return static_cast<Wrapper &>(*this);
211  }
212 
213  bool operator==(const Wrapper &other) const
214  {
215  return dict_ == other.dict_;
216  }
217 
218  bool operator!=(const Wrapper &other) const
219  {
220  return not(*this == other);
221  }
222 
223  const Dict &get_dict() const
224  {
225  return dict_;
226  }
227 
228  bool empty() const
229  {
230  return dict_.empty();
231  }
232 
233  Value get_coeff(Vec &x) const
234  {
235  auto ite = dict_.find(x);
236  if (ite != dict_.end())
237  return ite->second;
238  return Value(0);
239  }
240 
241  Wrapper translate(const vec_uint &translator, unsigned int size) const
242  {
243  SYMENGINE_ASSERT(translator.size() == vec_size)
244  SYMENGINE_ASSERT(size >= vec_size)
245 
246  Dict d;
247 
248  for (auto it : dict_) {
249  Vec changed;
250  changed.resize(size, 0);
251  for (unsigned int i = 0; i < vec_size; i++)
252  changed[translator[i]] = it.first[i];
253  d.insert({changed, it.second});
254  }
255 
256  return Wrapper(std::move(d), size);
257  }
258 };
259 
260 class MIntDict : public UDictWrapper<vec_uint, integer_class, MIntDict>
261 {
262 public:
263  MIntDict(unsigned int s) SYMENGINE_NOEXCEPT : UDictWrapper(s) {}
264 
265  MIntDict() SYMENGINE_NOEXCEPT {}
266 
267  ~MIntDict() SYMENGINE_NOEXCEPT {}
268 
269  MIntDict(MIntDict &&other) SYMENGINE_NOEXCEPT
270  : UDictWrapper(std::move(other))
271  {
272  }
273 
274  MIntDict(umap_uvec_mpz &&p, unsigned int sz)
275  : UDictWrapper(std::move(p), sz)
276  {
277  }
278 
279  MIntDict(const umap_uvec_mpz &p, unsigned int sz) : UDictWrapper(p, sz) {}
280 
281  MIntDict(const MIntDict &) = default;
282 
283  MIntDict &operator=(const MIntDict &) = default;
284 };
285 
286 class MExprDict : public UDictWrapper<vec_int, Expression, MExprDict>
287 {
288 public:
289  MExprDict(unsigned int s) SYMENGINE_NOEXCEPT : UDictWrapper(s) {}
290 
291  MExprDict() SYMENGINE_NOEXCEPT {}
292 
293  ~MExprDict() SYMENGINE_NOEXCEPT {}
294 
295  MExprDict(MExprDict &&other) SYMENGINE_NOEXCEPT
296  : UDictWrapper(std::move(other))
297  {
298  }
299 
300  MExprDict(umap_vec_expr &&p, unsigned int sz)
301  : UDictWrapper(std::move(p), sz)
302  {
303  }
304 
305  MExprDict(const umap_vec_expr &p, unsigned int sz) : UDictWrapper(p, sz) {}
306 
307  MExprDict(const MExprDict &) = default;
308 
309  MExprDict &operator=(const MExprDict &) = default;
310 };
311 
312 template <typename Container, typename Poly>
313 class MSymEnginePoly : public Basic
314 {
315 private:
316  Container poly_;
317  set_basic vars_;
318 
319 public:
320  typedef Container container_type;
321  typedef typename Container::coef_type coef_type;
322 
323  MSymEnginePoly(const set_basic &vars, Container &&dict)
324  : poly_{dict}, vars_{vars}
325  {
326  }
327 
328  static RCP<const Poly> from_container(const set_basic &vars, Container &&d)
329  {
330  return make_rcp<const Poly>(vars, std::move(d));
331  }
332 
333  int compare(const Basic &o) const override
334  {
335  SYMENGINE_ASSERT(is_a<Poly>(o))
336 
337  const Poly &s = down_cast<const Poly &>(o);
338 
339  if (vars_.size() != s.vars_.size())
340  return vars_.size() < s.vars_.size() ? -1 : 1;
341  if (poly_.dict_.size() != s.poly_.dict_.size())
342  return poly_.dict_.size() < s.poly_.dict_.size() ? -1 : 1;
343 
344  int cmp = unified_compare(vars_, s.vars_);
345  if (cmp != 0)
346  return cmp;
347 
348  return unified_compare(poly_.dict_, s.poly_.dict_);
349  }
350 
351  template <typename FromPoly>
352  static enable_if_t<is_a_UPoly<FromPoly>::value, RCP<const Poly>>
353  from_poly(const FromPoly &p)
354  {
355  Container c;
356  for (auto it = p.begin(); it != p.end(); ++it)
357  c.dict_[{it->first}] = it->second;
358  c.vec_size = 1;
359 
360  return Poly::from_container({p.get_var()}, std::move(c));
361  }
362 
363  static RCP<const Poly> from_dict(const vec_basic &v,
364  typename Container::dict_type &&d)
365  {
366  set_basic s;
367  std::map<RCP<const Basic>, unsigned int, RCPBasicKeyLess> m;
368  // Symbols in the vector are sorted by placeing them in an map image
369  // of the symbols in the map is their original location in the vector
370 
371  for (unsigned int i = 0; i < v.size(); i++) {
372  m.insert({v[i], i});
373  s.insert(v[i]);
374  }
375 
376  // vec_uint translator represents the permutation of the exponents
377  vec_uint trans(s.size());
378  auto mptr = m.begin();
379  for (unsigned int i = 0; i < s.size(); i++) {
380  trans[mptr->second] = i;
381  mptr++;
382  }
383 
384  Container x(std::move(d), numeric_cast<unsigned>(s.size()));
385  return Poly::from_container(
386  s, std::move(x.translate(trans, numeric_cast<unsigned>(s.size()))));
387  }
388 
389  static Container container_from_dict(const set_basic &s,
390  typename Container::dict_type &&d)
391  {
392  return Container(std::move(d), numeric_cast<unsigned>(s.size()));
393  }
394 
395  inline vec_basic get_args() const override
396  {
397  return {};
398  }
399 
400  inline const Container &get_poly() const
401  {
402  return poly_;
403  }
404 
405  inline const set_basic &get_vars() const
406  {
407  return vars_;
408  }
409 
410  bool __eq__(const Basic &o) const override
411  {
412  // TODO : fix for when vars are different, but there is an intersection
413  if (not is_a<Poly>(o))
414  return false;
415  const Poly &o_ = down_cast<const Poly &>(o);
416  // compare constants without regards to vars
417  if (1 == poly_.dict_.size() && 1 == o_.poly_.dict_.size()) {
418  if (poly_.dict_.begin()->second != o_.poly_.dict_.begin()->second)
419  return false;
420  if (poly_.dict_.begin()->first == o_.poly_.dict_.begin()->first
421  && unified_eq(vars_, o_.vars_))
422  return true;
423  typename Container::vec_type v1, v2;
424  v1.resize(vars_.size(), 0);
425  v2.resize(o_.vars_.size(), 0);
426  if (poly_.dict_.begin()->first == v1
427  || o_.poly_.dict_.begin()->first == v2)
428  return true;
429  return false;
430  } else if (0 == poly_.dict_.size() && 0 == o_.poly_.dict_.size()) {
431  return true;
432  } else {
433  return (unified_eq(vars_, o_.vars_)
434  && unified_eq(poly_.dict_, o_.poly_.dict_));
435  }
436  }
437 };
438 
439 class MIntPoly : public MSymEnginePoly<MIntDict, MIntPoly>
440 {
441 public:
442  MIntPoly(const set_basic &vars, MIntDict &&dict)
443  : MSymEnginePoly(vars, std::move(dict)){SYMENGINE_ASSIGN_TYPEID()}
444 
445  IMPLEMENT_TYPEID(SYMENGINE_MINTPOLY)
446 
447  hash_t __hash__() const override;
448  RCP<const Basic> as_symbolic() const;
449 
450  integer_class eval(
451  std::map<RCP<const Basic>, integer_class, RCPBasicKeyLess> &vals) const;
452 };
453 
454 class MExprPoly : public MSymEnginePoly<MExprDict, MExprPoly>
455 {
456 public:
457  MExprPoly(const set_basic &vars, MExprDict &&dict)
458  : MSymEnginePoly(vars, std::move(dict)){SYMENGINE_ASSIGN_TYPEID()}
459 
460  IMPLEMENT_TYPEID(SYMENGINE_MEXPRPOLY)
461 
462  hash_t __hash__() const override;
463  RCP<const Basic> as_symbolic() const;
464  Expression
465  eval(std::map<RCP<const Basic>, Expression, RCPBasicKeyLess> &vals) const;
466 };
467 
468 // reconciles the positioning of the exponents in the vectors in the
469 // Dict dict_ of the arguments with the positioning of the exponents in
470 // the correspondng vectors of the output of the function. f1 and f2 are
471 // vectors whose indices are the positions in the arguments and whose values
472 // are the positions in the output. set_basic s is the set of symbols of
473 // the output, and s1 and s2 are the sets of the symbols of the inputs.
474 unsigned int reconcile(vec_uint &v1, vec_uint &v2, set_basic &s,
475  const set_basic &s1, const set_basic &s2);
476 
477 template <typename Poly, typename Container>
478 set_basic get_translated_container(Container &x, Container &y, const Poly &a,
479  const Poly &b)
480 {
481  vec_uint v1, v2;
482  set_basic s;
483 
484  unsigned int sz = reconcile(v1, v2, s, a.get_vars(), b.get_vars());
485  x = a.get_poly().translate(v1, sz);
486  y = b.get_poly().translate(v2, sz);
487 
488  return s;
489 }
490 
491 template <typename Poly>
492 RCP<const Poly> add_mpoly(const Poly &a, const Poly &b)
493 {
494  typename Poly::container_type x, y;
495  set_basic s = get_translated_container(x, y, a, b);
496  x += y;
497  return Poly::from_container(s, std::move(x));
498 }
499 
500 template <typename Poly>
501 RCP<const Poly> sub_mpoly(const Poly &a, const Poly &b)
502 {
503  typename Poly::container_type x, y;
504  set_basic s = get_translated_container(x, y, a, b);
505  x -= y;
506  return Poly::from_container(s, std::move(x));
507 }
508 
509 template <typename Poly>
510 RCP<const Poly> mul_mpoly(const Poly &a, const Poly &b)
511 {
512  typename Poly::container_type x, y;
513  set_basic s = get_translated_container(x, y, a, b);
514  x *= y;
515  return Poly::from_container(s, std::move(x));
516 }
517 
518 template <typename Poly>
519 RCP<const Poly> neg_mpoly(const Poly &a)
520 {
521  auto x = a.get_poly();
522  return Poly::from_container(a.get_vars(), std::move(-x));
523 }
524 
525 template <typename Poly>
526 RCP<const Poly> pow_mpoly(const Poly &a, unsigned int n)
527 {
528  auto x = a.get_poly();
529  return Poly::from_container(a.get_vars(), Poly::container_type::pow(x, n));
530 }
531 } // namespace SymEngine
532 
533 #endif
#define IMPLEMENT_TYPEID(SYMENGINE_ID)
Inline members and functions.
Definition: basic.h:340
T begin(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
hash_t __hash__() const override
hash_t __hash__() const override
vec_basic get_args() const override
Returns the list of arguments.
bool __eq__(const Basic &o) const override
Test equality.
int compare(const Basic &o) const override
T clear(T... args)
T empty(T... args)
T end(T... args)
T erase(T... args)
T find(T... args)
T insert(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
int unified_compare(const T &a, const T &b)
Definition: dict.h:205
T size(T... args)
Our less operator (<):
Definition: basic.h:228