Program Listing for File cse.cpp¶
↰ Return to documentation for file (symengine/symengine/cse.cpp
)
#include <symengine/basic.h>
#include <symengine/add.h>
#include <symengine/mul.h>
#include <symengine/functions.h>
#include <symengine/visitor.h>
#include <queue>
namespace SymEngine
{
umap_basic_basic opt_cse(const vec_basic &exprs);
void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
const vec_basic &exprs, umap_basic_basic &opt_subs);
class FuncArgTracker
{
public:
std::unordered_map<RCP<const Basic>, unsigned, RCPBasicHash, RCPBasicKeyEq>
value_numbers;
vec_basic value_number_to_value;
std::vector<std::set<unsigned>> arg_to_funcset;
std::vector<std::set<unsigned>> func_to_argset;
public:
FuncArgTracker(
const std::vector<std::pair<RCP<const Basic>, vec_basic>> &funcs)
{
arg_to_funcset.resize(funcs.size());
for (unsigned func_i = 0; func_i < funcs.size(); func_i++) {
std::set<unsigned> func_argset;
for (auto &func_arg : funcs[func_i].second) {
unsigned arg_number = get_or_add_value_number(func_arg);
func_argset.insert(arg_number);
arg_to_funcset[arg_number].insert(func_i);
}
func_to_argset.push_back(func_argset);
}
}
template <typename Container>
vec_basic get_args_in_value_order(Container &argset)
{
vec_basic v;
for (unsigned i : argset) {
v.push_back(value_number_to_value[i]);
}
return v;
}
unsigned get_or_add_value_number(RCP<const Basic> value)
{
unsigned nvalues = numeric_cast<unsigned>(value_numbers.size());
auto ret = value_numbers.insert(std::make_pair(value, nvalues));
bool inserted = ret.second;
if (inserted) {
value_number_to_value.push_back(value);
arg_to_funcset.push_back(std::set<unsigned>());
return nvalues;
} else {
return ret.first->second;
}
}
void stop_arg_tracking(unsigned func_i)
{
for (unsigned arg : func_to_argset[func_i]) {
arg_to_funcset[arg].erase(func_i);
}
}
/*
Return a dict whose keys are function numbers. The entries of the dict
are the number of arguments said function has in common with `argset`.
Entries have at least 2 items in common.
*/
std::map<unsigned, unsigned>
get_common_arg_candidates(std::set<unsigned> &argset, unsigned min_func_i)
{
std::map<unsigned, unsigned> count_map;
std::vector<std::set<unsigned>> funcsets;
for (unsigned arg : argset) {
funcsets.push_back(arg_to_funcset[arg]);
}
// Sorted by size to make best use of the performance hack below.
std::sort(funcsets.begin(), funcsets.end(),
[](const std::set<unsigned> &a, const std::set<unsigned> &b) {
return a.size() < b.size();
});
for (unsigned i = 0; i < funcsets.size(); i++) {
auto &funcset = funcsets[i];
for (unsigned func_i : funcset) {
if (func_i >= min_func_i) {
count_map[func_i] += 1;
}
}
}
/*auto &largest_funcset = funcsets[funcsets.size() - 1];
// We pick the smaller of the two containers to iterate over to
// reduce the number of items we have to look at.
if (largest_funcset.size() < count_map.size()) {
for (unsigned func_i : largest_funcset) {
if (count_map[func_i] < 1) {
continue;
}
if (count_map.find(func_i) != count_map.end()) {
count_map[func_i] += 1;
}
}
} else {
for (auto &count_map_pair : count_map) {
unsigned func_i = count_map_pair.first;
if (count_map[func_i] < 1) {
continue;
}
if (largest_funcset.find(func_i) != largest_funcset.end()) {
count_map[func_i] += 1;
}
}
}*/
auto iter = count_map.begin();
for (; iter != count_map.end();) {
if (iter->second >= 2) {
++iter;
} else {
count_map.erase(iter++);
}
}
return count_map;
}
template <typename Container1, typename Container2>
std::vector<unsigned>
get_subset_candidates(const Container1 &argset,
const Container2 &restrict_to_funcset)
{
std::vector<unsigned> indices;
for (auto f : restrict_to_funcset) {
indices.push_back(f);
}
std::sort(std::begin(indices), std::end(indices));
std::vector<unsigned> intersect_result;
for (const auto &arg : argset) {
std::set_intersection(indices.begin(), indices.end(),
arg_to_funcset[arg].begin(),
arg_to_funcset[arg].end(),
std::back_inserter(intersect_result));
intersect_result.swap(indices);
intersect_result.clear();
}
return indices;
}
void update_func_argset(unsigned func_i,
const std::vector<unsigned> &new_args)
{
// Update a function with a new set of arguments.
auto &old_args = func_to_argset[func_i];
std::set<unsigned> diff;
std::set_difference(old_args.begin(), old_args.end(), new_args.begin(),
new_args.end(), std::inserter(diff, diff.begin()));
for (auto &deleted_arg : diff) {
arg_to_funcset[deleted_arg].erase(func_i);
}
diff.clear();
std::set_difference(new_args.begin(), new_args.end(), old_args.begin(),
old_args.end(), std::inserter(diff, diff.begin()));
for (auto &added_arg : diff) {
arg_to_funcset[added_arg].insert(func_i);
}
func_to_argset[func_i].clear();
func_to_argset[func_i].insert(new_args.begin(), new_args.end());
}
};
std::vector<unsigned> set_diff(const std::set<unsigned> &a,
const std::vector<unsigned> &b)
{
std::vector<unsigned> diff;
std::set_difference(a.begin(), a.end(), b.begin(), b.end(),
std::inserter(diff, diff.begin()));
return diff;
}
void add_to_sorted_vec(std::vector<unsigned> &vec, unsigned number)
{
if (std::find(vec.begin(), vec.end(), number) == vec.end()) {
// Add number if not found
vec.insert(std::upper_bound(vec.begin(), vec.end(), number), number);
}
}
void match_common_args(const std::string &func_class, const vec_basic &funcs_,
umap_basic_basic &opt_subs)
{
std::vector<std::pair<RCP<const Basic>, vec_basic>> funcs;
for (auto &b : funcs_) {
funcs.push_back(std::make_pair(b, b->get_args()));
}
std::sort(funcs.begin(), funcs.end(),
[](const std::pair<RCP<const Basic>, vec_basic> &a,
const std::pair<RCP<const Basic>, vec_basic> &b) {
return a.second.size() < b.second.size();
});
auto arg_tracker = FuncArgTracker(funcs);
std::set<unsigned> changed;
std::map<unsigned, unsigned> common_arg_candidates_counts;
for (unsigned i = 0; i < funcs.size(); i++) {
common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
arg_tracker.func_to_argset[i], i + 1);
std::deque<unsigned> common_arg_candidates;
for (auto it = common_arg_candidates_counts.begin();
it != common_arg_candidates_counts.end(); ++it) {
common_arg_candidates.push_back(it->first);
}
// Sort the candidates in order of match size.
// This makes us try combining smaller matches first.
std::sort(common_arg_candidates.begin(), common_arg_candidates.end(),
[&](unsigned a, unsigned b) {
if (common_arg_candidates_counts[a]
== common_arg_candidates_counts[b]) {
return a < b;
}
return common_arg_candidates_counts[a]
< common_arg_candidates_counts[b];
});
while (common_arg_candidates.size() > 0) {
unsigned j = common_arg_candidates.front();
common_arg_candidates.pop_front();
std::vector<unsigned> com_args;
std::set_intersection(arg_tracker.func_to_argset[i].begin(),
arg_tracker.func_to_argset[i].end(),
arg_tracker.func_to_argset[j].begin(),
arg_tracker.func_to_argset[j].end(),
std::back_inserter(com_args));
if (com_args.size() <= 1) {
// This may happen if a set of common arguments was already
// combined in a previous iteration.
continue;
}
std::vector<unsigned> diff_i
= set_diff(arg_tracker.func_to_argset[i], com_args);
unsigned com_func_number;
if (diff_i.size() > 0) {
// com_func needs to be unevaluated to allow for recursive
// matches.
auto com_func = function_symbol(
func_class, arg_tracker.get_args_in_value_order(com_args));
com_func_number = arg_tracker.get_or_add_value_number(com_func);
add_to_sorted_vec(diff_i, com_func_number);
arg_tracker.update_func_argset(i, diff_i);
changed.insert(i);
} else {
// Treat the whole expression as a CSE.
//
// The reason this needs to be done is somewhat subtle. Within
// tree_cse(), to_eliminate only contains expressions that are
// seen more than once. The problem is unevaluated expressions
// do not compare equal to the evaluated equivalent. So
// tree_cse() won't mark funcs[i] as a CSE if we use an
// unevaluated version.
com_func_number
= arg_tracker.get_or_add_value_number(funcs[i].first);
}
std::vector<unsigned> diff_j
= set_diff(arg_tracker.func_to_argset[j], com_args);
add_to_sorted_vec(diff_j, com_func_number);
arg_tracker.update_func_argset(j, diff_j);
changed.insert(j);
for (unsigned k : arg_tracker.get_subset_candidates(
com_args, common_arg_candidates)) {
std::vector<unsigned> diff_k
= set_diff(arg_tracker.func_to_argset[k], com_args);
add_to_sorted_vec(diff_k, com_func_number);
arg_tracker.update_func_argset(k, diff_k);
changed.insert(k);
}
}
if (std::find(changed.begin(), changed.end(), i) != changed.end()) {
opt_subs[funcs[i].first] = function_symbol(
func_class, arg_tracker.get_args_in_value_order(
arg_tracker.func_to_argset[i]));
}
arg_tracker.stop_arg_tracking(i);
}
}
class OptsCSEVisitor : public BaseVisitor<OptsCSEVisitor>
{
public:
umap_basic_basic &opt_subs;
set_basic adds;
set_basic muls;
set_basic seen_subexp;
OptsCSEVisitor(umap_basic_basic &opt_subs_) : opt_subs(opt_subs_)
{
}
bool is_seen(const Basic &expr)
{
return (seen_subexp.find(expr.rcp_from_this()) != seen_subexp.end());
}
void bvisit(const Derivative &x)
{
return;
}
void bvisit(const Subs &x)
{
return;
}
void bvisit(const Add &x)
{
if (not is_seen(x)) {
seen_subexp.insert(x.rcp_from_this());
for (const auto &p : x.get_args()) {
p->accept(*this);
}
adds.insert(x.rcp_from_this());
}
}
void bvisit(const Pow &x)
{
if (not is_seen(x)) {
auto expr = x.rcp_from_this();
seen_subexp.insert(expr);
for (const auto &p : x.get_args()) {
p->accept(*this);
}
auto ex = x.get_exp();
if (is_a<Mul>(*ex)) {
ex = static_cast<const Mul &>(*ex).get_coef();
}
if (is_a_Number(*ex)
and static_cast<const Number &>(*ex).is_negative()) {
vec_basic v({pow(x.get_base(), neg(x.get_exp())), integer(-1)});
opt_subs[expr] = function_symbol("pow", v);
}
}
}
void bvisit(const Mul &x)
{
if (not is_seen(x)) {
auto expr = x.rcp_from_this();
seen_subexp.insert(expr);
for (const auto &p : x.get_args()) {
p->accept(*this);
}
if (x.get_coef()->is_negative()) {
auto neg_expr = neg(x.rcp_from_this());
if (not is_a<Symbol>(*neg_expr)) {
opt_subs[expr]
= function_symbol("mul", {integer(-1), neg_expr});
seen_subexp.insert(neg_expr);
expr = neg_expr;
}
}
if (is_a<Mul>(*expr)) {
muls.insert(expr);
}
}
}
void bvisit(const Basic &x)
{
auto v = x.get_args();
if (v.size() > 0 and not is_seen(x)) {
auto expr = x.rcp_from_this();
seen_subexp.insert(expr);
for (const auto &p : v) {
p->accept(*this);
}
}
}
};
vec_basic set_as_vec(const set_basic &s)
{
vec_basic result;
for (auto &u : s) {
result.push_back(u);
}
return result;
}
umap_basic_basic opt_cse(const vec_basic &exprs)
{
// Find optimization opportunities in Adds, Muls, Pows and negative
// coefficient Muls
umap_basic_basic opt_subs;
OptsCSEVisitor visitor(opt_subs);
for (auto &e : exprs) {
e->accept(visitor);
}
match_common_args("add", set_as_vec(visitor.adds), opt_subs);
match_common_args("mul", set_as_vec(visitor.muls), opt_subs);
return opt_subs;
}
class RebuildVisitor : public BaseVisitor<RebuildVisitor, TransformVisitor>
{
private:
umap_basic_basic &subs;
umap_basic_basic &opt_subs;
set_basic &to_eliminate;
set_basic &excluded_symbols;
vec_pair &replacements;
unsigned next_symbol_index = 0;
public:
using TransformVisitor::result_;
using TransformVisitor::bvisit;
RebuildVisitor(umap_basic_basic &subs_, umap_basic_basic &opt_subs_,
set_basic &to_eliminate_, set_basic &excluded_symbols_,
vec_pair &replacements_)
: subs(subs_), opt_subs(opt_subs_), to_eliminate(to_eliminate_),
excluded_symbols(excluded_symbols_), replacements(replacements_)
{
}
virtual RCP<const Basic> apply(const RCP<const Basic> &orig_expr)
{
RCP<const Basic> expr = orig_expr;
if (is_a_Atom(*expr)) {
return expr;
}
auto iter = subs.find(expr);
if (iter != subs.end()) {
return iter->second;
}
auto iter2 = opt_subs.find(expr);
if (iter2 != opt_subs.end()) {
expr = iter2->second;
}
expr->accept(*this);
auto new_expr = result_;
if (to_eliminate.find(orig_expr) != to_eliminate.end()) {
auto sym = next_symbol();
subs[orig_expr] = sym;
replacements.push_back(
std::pair<RCP<const Basic>, RCP<const Basic>>(sym, new_expr));
return sym;
}
return new_expr;
}
RCP<const Basic> next_symbol()
{
RCP<const Basic> sym = symbol("x" + to_string(next_symbol_index));
next_symbol_index++;
if (excluded_symbols.find(sym) == excluded_symbols.end()) {
return sym;
} else {
return next_symbol();
}
};
void bvisit(const FunctionSymbol &x)
{
auto &fargs = x.get_vec();
vec_basic newargs;
for (const auto &a : fargs) {
newargs.push_back(apply(a));
}
if (x.get_name() == "add") {
result_ = add(newargs);
} else if (x.get_name() == "mul") {
result_ = mul(newargs);
} else if (x.get_name() == "pow") {
result_ = pow(newargs[0], newargs[1]);
} else {
result_ = x.create(newargs);
}
}
};
void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
const vec_basic &exprs, umap_basic_basic &opt_subs)
{
set_basic to_eliminate;
set_basic seen_subexp;
set_basic excluded_symbols;
std::function<void(RCP<const Basic> & expr)> find_repeated;
find_repeated = [&](RCP<const Basic> expr) -> void {
if (is_a_Number(*expr)) {
return;
}
if (is_a<Symbol>(*expr)) {
excluded_symbols.insert(expr);
}
if (seen_subexp.find(expr) != seen_subexp.end()) {
to_eliminate.insert(expr);
return;
}
seen_subexp.insert(expr);
auto iter = opt_subs.find(expr);
if (iter != opt_subs.end()) {
expr = iter->second;
}
vec_basic args = expr->get_args();
for (auto &arg : args) {
find_repeated(arg);
}
};
for (auto e : exprs) {
find_repeated(e);
}
umap_basic_basic subs;
RebuildVisitor rebuild_visitor(subs, opt_subs, to_eliminate,
excluded_symbols, replacements);
for (auto &e : exprs) {
auto reduced_e = rebuild_visitor.apply(e);
reduced_exprs.push_back(reduced_e);
}
}
void cse(vec_pair &replacements, vec_basic &reduced_exprs,
const vec_basic &exprs)
{
// Find other optimization opportunities.
umap_basic_basic opt_subs = opt_cse(exprs);
// Main CSE algorithm.
tree_cse(replacements, reduced_exprs, exprs, opt_subs);
}
}