cse.cpp
1 #include <symengine/basic.h>
2 #include <symengine/add.h>
3 #include <symengine/mul.h>
4 #include <symengine/functions.h>
5 #include <symengine/visitor.h>
6 
7 #include <queue>
8 
9 namespace SymEngine
10 {
11 umap_basic_basic opt_cse(const vec_basic &exprs);
12 void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
13  const vec_basic &exprs, umap_basic_basic &opt_subs);
14 
16 {
17 
18 public:
20  value_numbers;
21  vec_basic value_number_to_value;
22  std::vector<std::set<unsigned>> arg_to_funcset;
23  std::vector<std::set<unsigned>> func_to_argset;
24 
25 public:
27  const std::vector<std::pair<RCP<const Basic>, vec_basic>> &funcs)
28  {
29  arg_to_funcset.resize(funcs.size());
30  for (unsigned func_i = 0; func_i < funcs.size(); func_i++) {
31  std::set<unsigned> func_argset;
32  for (auto &func_arg : funcs[func_i].second) {
33  unsigned arg_number = get_or_add_value_number(func_arg);
34  func_argset.insert(arg_number);
35  arg_to_funcset[arg_number].insert(func_i);
36  }
37  func_to_argset.push_back(func_argset);
38  }
39  }
40 
41  template <typename Container>
42  vec_basic get_args_in_value_order(Container &argset)
43  {
44  vec_basic v;
45  for (unsigned i : argset) {
46  v.push_back(value_number_to_value[i]);
47  }
48  return v;
49  }
50 
51  unsigned get_or_add_value_number(RCP<const Basic> value)
52  {
53  unsigned nvalues = numeric_cast<unsigned>(value_numbers.size());
54  auto ret = value_numbers.insert(std::make_pair(value, nvalues));
55  bool inserted = ret.second;
56  if (inserted) {
57  value_number_to_value.push_back(value);
58  arg_to_funcset.push_back(std::set<unsigned>());
59  return nvalues;
60  } else {
61  return ret.first->second;
62  }
63  }
64 
65  void stop_arg_tracking(unsigned func_i)
66  {
67  for (unsigned arg : func_to_argset[func_i]) {
68  arg_to_funcset[arg].erase(func_i);
69  }
70  }
71 
72  /*
73  Return a dict whose keys are function numbers. The entries of the dict
74  are the number of arguments said function has in common with `argset`.
75  Entries have at least 2 items in common.
76  */
78  get_common_arg_candidates(std::set<unsigned> &argset, unsigned min_func_i)
79  {
82  for (unsigned arg : argset) {
83  funcsets.push_back(arg_to_funcset[arg]);
84  }
85  // Sorted by size to make best use of the performance hack below.
86  std::sort(funcsets.begin(), funcsets.end(),
87  [](const std::set<unsigned> &a, const std::set<unsigned> &b) {
88  return a.size() < b.size();
89  });
90 
91  for (unsigned i = 0; i < funcsets.size(); i++) {
92  auto &funcset = funcsets[i];
93  for (unsigned func_i : funcset) {
94  if (func_i >= min_func_i) {
95  count_map[func_i] += 1;
96  }
97  }
98  }
99 
100  /*auto &largest_funcset = funcsets[funcsets.size() - 1];
101 
102  // We pick the smaller of the two containers to iterate over to
103  // reduce the number of items we have to look at.
104 
105  if (largest_funcset.size() < count_map.size()) {
106  for (unsigned func_i : largest_funcset) {
107  if (count_map[func_i] < 1) {
108  continue;
109  }
110  if (count_map.find(func_i) != count_map.end()) {
111  count_map[func_i] += 1;
112  }
113  }
114  } else {
115  for (auto &count_map_pair : count_map) {
116  unsigned func_i = count_map_pair.first;
117  if (count_map[func_i] < 1) {
118  continue;
119  }
120  if (largest_funcset.find(func_i) != largest_funcset.end()) {
121  count_map[func_i] += 1;
122  }
123  }
124  }*/
125  auto iter = count_map.begin();
126  for (; iter != count_map.end();) {
127  if (iter->second >= 2) {
128  ++iter;
129  } else {
130  count_map.erase(iter++);
131  }
132  }
133  return count_map;
134  }
135 
136  template <typename Container1, typename Container2>
138  get_subset_candidates(const Container1 &argset,
139  const Container2 &restrict_to_funcset)
140  {
141  std::vector<unsigned> indices;
142  for (auto f : restrict_to_funcset) {
143  indices.push_back(f);
144  }
145  std::sort(std::begin(indices), std::end(indices));
146  std::vector<unsigned> intersect_result;
147  for (const auto &arg : argset) {
148  std::set_intersection(indices.begin(), indices.end(),
149  arg_to_funcset[arg].begin(),
150  arg_to_funcset[arg].end(),
151  std::back_inserter(intersect_result));
152  intersect_result.swap(indices);
153  intersect_result.clear();
154  }
155  return indices;
156  }
157 
158  void update_func_argset(unsigned func_i,
159  const std::vector<unsigned> &new_args)
160  {
161  // Update a function with a new set of arguments.
162  auto &old_args = func_to_argset[func_i];
163 
164  std::set<unsigned> diff;
165  std::set_difference(old_args.begin(), old_args.end(), new_args.begin(),
166  new_args.end(), std::inserter(diff, diff.begin()));
167 
168  for (auto &deleted_arg : diff) {
169  arg_to_funcset[deleted_arg].erase(func_i);
170  }
171 
172  diff.clear();
173  std::set_difference(new_args.begin(), new_args.end(), old_args.begin(),
174  old_args.end(), std::inserter(diff, diff.begin()));
175 
176  for (auto &added_arg : diff) {
177  arg_to_funcset[added_arg].insert(func_i);
178  }
179 
180  func_to_argset[func_i].clear();
181  func_to_argset[func_i].insert(new_args.begin(), new_args.end());
182  }
183 };
184 
185 std::vector<unsigned> set_diff(const std::set<unsigned> &a,
186  const std::vector<unsigned> &b)
187 {
189  std::set_difference(a.begin(), a.end(), b.begin(), b.end(),
190  std::inserter(diff, diff.begin()));
191  return diff;
192 }
193 
194 void add_to_sorted_vec(std::vector<unsigned> &vec, unsigned number)
195 {
196  if (std::find(vec.begin(), vec.end(), number) == vec.end()) {
197  // Add number if not found
198  vec.insert(std::upper_bound(vec.begin(), vec.end(), number), number);
199  }
200 }
201 
202 void match_common_args(const std::string &func_class, const vec_basic &funcs_,
203  umap_basic_basic &opt_subs)
204 {
205  std::vector<std::pair<RCP<const Basic>, vec_basic>> funcs;
206  for (auto &b : funcs_) {
207  funcs.push_back(std::make_pair(b, b->get_args()));
208  }
209  std::sort(funcs.begin(), funcs.end(),
210  [](const std::pair<RCP<const Basic>, vec_basic> &a,
211  const std::pair<RCP<const Basic>, vec_basic> &b) {
212  return a.second.size() < b.second.size();
213  });
214 
215  auto arg_tracker = FuncArgTracker(funcs);
216 
217  std::set<unsigned> changed;
218  std::map<unsigned, unsigned> common_arg_candidates_counts;
219 
220  for (unsigned i = 0; i < funcs.size(); i++) {
221  common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
222  arg_tracker.func_to_argset[i], i + 1);
223 
224  std::deque<unsigned> common_arg_candidates;
225  for (auto it = common_arg_candidates_counts.begin();
226  it != common_arg_candidates_counts.end(); ++it) {
227  common_arg_candidates.push_back(it->first);
228  }
229 
230  // Sort the candidates in order of match size.
231  // This makes us try combining smaller matches first.
232  std::sort(common_arg_candidates.begin(), common_arg_candidates.end(),
233  [&](unsigned a, unsigned b) {
234  if (common_arg_candidates_counts[a]
235  == common_arg_candidates_counts[b]) {
236  return a < b;
237  }
238  return common_arg_candidates_counts[a]
239  < common_arg_candidates_counts[b];
240  });
241 
242  while (common_arg_candidates.size() > 0) {
243  unsigned j = common_arg_candidates.front();
244  common_arg_candidates.pop_front();
245  std::vector<unsigned> com_args;
246 
247  std::set_intersection(arg_tracker.func_to_argset[i].begin(),
248  arg_tracker.func_to_argset[i].end(),
249  arg_tracker.func_to_argset[j].begin(),
250  arg_tracker.func_to_argset[j].end(),
251  std::back_inserter(com_args));
252 
253  if (com_args.size() <= 1) {
254  // This may happen if a set of common arguments was already
255  // combined in a previous iteration.
256  continue;
257  }
258 
259  std::vector<unsigned> diff_i
260  = set_diff(arg_tracker.func_to_argset[i], com_args);
261 
262  unsigned com_func_number;
263 
264  if (diff_i.size() > 0) {
265  // com_func needs to be unevaluated to allow for recursive
266  // matches.
267  auto com_func = function_symbol(
268  func_class, arg_tracker.get_args_in_value_order(com_args));
269  com_func_number = arg_tracker.get_or_add_value_number(com_func);
270  add_to_sorted_vec(diff_i, com_func_number);
271  arg_tracker.update_func_argset(i, diff_i);
272  changed.insert(i);
273 
274  } else {
275  // Treat the whole expression as a CSE.
276  //
277  // The reason this needs to be done is somewhat subtle. Within
278  // tree_cse(), to_eliminate only contains expressions that are
279  // seen more than once. The problem is unevaluated expressions
280  // do not compare equal to the evaluated equivalent. So
281  // tree_cse() won't mark funcs[i] as a CSE if we use an
282  // unevaluated version.
283  com_func_number
284  = arg_tracker.get_or_add_value_number(funcs[i].first);
285  }
286 
287  std::vector<unsigned> diff_j
288  = set_diff(arg_tracker.func_to_argset[j], com_args);
289  add_to_sorted_vec(diff_j, com_func_number);
290  arg_tracker.update_func_argset(j, diff_j);
291  changed.insert(j);
292 
293  for (unsigned k : arg_tracker.get_subset_candidates(
294  com_args, common_arg_candidates)) {
295  std::vector<unsigned> diff_k
296  = set_diff(arg_tracker.func_to_argset[k], com_args);
297  add_to_sorted_vec(diff_k, com_func_number);
298  arg_tracker.update_func_argset(k, diff_k);
299  changed.insert(k);
300  }
301  }
302  if (std::find(changed.begin(), changed.end(), i) != changed.end()) {
303  opt_subs[funcs[i].first] = function_symbol(
304  func_class, arg_tracker.get_args_in_value_order(
305  arg_tracker.func_to_argset[i]));
306  }
307  arg_tracker.stop_arg_tracking(i);
308  }
309 }
310 
311 class OptsCSEVisitor : public BaseVisitor<OptsCSEVisitor>
312 {
313 public:
314  umap_basic_basic &opt_subs;
315  set_basic adds;
316  set_basic muls;
317  set_basic seen_subexp;
318  OptsCSEVisitor(umap_basic_basic &opt_subs_) : opt_subs(opt_subs_) {}
319  bool is_seen(const Basic &expr)
320  {
321  return (seen_subexp.find(expr.rcp_from_this()) != seen_subexp.end());
322  }
323  void bvisit(const Derivative &x)
324  {
325  return;
326  }
327  void bvisit(const Subs &x)
328  {
329  return;
330  }
331  void bvisit(const Add &x)
332  {
333  if (not is_seen(x)) {
334  seen_subexp.insert(x.rcp_from_this());
335  for (const auto &p : x.get_args()) {
336  p->accept(*this);
337  }
338  adds.insert(x.rcp_from_this());
339  }
340  }
341  void bvisit(const Pow &x)
342  {
343  if (not is_seen(x)) {
344  auto expr = x.rcp_from_this();
345  seen_subexp.insert(expr);
346  for (const auto &p : x.get_args()) {
347  p->accept(*this);
348  }
349  auto ex = x.get_exp();
350  if (is_a<Mul>(*ex)) {
351  ex = static_cast<const Mul &>(*ex).get_coef();
352  }
353  if (is_a_Number(*ex)
354  and static_cast<const Number &>(*ex).is_negative()) {
355  vec_basic v({pow(x.get_base(), neg(x.get_exp())), integer(-1)});
356  opt_subs[expr] = function_symbol("pow", v);
357  }
358  }
359  }
360  void bvisit(const Mul &x)
361  {
362  if (not is_seen(x)) {
363  auto expr = x.rcp_from_this();
364  seen_subexp.insert(expr);
365  for (const auto &p : x.get_args()) {
366  p->accept(*this);
367  }
368  if (x.get_coef()->is_negative()) {
369  auto neg_expr = neg(x.rcp_from_this());
370  if (not is_a<Symbol>(*neg_expr)) {
371  opt_subs[expr]
372  = function_symbol("mul", {integer(-1), neg_expr});
373  seen_subexp.insert(neg_expr);
374  expr = neg_expr;
375  }
376  }
377  if (is_a<Mul>(*expr)) {
378  muls.insert(expr);
379  }
380  }
381  }
382  void bvisit(const Basic &x)
383  {
384  auto v = x.get_args();
385  if (v.size() > 0 and not is_seen(x)) {
386  auto expr = x.rcp_from_this();
387  seen_subexp.insert(expr);
388  for (const auto &p : v) {
389  p->accept(*this);
390  }
391  }
392  }
393 };
394 
395 vec_basic set_as_vec(const set_basic &s)
396 {
397  vec_basic result;
398  for (auto &u : s) {
399  result.push_back(u);
400  }
401  return result;
402 }
403 
404 umap_basic_basic opt_cse(const vec_basic &exprs)
405 {
406  // Find optimization opportunities in Adds, Muls, Pows and negative
407  // coefficient Muls
408  umap_basic_basic opt_subs;
409  OptsCSEVisitor visitor(opt_subs);
410  for (auto &e : exprs) {
411  e->accept(visitor);
412  }
413 
414  match_common_args("add", set_as_vec(visitor.adds), opt_subs);
415  match_common_args("mul", set_as_vec(visitor.muls), opt_subs);
416 
417  return opt_subs;
418 }
419 
420 class RebuildVisitor : public BaseVisitor<RebuildVisitor, TransformVisitor>
421 {
422 private:
423  umap_basic_basic &subs;
424  umap_basic_basic &opt_subs;
425  set_basic &to_eliminate;
426  set_basic &excluded_symbols;
427  vec_pair &replacements;
428  unsigned next_symbol_index = 0;
429 
430 public:
431  using TransformVisitor::bvisit;
432  using TransformVisitor::result_;
434  set_basic &to_eliminate_, set_basic &excluded_symbols_,
435  vec_pair &replacements_)
436  : subs(subs_), opt_subs(opt_subs_), to_eliminate(to_eliminate_),
437  excluded_symbols(excluded_symbols_), replacements(replacements_)
438  {
439  }
440  RCP<const Basic> apply(const RCP<const Basic> &orig_expr) override
441  {
442  RCP<const Basic> expr = orig_expr;
443  if (is_a_Atom(*expr)) {
444  return expr;
445  }
446 
447  auto iter = subs.find(expr);
448  if (iter != subs.end()) {
449  return iter->second;
450  }
451  auto iter2 = opt_subs.find(expr);
452  if (iter2 != opt_subs.end()) {
453  expr = iter2->second;
454  }
455  expr->accept(*this);
456  auto new_expr = result_;
457  if (to_eliminate.find(orig_expr) != to_eliminate.end()) {
458  auto sym = next_symbol();
459  subs[orig_expr] = sym;
460  replacements.push_back(
461  std::pair<RCP<const Basic>, RCP<const Basic>>(sym, new_expr));
462  return sym;
463  }
464  return new_expr;
465  }
466  RCP<const Basic> next_symbol()
467  {
468  RCP<const Basic> sym = symbol("x" + to_string(next_symbol_index));
469  next_symbol_index++;
470  if (excluded_symbols.find(sym) == excluded_symbols.end()) {
471  return sym;
472  } else {
473  return next_symbol();
474  }
475  };
476  void bvisit(const FunctionSymbol &x)
477  {
478  auto &fargs = x.get_vec();
479  vec_basic newargs;
480  for (const auto &a : fargs) {
481  newargs.push_back(apply(a));
482  }
483  if (x.get_name() == "add") {
484  result_ = add(newargs);
485  } else if (x.get_name() == "mul") {
486  result_ = mul(newargs);
487  } else if (x.get_name() == "pow") {
488  result_ = pow(newargs[0], newargs[1]);
489  } else {
490  result_ = x.create(newargs);
491  }
492  }
493 };
494 
495 void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
496  const vec_basic &exprs, umap_basic_basic &opt_subs)
497 {
498  set_basic to_eliminate;
499  set_basic seen_subexp;
500  set_basic excluded_symbols;
501 
502  std::function<void(RCP<const Basic> & expr)> find_repeated;
503  find_repeated = [&](RCP<const Basic> expr) -> void {
504  // Do not replace atoms
505  if (is_a_Number(*expr) or is_a<BooleanAtom>(*expr)) {
506  return;
507  }
508 
509  if (is_a<Symbol>(*expr)) {
510  excluded_symbols.insert(expr);
511  }
512 
513  if (seen_subexp.find(expr) != seen_subexp.end()) {
514  to_eliminate.insert(expr);
515  return;
516  }
517 
518  seen_subexp.insert(expr);
519 
520  auto iter = opt_subs.find(expr);
521  if (iter != opt_subs.end()) {
522  expr = iter->second;
523  }
524 
525  vec_basic args = expr->get_args();
526 
527  for (auto &arg : args) {
528  find_repeated(arg);
529  }
530  };
531 
532  for (auto e : exprs) {
533  find_repeated(e);
534  }
535 
536  umap_basic_basic subs;
537 
538  RebuildVisitor rebuild_visitor(subs, opt_subs, to_eliminate,
539  excluded_symbols, replacements);
540 
541  for (auto &e : exprs) {
542  auto reduced_e = rebuild_visitor.apply(e);
543  reduced_exprs.push_back(reduced_e);
544  }
545 }
546 
547 void cse(vec_pair &replacements, vec_basic &reduced_exprs,
548  const vec_basic &exprs)
549 {
550  // Find other optimization opportunities.
551  umap_basic_basic opt_subs = opt_cse(exprs);
552 
553  // Main CSE algorithm.
554  tree_cse(replacements, reduced_exprs, exprs, opt_subs);
555 }
556 } // namespace SymEngine
Classes and functions relating to the binary operation of addition.
T back_inserter(T... args)
The base class for SymEngine.
T begin(T... args)
The base class for representing addition in symbolic expressions.
Definition: add.h:27
vec_basic get_args() const override
Returns the arguments of the Add.
Definition: add.cpp:397
The lowest unit of symbolic representation.
Definition: basic.h:97
virtual vec_basic get_args() const =0
Returns the list of arguments.
RCP< T > rcp_from_this()
Get RCP<T> pointer to self (it will cast the pointer to T)
RCP< const Basic > create(const vec_basic &x) const override
Method to construct classes with canonicalization.
Definition: functions.cpp:1905
const std::string & get_name() const
Definition: functions.h:654
vec_basic get_args() const override
Returns the list of arguments.
Definition: mul.cpp:507
virtual bool is_negative() const =0
vec_basic get_args() const override
Returns the list of arguments.
Definition: pow.cpp:266
RCP< const Basic > get_exp() const
Definition: pow.h:42
RCP< const Basic > get_base() const
Definition: pow.h:37
T clear(T... args)
T end(T... args)
T erase(T... args)
T find(T... args)
T front(T... args)
T insert(T... args)
T inserter(T... args)
T make_pair(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a_Number(const Basic &b)
Definition: number.h:130
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
RCP< const Symbol > symbol(const std::string &name)
inline version to return Symbol
Definition: symbol.h:82
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
bool is_a_Atom(const Basic &b)
Returns true if b is an atom. i.e. b.get_args returns an empty vector.
Definition: basic.cpp:95
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
Definition: mul.cpp:443
T pop_front(T... args)
T push_back(T... args)
T resize(T... args)
T set_difference(T... args)
T set_intersection(T... args)
T size(T... args)
T sort(T... args)
Our comparison (==)
Definition: basic.h:219
T swap(T... args)
T upper_bound(T... args)