11 umap_basic_basic opt_cse(
const vec_basic &exprs);
12 void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
13 const vec_basic &exprs, umap_basic_basic &opt_subs);
21 vec_basic value_number_to_value;
22 std::vector<std::set<unsigned>> arg_to_funcset;
23 std::vector<std::set<unsigned>> func_to_argset;
27 const std::vector<std::pair<RCP<const Basic>, vec_basic>> &funcs)
29 arg_to_funcset.resize(funcs.size());
30 for (
unsigned func_i = 0; func_i < funcs.size(); func_i++) {
31 std::set<unsigned> func_argset;
32 for (
auto &func_arg : funcs[func_i].second) {
33 unsigned arg_number = get_or_add_value_number(func_arg);
34 func_argset.insert(arg_number);
35 arg_to_funcset[arg_number].insert(func_i);
37 func_to_argset.push_back(func_argset);
41 template <
typename Container>
42 vec_basic get_args_in_value_order(Container &argset)
45 for (
unsigned i : argset) {
46 v.push_back(value_number_to_value[i]);
51 unsigned get_or_add_value_number(RCP<const Basic> value)
53 unsigned nvalues = numeric_cast<unsigned>(value_numbers.size());
54 auto ret = value_numbers.insert(std::make_pair(value, nvalues));
55 bool inserted = ret.second;
57 value_number_to_value.push_back(value);
58 arg_to_funcset.push_back(std::set<unsigned>());
61 return ret.first->second;
65 void stop_arg_tracking(
unsigned func_i)
67 for (
unsigned arg : func_to_argset[func_i]) {
68 arg_to_funcset[arg].erase(func_i);
77 std::map<unsigned, unsigned>
78 get_common_arg_candidates(std::set<unsigned> &argset,
unsigned min_func_i)
80 std::map<unsigned, unsigned> count_map;
81 std::vector<std::set<unsigned>> funcsets;
82 for (
unsigned arg : argset) {
83 funcsets.push_back(arg_to_funcset[arg]);
86 std::sort(funcsets.begin(), funcsets.end(),
87 [](
const std::set<unsigned> &a,
const std::set<unsigned> &b) {
88 return a.size() < b.size();
91 for (
unsigned i = 0; i < funcsets.size(); i++) {
92 auto &funcset = funcsets[i];
93 for (
unsigned func_i : funcset) {
94 if (func_i >= min_func_i) {
95 count_map[func_i] += 1;
125 auto iter = count_map.begin();
126 for (; iter != count_map.end();) {
127 if (iter->second >= 2) {
130 count_map.erase(iter++);
136 template <
typename Container1,
typename Container2>
137 std::vector<unsigned>
138 get_subset_candidates(
const Container1 &argset,
139 const Container2 &restrict_to_funcset)
141 std::vector<unsigned> indices;
142 for (
auto f : restrict_to_funcset) {
143 indices.push_back(f);
145 std::sort(std::begin(indices), std::end(indices));
146 std::vector<unsigned> intersect_result;
147 for (
const auto &arg : argset) {
148 std::set_intersection(indices.begin(), indices.end(),
149 arg_to_funcset[arg].begin(),
150 arg_to_funcset[arg].end(),
151 std::back_inserter(intersect_result));
152 intersect_result.swap(indices);
153 intersect_result.clear();
158 void update_func_argset(
unsigned func_i,
159 const std::vector<unsigned> &new_args)
162 auto &old_args = func_to_argset[func_i];
164 std::set<unsigned> diff;
165 std::set_difference(old_args.begin(), old_args.end(), new_args.begin(),
166 new_args.end(), std::inserter(diff, diff.begin()));
168 for (
auto &deleted_arg : diff) {
169 arg_to_funcset[deleted_arg].erase(func_i);
173 std::set_difference(new_args.begin(), new_args.end(), old_args.begin(),
174 old_args.end(), std::inserter(diff, diff.begin()));
176 for (
auto &added_arg : diff) {
177 arg_to_funcset[added_arg].insert(func_i);
180 func_to_argset[func_i].clear();
181 func_to_argset[func_i].insert(new_args.begin(), new_args.end());
185 std::vector<unsigned> set_diff(
const std::set<unsigned> &a,
186 const std::vector<unsigned> &b)
188 std::vector<unsigned> diff;
189 std::set_difference(a.begin(), a.end(), b.begin(), b.end(),
190 std::inserter(diff, diff.begin()));
194 void add_to_sorted_vec(std::vector<unsigned> &vec,
unsigned number)
196 if (std::find(vec.begin(), vec.end(), number) == vec.end()) {
198 vec.insert(std::upper_bound(vec.begin(), vec.end(), number), number);
202 void match_common_args(
const std::string &func_class,
const vec_basic &funcs_,
203 umap_basic_basic &opt_subs)
205 std::vector<std::pair<RCP<const Basic>, vec_basic>> funcs;
206 for (
auto &b : funcs_) {
207 funcs.push_back(std::make_pair(b, b->get_args()));
209 std::sort(funcs.begin(), funcs.end(),
210 [](
const std::pair<RCP<const Basic>, vec_basic> &a,
211 const std::pair<RCP<const Basic>, vec_basic> &b) {
212 return a.second.size() < b.second.size();
215 auto arg_tracker = FuncArgTracker(funcs);
217 std::set<unsigned> changed;
218 std::map<unsigned, unsigned> common_arg_candidates_counts;
220 for (
unsigned i = 0; i < funcs.size(); i++) {
221 common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
222 arg_tracker.func_to_argset[i], i + 1);
224 std::deque<unsigned> common_arg_candidates;
225 for (
auto it = common_arg_candidates_counts.begin();
226 it != common_arg_candidates_counts.end(); ++it) {
227 common_arg_candidates.push_back(it->first);
232 std::sort(common_arg_candidates.begin(), common_arg_candidates.end(),
233 [&](
unsigned a,
unsigned b) {
234 if (common_arg_candidates_counts[a]
235 == common_arg_candidates_counts[b]) {
238 return common_arg_candidates_counts[a]
239 < common_arg_candidates_counts[b];
242 while (common_arg_candidates.size() > 0) {
243 unsigned j = common_arg_candidates.front();
244 common_arg_candidates.pop_front();
245 std::vector<unsigned> com_args;
247 std::set_intersection(arg_tracker.func_to_argset[i].begin(),
248 arg_tracker.func_to_argset[i].end(),
249 arg_tracker.func_to_argset[j].begin(),
250 arg_tracker.func_to_argset[j].end(),
251 std::back_inserter(com_args));
253 if (com_args.size() <= 1) {
259 std::vector<unsigned> diff_i
260 = set_diff(arg_tracker.func_to_argset[i], com_args);
262 unsigned com_func_number;
264 if (diff_i.size() > 0) {
267 auto com_func = function_symbol(
268 func_class, arg_tracker.get_args_in_value_order(com_args));
269 com_func_number = arg_tracker.get_or_add_value_number(com_func);
270 add_to_sorted_vec(diff_i, com_func_number);
271 arg_tracker.update_func_argset(i, diff_i);
284 = arg_tracker.get_or_add_value_number(funcs[i].first);
287 std::vector<unsigned> diff_j
288 = set_diff(arg_tracker.func_to_argset[j], com_args);
289 add_to_sorted_vec(diff_j, com_func_number);
290 arg_tracker.update_func_argset(j, diff_j);
293 for (
unsigned k : arg_tracker.get_subset_candidates(
294 com_args, common_arg_candidates)) {
295 std::vector<unsigned> diff_k
296 = set_diff(arg_tracker.func_to_argset[k], com_args);
297 add_to_sorted_vec(diff_k, com_func_number);
298 arg_tracker.update_func_argset(k, diff_k);
302 if (std::find(changed.begin(), changed.end(), i) != changed.end()) {
303 opt_subs[funcs[i].first] = function_symbol(
304 func_class, arg_tracker.get_args_in_value_order(
305 arg_tracker.func_to_argset[i]));
307 arg_tracker.stop_arg_tracking(i);
314 umap_basic_basic &opt_subs;
317 set_basic seen_subexp;
318 OptsCSEVisitor(umap_basic_basic &opt_subs_) : opt_subs(opt_subs_) {}
319 bool is_seen(
const Basic &expr)
321 return (seen_subexp.find(expr.
rcp_from_this()) != seen_subexp.end());
327 void bvisit(
const Subs &x)
331 void bvisit(
const Add &x)
333 if (not is_seen(x)) {
335 for (
const auto &p : x.
get_args()) {
341 void bvisit(
const Pow &x)
343 if (not is_seen(x)) {
345 seen_subexp.insert(expr);
346 for (
const auto &p : x.
get_args()) {
350 if (is_a<Mul>(*ex)) {
351 ex =
static_cast<const Mul &
>(*ex).get_coef();
356 opt_subs[expr] = function_symbol(
"pow", v);
360 void bvisit(
const Mul &x)
362 if (not is_seen(x)) {
364 seen_subexp.insert(expr);
365 for (
const auto &p : x.
get_args()) {
368 if (x.get_coef()->is_negative()) {
370 if (not is_a<Symbol>(*neg_expr)) {
372 = function_symbol(
"mul", {
integer(-1), neg_expr});
373 seen_subexp.insert(neg_expr);
377 if (is_a<Mul>(*expr)) {
382 void bvisit(
const Basic &x)
385 if (v.size() > 0 and not is_seen(x)) {
387 seen_subexp.insert(expr);
388 for (
const auto &p : v) {
395 vec_basic set_as_vec(
const set_basic &s)
404 umap_basic_basic opt_cse(
const vec_basic &exprs)
408 umap_basic_basic opt_subs;
409 OptsCSEVisitor visitor(opt_subs);
410 for (
auto &e : exprs) {
414 match_common_args(
"add", set_as_vec(visitor.adds), opt_subs);
415 match_common_args(
"mul", set_as_vec(visitor.muls), opt_subs);
423 umap_basic_basic &subs;
424 umap_basic_basic &opt_subs;
425 set_basic &to_eliminate;
426 set_basic &excluded_symbols;
427 vec_pair &replacements;
428 unsigned next_symbol_index = 0;
431 using TransformVisitor::bvisit;
432 using TransformVisitor::result_;
433 RebuildVisitor(umap_basic_basic &subs_, umap_basic_basic &opt_subs_,
434 set_basic &to_eliminate_, set_basic &excluded_symbols_,
435 vec_pair &replacements_)
436 : subs(subs_), opt_subs(opt_subs_), to_eliminate(to_eliminate_),
437 excluded_symbols(excluded_symbols_), replacements(replacements_)
440 RCP<const Basic> apply(
const RCP<const Basic> &orig_expr)
override
442 RCP<const Basic> expr = orig_expr;
447 auto iter = subs.find(expr);
448 if (iter != subs.end()) {
451 auto iter2 = opt_subs.find(expr);
452 if (iter2 != opt_subs.end()) {
453 expr = iter2->second;
456 auto new_expr = result_;
457 if (to_eliminate.find(orig_expr) != to_eliminate.end()) {
458 auto sym = next_symbol();
459 subs[orig_expr] = sym;
460 replacements.push_back(
461 std::pair<RCP<const Basic>, RCP<const Basic>>(sym, new_expr));
466 RCP<const Basic> next_symbol()
470 if (excluded_symbols.find(sym) == excluded_symbols.end()) {
473 return next_symbol();
478 auto &fargs = x.get_vec();
480 for (
const auto &a : fargs) {
481 newargs.push_back(apply(a));
484 result_ =
add(newargs);
486 result_ =
mul(newargs);
488 result_ = pow(newargs[0], newargs[1]);
490 result_ = x.
create(newargs);
495 void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
496 const vec_basic &exprs, umap_basic_basic &opt_subs)
498 set_basic to_eliminate;
499 set_basic seen_subexp;
500 set_basic excluded_symbols;
502 std::function<void(RCP<const Basic> & expr)> find_repeated;
503 find_repeated = [&](RCP<const Basic> expr) ->
void {
505 if (
is_a_Number(*expr) or is_a<BooleanAtom>(*expr)) {
509 if (is_a<Symbol>(*expr)) {
510 excluded_symbols.insert(expr);
513 if (seen_subexp.find(expr) != seen_subexp.end()) {
514 to_eliminate.insert(expr);
518 seen_subexp.insert(expr);
520 auto iter = opt_subs.find(expr);
521 if (iter != opt_subs.end()) {
525 vec_basic args = expr->get_args();
527 for (
auto &arg : args) {
532 for (
auto e : exprs) {
536 umap_basic_basic subs;
538 RebuildVisitor rebuild_visitor(subs, opt_subs, to_eliminate,
539 excluded_symbols, replacements);
541 for (
auto &e : exprs) {
542 auto reduced_e = rebuild_visitor.apply(e);
543 reduced_exprs.push_back(reduced_e);
547 void cse(vec_pair &replacements, vec_basic &reduced_exprs,
548 const vec_basic &exprs)
551 umap_basic_basic opt_subs = opt_cse(exprs);
554 tree_cse(replacements, reduced_exprs, exprs, opt_subs);
Classes and functions relating to the binary operation of addition.
The base class for SymEngine.
The base class for representing addition in symbolic expressions.
vec_basic get_args() const override
Returns the arguments of the Add.
The lowest unit of symbolic representation.
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
RCP< const Basic > create(const vec_basic &x) const override
Method to construct classes with canonicalization.
const std::string & get_name() const
vec_basic get_args() const override
Returns the list of arguments.
virtual bool is_negative() const =0
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_exp() const
RCP< const Basic > get_base() const
Main namespace for SymEngine package.
bool is_a_Number(const Basic &b)
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
RCP< const Symbol > symbol(const std::string &name)
inline version to return Symbol
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
bool is_a_Atom(const Basic &b)
Returns true if b is an atom. i.e. b.get_args returns an empty vector.
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
std::string to_string(const T &value)
workaround for MinGW bug
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.