Loading...
Searching...
No Matches
cse.cpp
1#include <symengine/basic.h>
2#include <symengine/add.h>
3#include <symengine/mul.h>
5#include <symengine/visitor.h>
6
7#include <queue>
8
9namespace SymEngine
10{
11umap_basic_basic opt_cse(const vec_basic &exprs);
12void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
13 const vec_basic &exprs, umap_basic_basic &opt_subs);
14
16{
17
18public:
20 value_numbers;
21 vec_basic value_number_to_value;
22 std::vector<std::set<unsigned>> arg_to_funcset;
23 std::vector<std::set<unsigned>> func_to_argset;
24
25public:
27 const std::vector<std::pair<RCP<const Basic>, vec_basic>> &funcs)
28 {
29 arg_to_funcset.resize(funcs.size());
30 for (unsigned func_i = 0; func_i < funcs.size(); func_i++) {
32 for (auto &func_arg : funcs[func_i].second) {
33 unsigned arg_number = get_or_add_value_number(func_arg);
34 func_argset.insert(arg_number);
35 arg_to_funcset[arg_number].insert(func_i);
36 }
37 func_to_argset.push_back(func_argset);
38 }
39 }
40
41 template <typename Container>
42 vec_basic get_args_in_value_order(Container &argset)
43 {
44 vec_basic v;
45 for (unsigned i : argset) {
46 v.push_back(value_number_to_value[i]);
47 }
48 return v;
49 }
50
51 unsigned get_or_add_value_number(RCP<const Basic> value)
52 {
53 unsigned nvalues = numeric_cast<unsigned>(value_numbers.size());
54 auto ret = value_numbers.insert(std::make_pair(value, nvalues));
55 bool inserted = ret.second;
56 if (inserted) {
57 value_number_to_value.push_back(value);
58 arg_to_funcset.push_back(std::set<unsigned>());
59 return nvalues;
60 } else {
61 return ret.first->second;
62 }
63 }
64
65 void stop_arg_tracking(unsigned func_i)
66 {
67 for (unsigned arg : func_to_argset[func_i]) {
68 arg_to_funcset[arg].erase(func_i);
69 }
70 }
71
72 /*
73 Return a dict whose keys are function numbers. The entries of the dict
74 are the number of arguments said function has in common with `argset`.
75 Entries have at least 2 items in common.
76 */
78 get_common_arg_candidates(std::set<unsigned> &argset, unsigned min_func_i)
79 {
82 for (unsigned arg : argset) {
83 funcsets.push_back(arg_to_funcset[arg]);
84 }
85 // Sorted by size to make best use of the performance hack below.
86 std::sort(funcsets.begin(), funcsets.end(),
87 [](const std::set<unsigned> &a, const std::set<unsigned> &b) {
88 return a.size() < b.size();
89 });
90
91 for (unsigned i = 0; i < funcsets.size(); i++) {
92 auto &funcset = funcsets[i];
93 for (unsigned func_i : funcset) {
94 if (func_i >= min_func_i) {
95 count_map[func_i] += 1;
96 }
97 }
98 }
99
100 /*auto &largest_funcset = funcsets[funcsets.size() - 1];
101
102 // We pick the smaller of the two containers to iterate over to
103 // reduce the number of items we have to look at.
104
105 if (largest_funcset.size() < count_map.size()) {
106 for (unsigned func_i : largest_funcset) {
107 if (count_map[func_i] < 1) {
108 continue;
109 }
110 if (count_map.find(func_i) != count_map.end()) {
111 count_map[func_i] += 1;
112 }
113 }
114 } else {
115 for (auto &count_map_pair : count_map) {
116 unsigned func_i = count_map_pair.first;
117 if (count_map[func_i] < 1) {
118 continue;
119 }
120 if (largest_funcset.find(func_i) != largest_funcset.end()) {
121 count_map[func_i] += 1;
122 }
123 }
124 }*/
125 auto iter = count_map.begin();
126 for (; iter != count_map.end();) {
127 if (iter->second >= 2) {
128 ++iter;
129 } else {
130 count_map.erase(iter++);
131 }
132 }
133 return count_map;
134 }
135
136 template <typename Container1, typename Container2>
138 get_subset_candidates(const Container1 &argset,
140 {
142 for (auto f : restrict_to_funcset) {
144 }
147 for (const auto &arg : argset) {
149 arg_to_funcset[arg].begin(),
150 arg_to_funcset[arg].end(),
153 intersect_result.clear();
154 }
155 return indices;
156 }
157
158 void update_func_argset(unsigned func_i,
160 {
161 // Update a function with a new set of arguments.
162 auto &old_args = func_to_argset[func_i];
163
165 std::set_difference(old_args.begin(), old_args.end(), new_args.begin(),
166 new_args.end(), std::inserter(diff, diff.begin()));
167
168 for (auto &deleted_arg : diff) {
169 arg_to_funcset[deleted_arg].erase(func_i);
170 }
171
172 diff.clear();
173 std::set_difference(new_args.begin(), new_args.end(), old_args.begin(),
174 old_args.end(), std::inserter(diff, diff.begin()));
175
176 for (auto &added_arg : diff) {
177 arg_to_funcset[added_arg].insert(func_i);
178 }
179
180 func_to_argset[func_i].clear();
181 func_to_argset[func_i].insert(new_args.begin(), new_args.end());
182 }
183};
184
186 const std::vector<unsigned> &b)
187{
189 std::set_difference(a.begin(), a.end(), b.begin(), b.end(),
190 std::inserter(diff, diff.begin()));
191 return diff;
192}
193
194void add_to_sorted_vec(std::vector<unsigned> &vec, unsigned number)
195{
196 if (std::find(vec.begin(), vec.end(), number) == vec.end()) {
197 // Add number if not found
198 vec.insert(std::upper_bound(vec.begin(), vec.end(), number), number);
199 }
200}
201
202void match_common_args(const std::string &func_class, const vec_basic &funcs_,
203 umap_basic_basic &opt_subs)
204{
206 for (auto &b : funcs_) {
207 funcs.push_back(std::make_pair(b, b->get_args()));
208 }
209 std::sort(funcs.begin(), funcs.end(),
210 [](const std::pair<RCP<const Basic>, vec_basic> &a,
211 const std::pair<RCP<const Basic>, vec_basic> &b) {
212 return a.second.size() < b.second.size();
213 });
214
215 auto arg_tracker = FuncArgTracker(funcs);
216
219
220 for (unsigned i = 0; i < funcs.size(); i++) {
221 common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
222 arg_tracker.func_to_argset[i], i + 1);
223
225 for (auto it = common_arg_candidates_counts.begin();
226 it != common_arg_candidates_counts.end(); ++it) {
228 }
229
230 // Sort the candidates in order of match size.
231 // This makes us try combining smaller matches first.
233 [&](unsigned a, unsigned b) {
234 if (common_arg_candidates_counts[a]
235 == common_arg_candidates_counts[b]) {
236 return a < b;
237 }
240 });
241
242 while (common_arg_candidates.size() > 0) {
243 unsigned j = common_arg_candidates.front();
244 common_arg_candidates.pop_front();
246
247 std::set_intersection(arg_tracker.func_to_argset[i].begin(),
248 arg_tracker.func_to_argset[i].end(),
249 arg_tracker.func_to_argset[j].begin(),
250 arg_tracker.func_to_argset[j].end(),
252
253 if (com_args.size() <= 1) {
254 // This may happen if a set of common arguments was already
255 // combined in a previous iteration.
256 continue;
257 }
258
260 = set_diff(arg_tracker.func_to_argset[i], com_args);
261
262 unsigned com_func_number;
263
264 if (diff_i.size() > 0) {
265 // com_func needs to be unevaluated to allow for recursive
266 // matches.
267 auto com_func = function_symbol(
268 func_class, arg_tracker.get_args_in_value_order(com_args));
269 com_func_number = arg_tracker.get_or_add_value_number(com_func);
270 add_to_sorted_vec(diff_i, com_func_number);
271 arg_tracker.update_func_argset(i, diff_i);
272 changed.insert(i);
273
274 } else {
275 // Treat the whole expression as a CSE.
276 //
277 // The reason this needs to be done is somewhat subtle. Within
278 // tree_cse(), to_eliminate only contains expressions that are
279 // seen more than once. The problem is unevaluated expressions
280 // do not compare equal to the evaluated equivalent. So
281 // tree_cse() won't mark funcs[i] as a CSE if we use an
282 // unevaluated version.
284 = arg_tracker.get_or_add_value_number(funcs[i].first);
285 }
286
288 = set_diff(arg_tracker.func_to_argset[j], com_args);
289 add_to_sorted_vec(diff_j, com_func_number);
290 arg_tracker.update_func_argset(j, diff_j);
291 changed.insert(j);
292
293 for (unsigned k : arg_tracker.get_subset_candidates(
296 = set_diff(arg_tracker.func_to_argset[k], com_args);
297 add_to_sorted_vec(diff_k, com_func_number);
298 arg_tracker.update_func_argset(k, diff_k);
299 changed.insert(k);
300 }
301 }
302 if (std::find(changed.begin(), changed.end(), i) != changed.end()) {
303 opt_subs[funcs[i].first] = function_symbol(
304 func_class, arg_tracker.get_args_in_value_order(
305 arg_tracker.func_to_argset[i]));
306 }
307 arg_tracker.stop_arg_tracking(i);
308 }
309}
310
311class OptsCSEVisitor : public BaseVisitor<OptsCSEVisitor>
312{
313public:
314 umap_basic_basic &opt_subs;
315 set_basic adds;
316 set_basic muls;
317 set_basic seen_subexp;
318 OptsCSEVisitor(umap_basic_basic &opt_subs_) : opt_subs(opt_subs_) {}
319 bool is_seen(const Basic &expr)
320 {
321 return (seen_subexp.find(expr.rcp_from_this()) != seen_subexp.end());
322 }
323 void bvisit(const Derivative &x)
324 {
325 return;
326 }
327 void bvisit(const Subs &x)
328 {
329 return;
330 }
331 void bvisit(const Add &x)
332 {
333 if (not is_seen(x)) {
334 seen_subexp.insert(x.rcp_from_this());
335 for (const auto &p : x.get_args()) {
336 p->accept(*this);
337 }
338 adds.insert(x.rcp_from_this());
339 }
340 }
341 void bvisit(const Pow &x)
342 {
343 if (not is_seen(x)) {
344 auto expr = x.rcp_from_this();
345 seen_subexp.insert(expr);
346 for (const auto &p : x.get_args()) {
347 p->accept(*this);
348 }
349 auto ex = x.get_exp();
350 if (is_a<Mul>(*ex)) {
351 ex = static_cast<const Mul &>(*ex).get_coef();
352 }
353 if (is_a_Number(*ex)
354 and static_cast<const Number &>(*ex).is_negative()) {
355 vec_basic v({pow(x.get_base(), neg(x.get_exp())), integer(-1)});
356 opt_subs[expr] = function_symbol("pow", v);
357 }
358 }
359 }
360 void bvisit(const Mul &x)
361 {
362 if (not is_seen(x)) {
363 auto expr = x.rcp_from_this();
364 seen_subexp.insert(expr);
365 for (const auto &p : x.get_args()) {
366 p->accept(*this);
367 }
368 if (x.get_coef()->is_negative()) {
369 auto neg_expr = neg(x.rcp_from_this());
370 if (not is_a<Symbol>(*neg_expr)) {
371 opt_subs[expr]
372 = function_symbol("mul", {integer(-1), neg_expr});
373 seen_subexp.insert(neg_expr);
374 expr = neg_expr;
375 }
376 }
377 if (is_a<Mul>(*expr)) {
378 muls.insert(expr);
379 }
380 }
381 }
382 void bvisit(const Basic &x)
383 {
384 auto v = x.get_args();
385 if (v.size() > 0 and not is_seen(x)) {
386 auto expr = x.rcp_from_this();
387 seen_subexp.insert(expr);
388 for (const auto &p : v) {
389 p->accept(*this);
390 }
391 }
392 }
393};
394
395vec_basic set_as_vec(const set_basic &s)
396{
397 vec_basic result;
398 for (auto &u : s) {
399 result.push_back(u);
400 }
401 return result;
402}
403
404umap_basic_basic opt_cse(const vec_basic &exprs)
405{
406 // Find optimization opportunities in Adds, Muls, Pows and negative
407 // coefficient Muls
408 umap_basic_basic opt_subs;
409 OptsCSEVisitor visitor(opt_subs);
410 for (auto &e : exprs) {
411 e->accept(visitor);
412 }
413
414 match_common_args("add", set_as_vec(visitor.adds), opt_subs);
415 match_common_args("mul", set_as_vec(visitor.muls), opt_subs);
416
417 return opt_subs;
418}
419
420class RebuildVisitor : public BaseVisitor<RebuildVisitor, TransformVisitor>
421{
422private:
423 umap_basic_basic &subs;
424 umap_basic_basic &opt_subs;
425 set_basic &to_eliminate;
426 set_basic &excluded_symbols;
427 vec_pair &replacements;
428 unsigned next_symbol_index = 0;
429
430public:
431 using TransformVisitor::bvisit;
432 using TransformVisitor::result_;
434 set_basic &to_eliminate_, set_basic &excluded_symbols_,
435 vec_pair &replacements_)
436 : subs(subs_), opt_subs(opt_subs_), to_eliminate(to_eliminate_),
437 excluded_symbols(excluded_symbols_), replacements(replacements_)
438 {
439 }
440 RCP<const Basic> apply(const RCP<const Basic> &orig_expr) override
441 {
442 RCP<const Basic> expr = orig_expr;
443 if (is_a_Atom(*expr)) {
444 return expr;
445 }
446
447 auto iter = subs.find(expr);
448 if (iter != subs.end()) {
449 return iter->second;
450 }
451 auto iter2 = opt_subs.find(expr);
452 if (iter2 != opt_subs.end()) {
453 expr = iter2->second;
454 }
455 expr->accept(*this);
456 auto new_expr = result_;
457 if (to_eliminate.find(orig_expr) != to_eliminate.end()) {
458 auto sym = next_symbol();
459 subs[orig_expr] = sym;
460 replacements.push_back(
461 std::pair<RCP<const Basic>, RCP<const Basic>>(sym, new_expr));
462 return sym;
463 }
464 return new_expr;
465 }
466 RCP<const Basic> next_symbol()
467 {
468 RCP<const Basic> sym = symbol("x" + to_string(next_symbol_index));
469 next_symbol_index++;
470 if (excluded_symbols.find(sym) == excluded_symbols.end()) {
471 return sym;
472 } else {
473 return next_symbol();
474 }
475 };
476 void bvisit(const FunctionSymbol &x)
477 {
478 auto &fargs = x.get_vec();
479 vec_basic newargs;
480 for (const auto &a : fargs) {
481 newargs.push_back(apply(a));
482 }
483 if (x.get_name() == "add") {
484 result_ = add(newargs);
485 } else if (x.get_name() == "mul") {
486 result_ = mul(newargs);
487 } else if (x.get_name() == "pow") {
488 result_ = pow(newargs[0], newargs[1]);
489 } else {
490 result_ = x.create(newargs);
491 }
492 }
493};
494
495void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
496 const vec_basic &exprs, umap_basic_basic &opt_subs)
497{
498 set_basic to_eliminate;
499 set_basic seen_subexp;
500 set_basic excluded_symbols;
501
502 std::function<void(RCP<const Basic> & expr)> find_repeated;
503 find_repeated = [&](RCP<const Basic> expr) -> void {
504 // Do not replace atoms
505 if (is_a_Number(*expr) or is_a<BooleanAtom>(*expr)) {
506 return;
507 }
508
509 if (is_a<Symbol>(*expr)) {
510 excluded_symbols.insert(expr);
511 }
512
513 if (seen_subexp.find(expr) != seen_subexp.end()) {
514 to_eliminate.insert(expr);
515 return;
516 }
517
518 seen_subexp.insert(expr);
519
520 auto iter = opt_subs.find(expr);
521 if (iter != opt_subs.end()) {
522 expr = iter->second;
523 }
524
525 vec_basic args = expr->get_args();
526
527 for (auto &arg : args) {
528 find_repeated(arg);
529 }
530 };
531
532 for (auto e : exprs) {
533 find_repeated(e);
534 }
535
536 umap_basic_basic subs;
537
538 RebuildVisitor rebuild_visitor(subs, opt_subs, to_eliminate,
539 excluded_symbols, replacements);
540
541 for (auto &e : exprs) {
542 auto reduced_e = rebuild_visitor.apply(e);
543 reduced_exprs.push_back(reduced_e);
544 }
545}
546
547void cse(vec_pair &replacements, vec_basic &reduced_exprs,
548 const vec_basic &exprs)
549{
550 // Find other optimization opportunities.
551 umap_basic_basic opt_subs = opt_cse(exprs);
552
553 // Main CSE algorithm.
554 tree_cse(replacements, reduced_exprs, exprs, opt_subs);
555}
556} // namespace SymEngine
Classes and functions relating to the binary operation of addition.
T back_inserter(T... args)
The base class for SymEngine.
T begin(T... args)
The base class for representing addition in symbolic expressions.
Definition add.h:27
vec_basic get_args() const override
Returns the arguments of the Add.
Definition add.cpp:397
The lowest unit of symbolic representation.
Definition basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
const std::string & get_name() const
Definition functions.h:654
RCP< const Basic > create(const vec_basic &x) const override
Method to construct classes with canonicalization.
vec_basic get_args() const override
Returns the list of arguments.
Definition mul.cpp:507
virtual bool is_negative() const =0
RCP< const Basic > get_base() const
Definition pow.h:37
vec_basic get_args() const override
Returns the list of arguments.
Definition pow.cpp:266
RCP< const Basic > get_exp() const
Definition pow.h:42
T clear(T... args)
T end(T... args)
T erase(T... args)
T find(T... args)
T insert(T... args)
T inserter(T... args)
T make_pair(T... args)
Main namespace for SymEngine package.
Definition add.cpp:19
void hash_combine(hash_t &seed, const T &v)
Definition basic-inl.h:95
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition mul.cpp:352
bool is_a_Atom(const Basic &b)
Returns true if b is an atom. i.e. b.get_args returns an empty vector.
Definition basic.cpp:95
RCP< const Symbol > symbol(const std::string &name)
inline version to return Symbol
Definition symbol.h:82
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition add.cpp:425
std::enable_if< std::is_integral< T >::value, RCP< constInteger > >::type integer(T i)
Definition integer.h:197
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
Definition mul.cpp:443
T push_back(T... args)
T resize(T... args)
T set_difference(T... args)
T set_intersection(T... args)
T size(T... args)
T sort(T... args)
Our comparison (==)
Definition basic.h:219
T upper_bound(T... args)