refine.cpp
1 #include <symengine/refine.h>
2 
3 namespace SymEngine
4 {
5 
6 void RefineVisitor::bvisit(const Abs &x)
7 {
8  auto farg = x.get_arg();
9  auto newarg = apply(farg);
10  if (is_true(is_nonnegative(*newarg, assumptions_))) {
11  result_ = newarg;
12  } else if (is_true(is_nonpositive(*newarg, assumptions_))) {
13  result_ = neg(newarg);
14  } else if (is_a<Conjugate>(*newarg)) {
15  result_ = abs(down_cast<const Conjugate &>(*newarg).get_arg());
16  } else {
17  result_ = abs(newarg);
18  }
19 }
20 
21 void RefineVisitor::bvisit(const Sign &x)
22 {
23  auto farg = x.get_arg();
24  auto newarg = apply(farg);
25  if (is_true(is_positive(*newarg, assumptions_))) {
26  result_ = integer(1);
27  } else if (is_true(is_negative(*newarg, assumptions_))) {
28  result_ = integer(-1);
29  } else if (is_true(is_zero(*newarg, assumptions_))) {
30  result_ = integer(0);
31  } else {
32  result_ = sign(newarg);
33  }
34 }
35 
36 void RefineVisitor::bvisit(const Floor &x)
37 {
38  auto farg = x.get_arg();
39  auto newarg = apply(farg);
40  if (is_true(is_integer(*newarg, assumptions_))) {
41  result_ = newarg;
42  } else if (could_extract_minus(*newarg)) {
43  result_ = neg(ceiling(neg(newarg)));
44  } else {
45  result_ = floor(newarg);
46  }
47 }
48 
49 void RefineVisitor::bvisit(const Ceiling &x)
50 {
51  auto farg = x.get_arg();
52  auto newarg = apply(farg);
53  if (is_true(is_integer(*newarg, assumptions_))) {
54  result_ = newarg;
55  } else if (could_extract_minus(*newarg)) {
56  result_ = neg(floor(neg(newarg)));
57  } else {
58  result_ = ceiling(newarg);
59  }
60 }
61 
62 void RefineVisitor::bvisit(const Conjugate &x)
63 {
64  auto farg = x.get_arg();
65  auto newarg = apply(farg);
66  if (is_true(is_real(*newarg, assumptions_))) {
67  result_ = newarg;
68  } else {
69  result_ = conjugate(newarg);
70  }
71 }
72 
73 void RefineVisitor::bvisit(const Max &x)
74 {
75  // positive > nonpositive
76  // nonnegative and positive > negative
77  vec_basic nonpositive;
78  vec_basic negative;
79  vec_basic keep;
80  bool have_positive = false;
81  bool have_nonnegative = false;
82  for (auto arg : x.get_args()) {
83  auto newarg = apply(arg);
84  if (is_true(is_positive(*newarg, assumptions_))) {
85  keep.push_back(newarg);
86  have_positive = true;
87  } else if (is_true(is_nonnegative(*newarg, assumptions_))) {
88  keep.push_back(newarg);
89  have_nonnegative = true;
90  } else if (is_true(is_negative(*newarg, assumptions_))) {
91  negative.push_back(newarg);
92  } else if (is_true(is_nonpositive(*newarg, assumptions_))) {
93  nonpositive.push_back(newarg);
94  } else {
95  keep.push_back(newarg);
96  }
97  }
98  if (not have_positive and not nonpositive.empty()) {
99  std::copy(nonpositive.begin(), nonpositive.end(),
100  std::back_inserter(keep));
101  }
102  if (not have_nonnegative and not have_positive and not negative.empty()) {
103  std::copy(negative.begin(), negative.end(), std::back_inserter(keep));
104  }
105 
106  result_ = max(keep);
107 }
108 
109 void RefineVisitor::bvisit(const Min &x)
110 {
111  // negative < nonnegative
112  // nonpositive and negative < positive
113  vec_basic nonnegative;
114  vec_basic positive;
115  vec_basic keep;
116  bool have_negative = false;
117  bool have_nonpositive = false;
118  for (auto arg : x.get_args()) {
119  auto newarg = apply(arg);
120  if (is_true(is_negative(*newarg, assumptions_))) {
121  keep.push_back(newarg);
122  have_negative = true;
123  } else if (is_true(is_nonpositive(*newarg, assumptions_))) {
124  keep.push_back(newarg);
125  have_nonpositive = true;
126  } else if (is_true(is_positive(*newarg, assumptions_))) {
127  positive.push_back(newarg);
128  } else if (is_true(is_nonnegative(*newarg, assumptions_))) {
129  nonnegative.push_back(newarg);
130  } else {
131  keep.push_back(newarg);
132  }
133  }
134  if (not have_negative and not nonnegative.empty()) {
135  std::copy(nonnegative.begin(), nonnegative.end(),
136  std::back_inserter(keep));
137  }
138  if (not have_nonpositive and not have_negative and not positive.empty()) {
139  std::copy(positive.begin(), positive.end(), std::back_inserter(keep));
140  }
141 
142  result_ = min(keep);
143 }
144 
145 void RefineVisitor::bvisit(const Pow &x)
146 {
147  auto exp = x.get_exp();
148  auto newexp = apply(exp);
149  auto base = x.get_base();
150  auto newbase = apply(base);
151  // Handle cases when (x**k)**n = x**(k*n) or = abs(x)**(k*n)
152  if (is_a<Pow>(*newbase) and is_a_Number(*newexp)) {
153  const Pow &inner_pow = down_cast<const Pow &>(*newbase);
154  auto inner_exp = inner_pow.get_exp();
155  auto inner_base = inner_pow.get_base();
156  if (is_true(is_real(*inner_base, assumptions_))
157  and is_a_Number(*inner_exp)
158  and not down_cast<const Number &>(*inner_exp).is_complex()
159  and not down_cast<const Number &>(*newexp).is_complex()) {
160  if (is_true(is_positive(*inner_base, assumptions_))) {
161  result_ = pow(inner_base, mul(newexp, inner_exp));
162  } else {
163  result_ = pow(abs(inner_base), mul(newexp, inner_exp));
164  }
165  return;
166  }
167  }
168  result_ = pow(newbase, newexp);
169 }
170 
171 void RefineVisitor::bvisit(const Log &x)
172 {
173  auto farg = x.get_arg();
174  auto newarg = apply(farg);
175  if (is_a<Pow>(*newarg)) {
176  auto base = down_cast<const Pow &>(*newarg).get_base();
177  if (is_true(is_positive(*base, assumptions_))) {
178  auto exp = down_cast<const Pow &>(*newarg).get_exp();
179  if (is_true(is_real(*exp, assumptions_))) {
180  result_ = mul(exp, log(base));
181  return;
182  }
183  }
184  } else if (is_a<Integer>(*newarg)) {
185  auto base_exp = mp_perfect_power_decomposition(
186  down_cast<const Integer &>(*newarg).as_integer_class());
187  if (base_exp.second != 1) {
188  result_ = mul(make_rcp<const Integer>(base_exp.second),
189  log(make_rcp<const Integer>(base_exp.first)));
190  return;
191  }
192  }
193  result_ = log(newarg);
194 }
195 
196 void RefineVisitor::bvisit(const Interval &x)
197 {
198  if (eq(*x.get_start(), *SymEngine::infty(-1))
199  and eq(*x.get_end(), *SymEngine::infty(1))) {
200  result_ = reals();
201  return;
202  }
203  result_ = x.rcp_from_this();
204 }
205 
206 RCP<const Basic> refine(const RCP<const Basic> &x,
207  const Assumptions *assumptions)
208 {
209  RefineVisitor b(assumptions);
210  return b.apply(x);
211 }
212 
213 } // namespace SymEngine
T back_inserter(T... args)
T copy(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a_Number(const Basic &b)
Definition: number.h:130
std::enable_if< std::is_integral< T >::value, RCP< const Integer > >::type integer(T i)
Definition: integer.h:197
RCP< const Reals > reals()
Definition: sets.h:560
RCP< const Basic > max(const vec_basic &arg)
Canonicalize Max:
Definition: functions.cpp:3555
bool eq(const Basic &a, const Basic &b)
Checks equality for a and b
Definition: basic-inl.h:21
RCP< const Basic > sign(const RCP< const Basic > &arg)
Canonicalize Sign.
Definition: functions.cpp:527
RCP< const Basic > ceiling(const RCP< const Basic > &arg)
Canonicalize Ceiling:
Definition: functions.cpp:705
RCP< const Basic > abs(const RCP< const Basic > &arg)
Canonicalize Abs:
Definition: functions.cpp:3492
RCP< const Basic > exp(const RCP< const Basic > &x)
Returns the natural exponential function E**x = pow(E, x)
Definition: pow.cpp:271
bool could_extract_minus(const Basic &arg)
Definition: functions.cpp:325
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
RCP< const Basic > floor(const RCP< const Basic > &arg)
Canonicalize Floor:
Definition: functions.cpp:611
std::pair< integer_class, integer_class > mp_perfect_power_decomposition(const integer_class &n, bool lowest_exponent)
Decompose a positive integer into perfect powers.
Definition: ntheory.cpp:1677
RCP< const Basic > log(const RCP< const Basic > &arg)
Returns the Natural Logarithm from argument arg
Definition: functions.cpp:1774
RCP< const Basic > min(const vec_basic &arg)
Canonicalize Min:
Definition: functions.cpp:3659
RCP< const Basic > neg(const RCP< const Basic > &a)
Negation.
Definition: mul.cpp:443
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.
Definition: functions.cpp:149
T pow(T... args)