11umap_basic_basic opt_cse(
const vec_basic &
exprs);
12void tree_cse(vec_pair &replacements, vec_basic &
reduced_exprs,
13 const vec_basic &
exprs, umap_basic_basic &opt_subs);
41 template <
typename Container>
45 for (
unsigned i :
argset) {
51 unsigned get_or_add_value_number(RCP<const Basic> value)
61 return ret.first->second;
65 void stop_arg_tracking(
unsigned func_i)
67 for (
unsigned arg : func_to_argset[
func_i]) {
88 return a.size() < b.size();
91 for (
unsigned i = 0; i <
funcsets.size(); i++) {
127 if (
iter->second >= 2) {
136 template <
typename Container1,
typename Container2>
150 arg_to_funcset[
arg].
end(),
158 void update_func_argset(
unsigned func_i,
203 umap_basic_basic &opt_subs)
210 [](
const std::pair<RCP<const Basic>, vec_basic> &a,
211 const std::pair<RCP<const Basic>, vec_basic> &b) {
212 return a.second.size() < b.second.size();
220 for (
unsigned i = 0; i <
funcs.size(); i++) {
233 [&](
unsigned a,
unsigned b) {
234 if (common_arg_candidates_counts[a]
235 == common_arg_candidates_counts[b]) {
293 for (
unsigned k :
arg_tracker.get_subset_candidates(
303 opt_subs[
funcs[i].first] = function_symbol(
319 bool is_seen(
const Basic &expr)
327 void bvisit(
const Subs &x)
331 void bvisit(
const Add &x)
333 if (not is_seen(x)) {
335 for (
const auto &p : x.
get_args()) {
341 void bvisit(
const Pow &x)
343 if (not is_seen(x)) {
346 for (
const auto &p : x.
get_args()) {
350 if (is_a<Mul>(*ex)) {
351 ex =
static_cast<const Mul &
>(*ex).get_coef();
356 opt_subs[expr] = function_symbol(
"pow", v);
360 void bvisit(
const Mul &x)
362 if (not is_seen(x)) {
365 for (
const auto &p : x.
get_args()) {
368 if (x.get_coef()->is_negative()) {
370 if (not is_a<Symbol>(*neg_expr)) {
372 = function_symbol(
"mul", {
integer(-1), neg_expr});
373 seen_subexp.
insert(neg_expr);
377 if (is_a<Mul>(*expr)) {
382 void bvisit(
const Basic &x)
385 if (v.size() > 0 and not is_seen(x)) {
388 for (
const auto &p : v) {
404umap_basic_basic opt_cse(
const vec_basic &exprs)
408 umap_basic_basic opt_subs;
409 OptsCSEVisitor visitor(opt_subs);
410 for (
auto &e : exprs) {
414 match_common_args(
"add", set_as_vec(visitor.adds), opt_subs);
415 match_common_args(
"mul", set_as_vec(visitor.muls), opt_subs);
428 unsigned next_symbol_index = 0;
431 using TransformVisitor::bvisit;
432 using TransformVisitor::result_;
436 : subs(subs_), opt_subs(opt_subs_), to_eliminate(to_eliminate_),
437 excluded_symbols(excluded_symbols_), replacements(replacements_)
440 RCP<const Basic> apply(
const RCP<const Basic> &orig_expr)
override
442 RCP<const Basic> expr = orig_expr;
447 auto iter = subs.find(expr);
448 if (iter != subs.end()) {
451 auto iter2 = opt_subs.
find(expr);
452 if (iter2 != opt_subs.
end()) {
453 expr = iter2->second;
456 auto new_expr = result_;
457 if (to_eliminate.
find(orig_expr) != to_eliminate.
end()) {
458 auto sym = next_symbol();
459 subs[orig_expr] = sym;
461 std::pair<RCP<const Basic>, RCP<const Basic>>(sym, new_expr));
466 RCP<const Basic> next_symbol()
468 RCP<const Basic> sym =
symbol(
"x" + to_string(next_symbol_index));
470 if (excluded_symbols.
find(sym) == excluded_symbols.
end()) {
473 return next_symbol();
478 auto &fargs = x.get_vec();
480 for (
const auto &a : fargs) {
484 result_ =
add(newargs);
486 result_ =
mul(newargs);
488 result_ = pow(newargs[0], newargs[1]);
490 result_ = x.
create(newargs);
503 find_repeated = [&](RCP<const Basic> expr) ->
void {
505 if (is_a_Number(*expr) or is_a<BooleanAtom>(*expr)) {
509 if (is_a<Symbol>(*expr)) {
510 excluded_symbols.
insert(expr);
513 if (seen_subexp.
find(expr) != seen_subexp.
end()) {
514 to_eliminate.
insert(expr);
520 auto iter = opt_subs.
find(expr);
521 if (iter != opt_subs.
end()) {
525 vec_basic args = expr->get_args();
527 for (
auto &arg : args) {
532 for (
auto e : exprs) {
536 umap_basic_basic subs;
538 RebuildVisitor rebuild_visitor(subs, opt_subs, to_eliminate,
539 excluded_symbols, replacements);
541 for (
auto &e : exprs) {
542 auto reduced_e = rebuild_visitor.apply(e);
547void cse(vec_pair &replacements, vec_basic &reduced_exprs,
548 const vec_basic &exprs)
551 umap_basic_basic opt_subs = opt_cse(exprs);
554 tree_cse(replacements, reduced_exprs, exprs, opt_subs);
Classes and functions relating to the binary operation of addition.
T back_inserter(T... args)
The base class for SymEngine.
The base class for representing addition in symbolic expressions.
vec_basic get_args() const override
Returns the arguments of the Add.
The lowest unit of symbolic representation.
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
const std::string & get_name() const
RCP< const Basic > create(const vec_basic &x) const override
Method to construct classes with canonicalization.
vec_basic get_args() const override
Returns the list of arguments.
virtual bool is_negative() const =0
RCP< const Basic > get_base() const
vec_basic get_args() const override
Returns the list of arguments.
RCP< const Basic > get_exp() const
Main namespace for SymEngine package.
void hash_combine(hash_t &seed, const T &v)
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
bool is_a_Atom(const Basic &b)
Returns true if b is an atom. i.e. b.get_args returns an empty vector.
RCP< const Symbol > symbol(const std::string &name)
inline version to return Symbol
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
std::enable_if< std::is_integral< T >::value, RCP< constInteger > >::type integer(T i)
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
T set_difference(T... args)
T set_intersection(T... args)