11 umap_basic_basic opt_cse(
const vec_basic &exprs);
12 void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
13 const vec_basic &exprs, umap_basic_basic &opt_subs);
29 arg_to_funcset.
resize(funcs.size());
30 for (
unsigned func_i = 0; func_i < funcs.size(); func_i++) {
32 for (
auto &func_arg : funcs[func_i].second) {
33 unsigned arg_number = get_or_add_value_number(func_arg);
34 func_argset.
insert(arg_number);
35 arg_to_funcset[arg_number].
insert(func_i);
41 template <
typename Container>
42 vec_basic get_args_in_value_order(Container &argset)
45 for (
unsigned i : argset) {
51 unsigned get_or_add_value_number(RCP<const Basic> value)
53 unsigned nvalues = numeric_cast<unsigned>(value_numbers.
size());
55 bool inserted = ret.second;
61 return ret.first->second;
65 void stop_arg_tracking(
unsigned func_i)
67 for (
unsigned arg : func_to_argset[func_i]) {
68 arg_to_funcset[arg].
erase(func_i);
82 for (
unsigned arg : argset) {
88 return a.size() < b.size();
91 for (
unsigned i = 0; i < funcsets.
size(); i++) {
92 auto &funcset = funcsets[i];
93 for (
unsigned func_i : funcset) {
94 if (func_i >= min_func_i) {
95 count_map[func_i] += 1;
125 auto iter = count_map.
begin();
126 for (; iter != count_map.
end();) {
127 if (iter->second >= 2) {
130 count_map.
erase(iter++);
136 template <
typename Container1,
typename Container2>
138 get_subset_candidates(
const Container1 &argset,
139 const Container2 &restrict_to_funcset)
142 for (
auto f : restrict_to_funcset) {
147 for (
const auto &arg : argset) {
149 arg_to_funcset[arg].
begin(),
150 arg_to_funcset[arg].
end(),
152 intersect_result.
swap(indices);
153 intersect_result.
clear();
158 void update_func_argset(
unsigned func_i,
162 auto &old_args = func_to_argset[func_i];
168 for (
auto &deleted_arg : diff) {
169 arg_to_funcset[deleted_arg].
erase(func_i);
176 for (
auto &added_arg : diff) {
177 arg_to_funcset[added_arg].
insert(func_i);
180 func_to_argset[func_i].
clear();
202 void match_common_args(
const std::string &func_class,
const vec_basic &funcs_,
203 umap_basic_basic &opt_subs)
206 for (
auto &b : funcs_) {
210 [](
const std::pair<RCP<const Basic>, vec_basic> &a,
211 const std::pair<RCP<const Basic>, vec_basic> &b) {
212 return a.second.size() < b.second.size();
215 auto arg_tracker = FuncArgTracker(funcs);
220 for (
unsigned i = 0; i < funcs.
size(); i++) {
221 common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
222 arg_tracker.func_to_argset[i], i + 1);
225 for (
auto it = common_arg_candidates_counts.
begin();
226 it != common_arg_candidates_counts.
end(); ++it) {
227 common_arg_candidates.
push_back(it->first);
233 [&](
unsigned a,
unsigned b) {
234 if (common_arg_candidates_counts[a]
235 == common_arg_candidates_counts[b]) {
238 return common_arg_candidates_counts[a]
239 < common_arg_candidates_counts[b];
242 while (common_arg_candidates.
size() > 0) {
243 unsigned j = common_arg_candidates.
front();
248 arg_tracker.func_to_argset[i].end(),
249 arg_tracker.func_to_argset[j].begin(),
250 arg_tracker.func_to_argset[j].end(),
253 if (com_args.
size() <= 1) {
260 = set_diff(arg_tracker.func_to_argset[i], com_args);
262 unsigned com_func_number;
264 if (diff_i.
size() > 0) {
267 auto com_func = function_symbol(
268 func_class, arg_tracker.get_args_in_value_order(com_args));
269 com_func_number = arg_tracker.get_or_add_value_number(com_func);
270 add_to_sorted_vec(diff_i, com_func_number);
271 arg_tracker.update_func_argset(i, diff_i);
284 = arg_tracker.get_or_add_value_number(funcs[i].first);
288 = set_diff(arg_tracker.func_to_argset[j], com_args);
289 add_to_sorted_vec(diff_j, com_func_number);
290 arg_tracker.update_func_argset(j, diff_j);
293 for (
unsigned k : arg_tracker.get_subset_candidates(
294 com_args, common_arg_candidates)) {
296 = set_diff(arg_tracker.func_to_argset[k], com_args);
297 add_to_sorted_vec(diff_k, com_func_number);
298 arg_tracker.update_func_argset(k, diff_k);
303 opt_subs[funcs[i].first] = function_symbol(
304 func_class, arg_tracker.get_args_in_value_order(
305 arg_tracker.func_to_argset[i]));
307 arg_tracker.stop_arg_tracking(i);
319 bool is_seen(
const Basic &expr)
327 void bvisit(
const Subs &x)
331 void bvisit(
const Add &x)
333 if (not is_seen(x)) {
335 for (
const auto &p : x.
get_args()) {
341 void bvisit(
const Pow &x)
343 if (not is_seen(x)) {
346 for (
const auto &p : x.
get_args()) {
350 if (is_a<Mul>(*ex)) {
351 ex =
static_cast<const Mul &
>(*ex).get_coef();
356 opt_subs[expr] = function_symbol(
"pow", v);
360 void bvisit(
const Mul &x)
362 if (not is_seen(x)) {
365 for (
const auto &p : x.
get_args()) {
368 if (x.get_coef()->is_negative()) {
370 if (not is_a<Symbol>(*neg_expr)) {
372 = function_symbol(
"mul", {
integer(-1), neg_expr});
373 seen_subexp.
insert(neg_expr);
377 if (is_a<Mul>(*expr)) {
382 void bvisit(
const Basic &x)
385 if (v.size() > 0 and not is_seen(x)) {
388 for (
const auto &p : v) {
404 umap_basic_basic opt_cse(
const vec_basic &exprs)
408 umap_basic_basic opt_subs;
409 OptsCSEVisitor visitor(opt_subs);
410 for (
auto &e : exprs) {
414 match_common_args(
"add", set_as_vec(visitor.adds), opt_subs);
415 match_common_args(
"mul", set_as_vec(visitor.muls), opt_subs);
428 unsigned next_symbol_index = 0;
431 using TransformVisitor::bvisit;
432 using TransformVisitor::result_;
436 : subs(subs_), opt_subs(opt_subs_), to_eliminate(to_eliminate_),
437 excluded_symbols(excluded_symbols_), replacements(replacements_)
440 RCP<const Basic> apply(
const RCP<const Basic> &orig_expr)
override
442 RCP<const Basic> expr = orig_expr;
447 auto iter = subs.find(expr);
448 if (iter != subs.end()) {
451 auto iter2 = opt_subs.
find(expr);
452 if (iter2 != opt_subs.
end()) {
453 expr = iter2->second;
456 auto new_expr = result_;
457 if (to_eliminate.
find(orig_expr) != to_eliminate.
end()) {
458 auto sym = next_symbol();
459 subs[orig_expr] = sym;
461 std::pair<RCP<const Basic>, RCP<const Basic>>(sym, new_expr));
466 RCP<const Basic> next_symbol()
468 RCP<const Basic> sym =
symbol(
"x" + to_string(next_symbol_index));
470 if (excluded_symbols.
find(sym) == excluded_symbols.
end()) {
473 return next_symbol();
478 auto &fargs = x.get_vec();
480 for (
const auto &a : fargs) {
484 result_ =
add(newargs);
486 result_ =
mul(newargs);
488 result_ = pow(newargs[0], newargs[1]);
490 result_ = x.
create(newargs);
503 find_repeated = [&](RCP<const Basic> expr) ->
void {
505 if (
is_a_Number(*expr) or is_a<BooleanAtom>(*expr)) {
509 if (is_a<Symbol>(*expr)) {
510 excluded_symbols.
insert(expr);
513 if (seen_subexp.
find(expr) != seen_subexp.
end()) {
514 to_eliminate.
insert(expr);
520 auto iter = opt_subs.
find(expr);
521 if (iter != opt_subs.
end()) {
525 vec_basic args = expr->get_args();
527 for (
auto &arg : args) {
532 for (
auto e : exprs) {
536 umap_basic_basic subs;
538 RebuildVisitor rebuild_visitor(subs, opt_subs, to_eliminate,
539 excluded_symbols, replacements);
541 for (
auto &e : exprs) {
542 auto reduced_e = rebuild_visitor.apply(e);
547 void cse(vec_pair &replacements, vec_basic &reduced_exprs,
548 const vec_basic &exprs)
551 umap_basic_basic opt_subs = opt_cse(exprs);
554 tree_cse(replacements, reduced_exprs, exprs, opt_subs);
Classes and functions relating to the binary operation of addition.
T back_inserter(T... args)
The base class for SymEngine.
The base class for representing addition in symbolic expressions.
vec_basic get_args() const override
Returns the arguments of the Add.
The lowest unit of symbolic representation.
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
RCP< const Basic > create(const vec_basic &x) const override
Method to construct classes with canonicalization.
const std::string & get_name() const
vec_basic get_args() const override
Returns the list of arguments.
virtual bool is_negative() const =0
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_exp() const
RCP< const Basic > get_base() const
Main namespace for SymEngine package.
bool is_a_Number(const Basic &b)
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
RCP< const Symbol > symbol(const std::string &name)
inline version to return Symbol
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
bool is_a_Atom(const Basic &b)
Returns true if b is an atom. i.e. b.get_args returns an empty vector.
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
T set_difference(T... args)
T set_intersection(T... args)