Loading...
Searching...
No Matches
msymenginepoly.h
1#ifndef SYMENGINE_POLYNOMIALS_MULTIVARIATE
2#define SYMENGINE_POLYNOMIALS_MULTIVARIATE
3
7#include <symengine/polys/uexprpoly.h>
8#include <symengine/symengine_casts.h>
9
10namespace SymEngine
11{
12
13template <typename Vec, typename Value, typename Wrapper>
15{
16public:
18 Dict dict_;
19 unsigned int vec_size;
20
21 typedef Vec vec_type;
22 typedef Value coef_type;
23 typedef Dict dict_type;
24
25 UDictWrapper(unsigned int s) SYMENGINE_NOEXCEPT
26 {
27 vec_size = s;
28 }
29
30 UDictWrapper() SYMENGINE_NOEXCEPT {}
31
32 ~UDictWrapper() SYMENGINE_NOEXCEPT {}
33
34 UDictWrapper(Dict &&p, unsigned int sz)
35 {
36 auto iter = p.begin();
37 while (iter != p.end()) {
38 if (iter->second == 0) {
39 auto toErase = iter;
40 iter++;
41 p.erase(toErase);
42 } else {
43 iter++;
44 }
45 }
46
47 dict_ = p;
48 vec_size = sz;
49 }
50
51 UDictWrapper(const Dict &p, unsigned int sz)
52 {
53 for (auto &iter : p) {
54 if (iter.second != Value(0))
55 dict_[iter.first] = iter.second;
56 }
57 vec_size = sz;
58 }
59
60 Wrapper &operator=(Wrapper &&other)
61 {
62 if (this != &other)
63 dict_ = std::move(other.dict_);
64 return static_cast<Wrapper &>(*this);
65 }
66
67 friend Wrapper operator+(const Wrapper &a, const Wrapper &b)
68 {
69 SYMENGINE_ASSERT(a.vec_size == b.vec_size)
70 Wrapper c = a;
71 c += b;
72 return c;
73 }
74
75 // both wrappers must have "aligned" vectors, ie same size
76 // and vector positions refer to the same generators
77 Wrapper &operator+=(const Wrapper &other)
78 {
79 SYMENGINE_ASSERT(vec_size == other.vec_size)
80
81 for (auto &iter : other.dict_) {
82 auto t = dict_.find(iter.first);
83 if (t != dict_.end()) {
84 t->second += iter.second;
85 if (t->second == 0)
86 dict_.erase(t);
87 } else {
88 dict_.insert(t, {iter.first, iter.second});
89 }
90 }
91 return static_cast<Wrapper &>(*this);
92 }
93
94 friend Wrapper operator-(const Wrapper &a, const Wrapper &b)
95 {
96 SYMENGINE_ASSERT(a.vec_size == b.vec_size)
97
98 Wrapper c = a;
99 c -= b;
100 return c;
101 }
102
103 Wrapper operator-() const
104 {
105 auto c = *this;
106 for (auto &iter : c.dict_)
107 iter.second *= -1;
108 return static_cast<Wrapper &>(c);
109 }
110
111 // both wrappers must have "aligned" vectors, ie same size
112 // and vector positions refer to the same generators
113 Wrapper &operator-=(const Wrapper &other)
114 {
115 SYMENGINE_ASSERT(vec_size == other.vec_size)
116
117 for (auto &iter : other.dict_) {
118 auto t = dict_.find(iter.first);
119 if (t != dict_.end()) {
120 t->second -= iter.second;
121 if (t->second == 0)
122 dict_.erase(t);
123 } else {
124 dict_.insert(t, {iter.first, -iter.second});
125 }
126 }
127 return static_cast<Wrapper &>(*this);
128 }
129
130 static Wrapper mul(const Wrapper &a, const Wrapper &b)
131 {
132 SYMENGINE_ASSERT(a.vec_size == b.vec_size)
133
134 Wrapper p(a.vec_size);
135 for (auto const &a_ : a.dict_) {
136 for (auto const &b_ : b.dict_) {
137
138 Vec target(a.vec_size, 0);
139 for (unsigned int i = 0; i < a.vec_size; i++)
140 target[i] = a_.first[i] + b_.first[i];
141
142 if (p.dict_.find(target) == p.dict_.end()) {
143 p.dict_.insert({target, a_.second * b_.second});
144 } else {
145 p.dict_.find(target)->second += a_.second * b_.second;
146 }
147 }
148 }
149
150 for (auto it = p.dict_.begin(); it != p.dict_.end();) {
151 if (it->second == 0) {
152 p.dict_.erase(it++);
153 } else {
154 ++it;
155 }
156 }
157 return p;
158 }
159
160 static Wrapper pow(const Wrapper &a, unsigned int p)
161 {
162 Wrapper tmp = a, res(a.vec_size);
163
164 Vec zero_v(a.vec_size, 0);
165 res.dict_[zero_v] = 1_z;
166
167 while (p != 1) {
168 if (p % 2 == 0) {
169 tmp = tmp * tmp;
170 } else {
171 res = res * tmp;
172 tmp = tmp * tmp;
173 }
174 p >>= 1;
175 }
176
177 return (res * tmp);
178 }
179
180 friend Wrapper operator*(const Wrapper &a, const Wrapper &b)
181 {
182 SYMENGINE_ASSERT(a.vec_size == b.vec_size)
183 return Wrapper::mul(a, b);
184 }
185
186 Wrapper &operator*=(const Wrapper &other)
187 {
188 SYMENGINE_ASSERT(vec_size == other.vec_size)
189
190 if (dict_.empty())
191 return static_cast<Wrapper &>(*this);
192
193 if (other.dict_.empty()) {
194 dict_.clear();
195 return static_cast<Wrapper &>(*this);
196 }
197
198 Vec zero_v(vec_size, 0);
199 // ! other is a just constant term
200 if (other.dict_.size() == 1
201 and other.dict_.find(zero_v) != other.dict_.end()) {
202 auto t = other.dict_.begin();
203 for (auto &i1 : dict_)
204 i1.second *= t->second;
205 return static_cast<Wrapper &>(*this);
206 }
207
208 Wrapper res = Wrapper::mul(static_cast<Wrapper &>(*this), other);
209 res.dict_.swap(this->dict_);
210 return static_cast<Wrapper &>(*this);
211 }
212
213 bool operator==(const Wrapper &other) const
214 {
215 return dict_ == other.dict_;
216 }
217
218 bool operator!=(const Wrapper &other) const
219 {
220 return not(*this == other);
221 }
222
223 const Dict &get_dict() const
224 {
225 return dict_;
226 }
227
228 bool empty() const
229 {
230 return dict_.empty();
231 }
232
233 Value get_coeff(Vec &x) const
234 {
235 auto ite = dict_.find(x);
236 if (ite != dict_.end())
237 return ite->second;
238 return Value(0);
239 }
240
241 Wrapper translate(const vec_uint &translator, unsigned int size) const
242 {
243 SYMENGINE_ASSERT(translator.size() == vec_size)
244 SYMENGINE_ASSERT(size >= vec_size)
245
246 Dict d;
247
248 for (auto it : dict_) {
249 Vec changed;
250 changed.resize(size, 0);
251 for (unsigned int i = 0; i < vec_size; i++)
252 changed[translator[i]] = it.first[i];
253 d.insert({changed, it.second});
254 }
255
256 return Wrapper(std::move(d), size);
257 }
258};
259
260class MIntDict : public UDictWrapper<vec_uint, integer_class, MIntDict>
261{
262public:
263 MIntDict(unsigned int s) SYMENGINE_NOEXCEPT : UDictWrapper(s) {}
264
265 MIntDict() SYMENGINE_NOEXCEPT {}
266
267 ~MIntDict() SYMENGINE_NOEXCEPT {}
268
269 MIntDict(MIntDict &&other) SYMENGINE_NOEXCEPT
270 : UDictWrapper(std::move(other))
271 {
272 }
273
274 MIntDict(umap_uvec_mpz &&p, unsigned int sz)
275 : UDictWrapper(std::move(p), sz)
276 {
277 }
278
279 MIntDict(const umap_uvec_mpz &p, unsigned int sz) : UDictWrapper(p, sz) {}
280
281 MIntDict(const MIntDict &) = default;
282
283 MIntDict &operator=(const MIntDict &) = default;
284};
285
286class MExprDict : public UDictWrapper<vec_int, Expression, MExprDict>
287{
288public:
289 MExprDict(unsigned int s) SYMENGINE_NOEXCEPT : UDictWrapper(s) {}
290
291 MExprDict() SYMENGINE_NOEXCEPT {}
292
293 ~MExprDict() SYMENGINE_NOEXCEPT {}
294
295 MExprDict(MExprDict &&other) SYMENGINE_NOEXCEPT
296 : UDictWrapper(std::move(other))
297 {
298 }
299
300 MExprDict(umap_vec_expr &&p, unsigned int sz)
301 : UDictWrapper(std::move(p), sz)
302 {
303 }
304
305 MExprDict(const umap_vec_expr &p, unsigned int sz) : UDictWrapper(p, sz) {}
306
307 MExprDict(const MExprDict &) = default;
308
309 MExprDict &operator=(const MExprDict &) = default;
310};
311
312template <typename Container, typename Poly>
313class MSymEnginePoly : public Basic
314{
315private:
316 Container poly_;
317 set_basic vars_;
318
319public:
320 typedef Container container_type;
321 typedef typename Container::coef_type coef_type;
322
323 MSymEnginePoly(const set_basic &vars, Container &&dict)
324 : poly_{dict}, vars_{vars}
325 {
326 }
327
328 static RCP<const Poly> from_container(const set_basic &vars, Container &&d)
329 {
330 return make_rcp<const Poly>(vars, std::move(d));
331 }
332
333 int compare(const Basic &o) const override
334 {
335 SYMENGINE_ASSERT(is_a<Poly>(o))
336
337 const Poly &s = down_cast<const Poly &>(o);
338
339 if (vars_.size() != s.vars_.size())
340 return vars_.size() < s.vars_.size() ? -1 : 1;
341 if (poly_.dict_.size() != s.poly_.dict_.size())
342 return poly_.dict_.size() < s.poly_.dict_.size() ? -1 : 1;
343
344 int cmp = unified_compare(vars_, s.vars_);
345 if (cmp != 0)
346 return cmp;
347
348 return unified_compare(poly_.dict_, s.poly_.dict_);
349 }
350
351 template <typename FromPoly>
352 static enable_if_t<is_a_UPoly<FromPoly>::value, RCP<const Poly>>
353 from_poly(const FromPoly &p)
354 {
355 Container c;
356 for (auto it = p.begin(); it != p.end(); ++it)
357 c.dict_[{it->first}] = it->second;
358 c.vec_size = 1;
359
360 return Poly::from_container({p.get_var()}, std::move(c));
361 }
362
363 static RCP<const Poly> from_dict(const vec_basic &v,
364 typename Container::dict_type &&d)
365 {
366 set_basic s;
367 std::map<RCP<const Basic>, unsigned int, RCPBasicKeyLess> m;
368 // Symbols in the vector are sorted by placeing them in an map image
369 // of the symbols in the map is their original location in the vector
370
371 for (unsigned int i = 0; i < v.size(); i++) {
372 m.insert({v[i], i});
373 s.insert(v[i]);
374 }
375
376 // vec_uint translator represents the permutation of the exponents
377 vec_uint trans(s.size());
378 auto mptr = m.begin();
379 for (unsigned int i = 0; i < s.size(); i++) {
380 trans[mptr->second] = i;
381 mptr++;
382 }
383
384 Container x(std::move(d), numeric_cast<unsigned>(s.size()));
385 return Poly::from_container(
386 s, std::move(x.translate(trans, numeric_cast<unsigned>(s.size()))));
387 }
388
389 static Container container_from_dict(const set_basic &s,
390 typename Container::dict_type &&d)
391 {
392 return Container(std::move(d), numeric_cast<unsigned>(s.size()));
393 }
394
395 inline vec_basic get_args() const override
396 {
397 return {};
398 }
399
400 inline const Container &get_poly() const
401 {
402 return poly_;
403 }
404
405 inline const set_basic &get_vars() const
406 {
407 return vars_;
408 }
409
410 bool __eq__(const Basic &o) const override
411 {
412 // TODO : fix for when vars are different, but there is an intersection
413 if (not is_a<Poly>(o))
414 return false;
415 const Poly &o_ = down_cast<const Poly &>(o);
416 // compare constants without regards to vars
417 if (1 == poly_.dict_.size() && 1 == o_.poly_.dict_.size()) {
418 if (poly_.dict_.begin()->second != o_.poly_.dict_.begin()->second)
419 return false;
420 if (poly_.dict_.begin()->first == o_.poly_.dict_.begin()->first
421 && unified_eq(vars_, o_.vars_))
422 return true;
423 typename Container::vec_type v1, v2;
424 v1.resize(vars_.size(), 0);
425 v2.resize(o_.vars_.size(), 0);
426 if (poly_.dict_.begin()->first == v1
427 || o_.poly_.dict_.begin()->first == v2)
428 return true;
429 return false;
430 } else if (0 == poly_.dict_.size() && 0 == o_.poly_.dict_.size()) {
431 return true;
432 } else {
433 return (unified_eq(vars_, o_.vars_)
434 && unified_eq(poly_.dict_, o_.poly_.dict_));
435 }
436 }
437};
438
439class MIntPoly : public MSymEnginePoly<MIntDict, MIntPoly>
440{
441public:
442 MIntPoly(const set_basic &vars, MIntDict &&dict)
443 : MSymEnginePoly(vars, std::move(dict)){SYMENGINE_ASSIGN_TYPEID()}
444
445 IMPLEMENT_TYPEID(SYMENGINE_MINTPOLY)
446
447 hash_t __hash__() const override;
448 RCP<const Basic> as_symbolic() const;
449
450 integer_class eval(
451 std::map<RCP<const Basic>, integer_class, RCPBasicKeyLess> &vals) const;
452};
453
454class MExprPoly : public MSymEnginePoly<MExprDict, MExprPoly>
455{
456public:
457 MExprPoly(const set_basic &vars, MExprDict &&dict)
458 : MSymEnginePoly(vars, std::move(dict)){SYMENGINE_ASSIGN_TYPEID()}
459
460 IMPLEMENT_TYPEID(SYMENGINE_MEXPRPOLY)
461
462 hash_t __hash__() const override;
463 RCP<const Basic> as_symbolic() const;
465 eval(std::map<RCP<const Basic>, Expression, RCPBasicKeyLess> &vals) const;
466};
467
468// reconciles the positioning of the exponents in the vectors in the
469// Dict dict_ of the arguments with the positioning of the exponents in
470// the correspondng vectors of the output of the function. f1 and f2 are
471// vectors whose indices are the positions in the arguments and whose values
472// are the positions in the output. set_basic s is the set of symbols of
473// the output, and s1 and s2 are the sets of the symbols of the inputs.
474unsigned int reconcile(vec_uint &v1, vec_uint &v2, set_basic &s,
475 const set_basic &s1, const set_basic &s2);
476
477template <typename Poly, typename Container>
478set_basic get_translated_container(Container &x, Container &y, const Poly &a,
479 const Poly &b)
480{
481 vec_uint v1, v2;
482 set_basic s;
483
484 unsigned int sz = reconcile(v1, v2, s, a.get_vars(), b.get_vars());
485 x = a.get_poly().translate(v1, sz);
486 y = b.get_poly().translate(v2, sz);
487
488 return s;
489}
490
491template <typename Poly>
492RCP<const Poly> add_mpoly(const Poly &a, const Poly &b)
493{
494 typename Poly::container_type x, y;
495 set_basic s = get_translated_container(x, y, a, b);
496 x += y;
497 return Poly::from_container(s, std::move(x));
498}
499
500template <typename Poly>
501RCP<const Poly> sub_mpoly(const Poly &a, const Poly &b)
502{
503 typename Poly::container_type x, y;
504 set_basic s = get_translated_container(x, y, a, b);
505 x -= y;
506 return Poly::from_container(s, std::move(x));
507}
508
509template <typename Poly>
510RCP<const Poly> mul_mpoly(const Poly &a, const Poly &b)
511{
512 typename Poly::container_type x, y;
513 set_basic s = get_translated_container(x, y, a, b);
514 x *= y;
515 return Poly::from_container(s, std::move(x));
516}
517
518template <typename Poly>
519RCP<const Poly> neg_mpoly(const Poly &a)
520{
521 auto x = a.get_poly();
522 return Poly::from_container(a.get_vars(), std::move(-x));
523}
524
525template <typename Poly>
526RCP<const Poly> pow_mpoly(const Poly &a, unsigned int n)
527{
528 auto x = a.get_poly();
529 return Poly::from_container(a.get_vars(), Poly::container_type::pow(x, n));
530}
531} // namespace SymEngine
532
533#endif
#define IMPLEMENT_TYPEID(SYMENGINE_ID)
Inline members and functions.
Definition: basic.h:340
T begin(T... args)
The lowest unit of symbolic representation.
Definition: basic.h:97
hash_t __hash__() const override
hash_t __hash__() const override
vec_basic get_args() const override
Returns the list of arguments.
bool __eq__(const Basic &o) const override
Test equality.
int compare(const Basic &o) const override
T clear(T... args)
T empty(T... args)
T end(T... args)
T erase(T... args)
T find(T... args)
T insert(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
int unified_compare(const T &a, const T &b)
Definition: dict.h:205
T size(T... args)
Our less operator (<):
Definition: basic.h:228