2 #include <symengine/matrix.h>
7 #include <symengine/symengine_exception.h>
9 #include <symengine/test_visitors.h>
14 CSRMatrix::CSRMatrix() {}
16 CSRMatrix::CSRMatrix(
unsigned row,
unsigned col) : row_(row), col_(col)
19 SYMENGINE_ASSERT(is_canonical());
24 : p_{p}, j_{j}, x_{x}, row_(row), col_(col)
26 SYMENGINE_ASSERT(is_canonical());
33 SYMENGINE_ASSERT(is_canonical());
36 CSRMatrix &CSRMatrix::operator=(CSRMatrix &&other)
47 CSRMatrix::as_vectors()
const
54 bool CSRMatrix::eq(
const MatrixBase &other)
const
56 unsigned row = this->nrows();
57 if (row != other.nrows() or this->ncols() != other.ncols())
60 if (is_a<CSRMatrix>(other)) {
61 const CSRMatrix &o = down_cast<const CSRMatrix &>(other);
63 if (this->p_[row] != o.p_[row])
66 for (
unsigned i = 0; i <= row; i++)
67 if (this->p_[i] != o.p_[i])
70 for (
unsigned i = 0; i < this->p_[row]; i++)
71 if ((this->j_[i] != o.j_[i]) or
neq(*this->x_[i], *(o.x_[i])))
76 return this->MatrixBase::eq(other);
80 bool CSRMatrix::is_canonical()
const
82 if (p_.size() != row_ + 1 or j_.size() != p_[row_] or x_.size() != p_[row_])
86 return csr_has_canonical_format(p_, j_, row_);
91 RCP<const Basic> CSRMatrix::get(
unsigned i,
unsigned j)
const
93 SYMENGINE_ASSERT(i < row_ and j < col_);
95 unsigned row_start = p_[i];
96 unsigned row_end = p_[i + 1];
99 if (row_start == row_end) {
103 while (row_start < row_end) {
104 k = (row_start + row_end) / 2;
107 }
else if (j_[k] < j) {
117 void CSRMatrix::set(
unsigned i,
unsigned j,
const RCP<const Basic> &e)
119 SYMENGINE_ASSERT(i < row_ and j < col_);
122 unsigned row_end = p_[i + 1];
123 unsigned end = p_[i + 1];
133 }
else if (j_[mid] >= j and j_[mid - 1] < j) {
136 }
else if (j_[mid - 1] >= j) {
144 if (k < row_end and j_[k] == j) {
147 x_.insert(x_.begin() + k, e);
148 j_.insert(j_.begin() + k, j);
149 for (
unsigned l = i + 1; l <= row_; l++)
153 if (k < row_end and j_[k] == j) {
154 x_.erase(x_.begin() + k);
155 j_.erase(j_.begin() + k);
156 for (
unsigned l = i + 1; l <= row_; l++)
162 tribool CSRMatrix::is_real(
const Assumptions *assumptions)
const
164 RealVisitor visitor(assumptions);
165 tribool cur = tribool::tritrue;
167 cur = and_tribool(cur, visitor.apply(*e));
175 unsigned CSRMatrix::rank()
const
177 throw NotImplementedError(
"Not Implemented");
180 RCP<const Basic> CSRMatrix::det()
const
182 throw NotImplementedError(
"Not Implemented");
185 void CSRMatrix::inv(MatrixBase &result)
const
187 throw NotImplementedError(
"Not Implemented");
190 void CSRMatrix::add_matrix(
const MatrixBase &other, MatrixBase &result)
const
192 throw NotImplementedError(
"Not Implemented");
195 void CSRMatrix::mul_matrix(
const MatrixBase &other, MatrixBase &result)
const
197 throw NotImplementedError(
"Not Implemented");
200 void CSRMatrix::elementwise_mul_matrix(
const MatrixBase &other,
201 MatrixBase &result)
const
203 if (is_a<CSRMatrix>(result)) {
204 auto &o = down_cast<const CSRMatrix &>(other);
205 auto &r = down_cast<CSRMatrix &>(result);
206 csr_binop_csr_canonical(*
this, o, r, mul);
211 void CSRMatrix::add_scalar(
const RCP<const Basic> &k, MatrixBase &result)
const
213 throw NotImplementedError(
"Not Implemented");
217 void CSRMatrix::mul_scalar(
const RCP<const Basic> &k, MatrixBase &result)
const
219 throw NotImplementedError(
"Not Implemented");
223 void CSRMatrix::conjugate(MatrixBase &result)
const
225 if (is_a<CSRMatrix>(result)) {
226 auto &r = down_cast<CSRMatrix &>(result);
228 vec_basic x(x_.size());
229 for (
unsigned i = 0; i < x_.size(); ++i) {
234 throw NotImplementedError(
"Not Implemented");
239 void CSRMatrix::transpose(MatrixBase &result)
const
241 if (is_a<CSRMatrix>(result)) {
242 auto &r = down_cast<CSRMatrix &>(result);
243 r = this->transpose();
245 throw NotImplementedError(
"Not Implemented");
249 CSRMatrix CSRMatrix::transpose(
bool conjugate)
const
251 const auto nnz = j_.size();
255 for (
unsigned i = 0; i < nnz; ++i)
259 for (
unsigned ri = 0; ri < row_; ++ri) {
260 for (
unsigned i = p_[ri]; i < p_[ri + 1]; ++i) {
261 const auto ci = j_[i];
262 const unsigned k = p[ci] + tmp[ci];
276 void CSRMatrix::conjugate_transpose(MatrixBase &result)
const
278 if (is_a<CSRMatrix>(result)) {
279 auto &r = down_cast<CSRMatrix &>(result);
280 r = this->transpose(
true);
282 throw NotImplementedError(
"Not Implemented");
287 void CSRMatrix::submatrix(MatrixBase &result,
unsigned row_start,
288 unsigned col_start,
unsigned row_end,
289 unsigned col_end,
unsigned row_step,
290 unsigned col_step)
const
292 throw NotImplementedError(
"Not Implemented");
296 void CSRMatrix::LU(MatrixBase &L, MatrixBase &U)
const
298 throw NotImplementedError(
"Not Implemented");
302 void CSRMatrix::LDL(MatrixBase &L, MatrixBase &D)
const
304 throw NotImplementedError(
"Not Implemented");
308 void CSRMatrix::LU_solve(
const MatrixBase &b, MatrixBase &x)
const
310 throw NotImplementedError(
"Not Implemented");
314 void CSRMatrix::FFLU(MatrixBase &LU)
const
316 throw NotImplementedError(
"Not Implemented");
320 void CSRMatrix::FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U)
const
322 throw NotImplementedError(
"Not Implemented");
326 void CSRMatrix::QR(MatrixBase &Q, MatrixBase &R)
const
328 throw NotImplementedError(
"Not Implemented");
332 void CSRMatrix::cholesky(MatrixBase &L)
const
334 throw NotImplementedError(
"Not Implemented");
342 unsigned row_end = 0;
343 unsigned jj = 0, j = 0;
344 RCP<const Basic> x = zero;
346 for (
unsigned i = 0; i < row_; i++) {
350 while (jj < row_end) {
355 while (jj < row_end and j_[jj] == j) {
378 for (
unsigned i = 0; i < row_; i++) {
379 unsigned row_start = p_[i];
380 unsigned row_end = p_[i + 1];
384 for (
unsigned jj = row_start; jj < row_end; jj++) {
389 [](
const std::pair<
unsigned, RCP<const Basic>> &x,
390 const std::pair<
unsigned, RCP<const Basic>> &y) {
391 return x.first < y.first;
394 for (
unsigned jj = row_start, n = 0; jj < row_end; jj++, n++) {
395 j_[jj] = temp[n].first;
396 x_[jj] = temp[n].second;
406 for (
unsigned i = 0; i < row_; i++) {
407 for (
unsigned j = p_[i]; j + 1 < p_[i + 1]; j++) {
408 if (j_[j] == j_[j + 1])
419 for (
unsigned i = 0; i < row_; i++) {
420 for (
unsigned jj = p_[i]; jj < p_[i + 1] - 1; jj++) {
421 if (j_[jj] > j_[jj + 1])
432 for (
unsigned i = 0; i < row_; i++) {
433 if (p_[i] > p_[i + 1])
437 return csr_has_sorted_indices(p_, j_, row_)
438 and not csr_has_duplicates(p_, j_, row_);
441 CSRMatrix CSRMatrix::from_coo(
unsigned row,
unsigned col,
447 unsigned nnz = numeric_cast<unsigned>(x.size());
450 vec_basic x_ = vec_basic(nnz);
452 for (
unsigned n = 0; n < nnz; n++) {
458 for (
unsigned i = 0, cumsum = 0; i < row; i++) {
466 unsigned row_, dest_;
467 for (
unsigned n = 0; n < nnz; n++) {
477 for (
unsigned i = 0, last = 0; i <= row; i++) {
482 csr_sort_indices(p_, j_, x_, row);
484 csr_sum_duplicates(p_, j_, x_, row);
491 CSRMatrix CSRMatrix::jacobian(
const vec_basic &exprs,
const vec_sym &x,
494 const unsigned nrows =
static_cast<unsigned>(exprs.size());
495 const unsigned ncols =
static_cast<unsigned>(x.size());
500 elems.reserve(nrows);
501 for (
unsigned ri = 0; ri < nrows; ++ri) {
503 for (
unsigned ci = 0; ci < ncols; ++ci) {
504 auto elem = exprs[ri]->diff(x[ci], diff_cache);
505 if (!is_true(
is_zero(*elem))) {
516 CSRMatrix CSRMatrix::jacobian(
const DenseMatrix &A,
const DenseMatrix &x,
519 SYMENGINE_ASSERT(A.col_ == 1);
520 SYMENGINE_ASSERT(x.col_ == 1);
522 syms.reserve(x.row_);
523 for (
const auto &dx : x.m_) {
524 if (!is_a<Symbol>(*dx)) {
525 throw SymEngineException(
"'x' must contain Symbols only");
527 syms.push_back(rcp_static_cast<const Symbol>(dx));
529 return CSRMatrix::jacobian(A.m_, syms, diff_cache);
532 void csr_matmat_pass1(
const CSRMatrix &A,
const CSRMatrix &B, CSRMatrix &C)
539 for (
unsigned i = 0; i < A.row_; i++) {
541 unsigned row_nnz = 0;
543 for (
unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++) {
544 unsigned j = A.j_[jj];
545 for (
unsigned kk = B.p_[j]; kk < B.p_[j + 1]; kk++) {
546 unsigned k = B.j_[kk];
554 unsigned next_nnz = nnz + row_nnz;
557 if (next_nnz < nnz) {
568 void csr_matmat_pass2(
const CSRMatrix &A,
const CSRMatrix &B, CSRMatrix &C)
571 vec_basic sums(A.col_, zero);
577 for (
unsigned i = 0; i < A.row_; i++) {
581 unsigned jj_start = A.p_[i];
582 unsigned jj_end = A.p_[i + 1];
583 for (
unsigned jj = jj_start; jj < jj_end; jj++) {
584 unsigned j = A.j_[jj];
585 RCP<const Basic> v = A.x_[jj];
587 unsigned kk_start = B.p_[j];
588 unsigned kk_end = B.p_[j + 1];
589 for (
unsigned kk = kk_start; kk < kk_end; kk++) {
590 unsigned k = B.j_[kk];
592 sums[k] =
add(sums[k],
mul(v, B.x_[kk]));
602 for (
unsigned jj = 0; jj < length; jj++) {
604 if (!is_true(
is_zero(*sums[head]))) {
606 C.x_[nnz] = sums[head];
610 unsigned temp = head;
622 void csr_diagonal(
const CSRMatrix &A, DenseMatrix &D)
624 unsigned N =
std::min(A.row_, A.col_);
626 SYMENGINE_ASSERT(D.nrows() == N and D.ncols() == 1);
630 RCP<const Basic> diag;
632 for (
unsigned i = 0; i < N; i++) {
634 row_end = A.p_[i + 1];
638 while (row_start <= row_end) {
639 jj = (row_start + row_end) / 2;
643 }
else if (A.j_[jj] < i) {
656 void csr_scale_rows(CSRMatrix &A,
const DenseMatrix &X)
658 SYMENGINE_ASSERT(A.row_ == X.nrows() and X.ncols() == 1);
660 for (
unsigned i = 0; i < A.row_; i++) {
661 if (is_true(
is_zero(*X.get(i, 0))))
662 throw SymEngineException(
"Scaling factor can't be zero");
663 for (
unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++)
664 A.x_[jj] =
mul(A.x_[jj], X.get(i, 0));
670 void csr_scale_columns(CSRMatrix &A,
const DenseMatrix &X)
672 SYMENGINE_ASSERT(A.col_ == X.nrows() and X.ncols() == 1);
674 const unsigned nnz = A.p_[A.row_];
677 for (i = 0; i < A.col_; i++) {
678 if (is_true(
is_zero(*X.get(i, 0))))
679 throw SymEngineException(
"Scaling factor can't be zero");
682 for (i = 0; i < nnz; i++)
683 A.x_[i] =
mul(A.x_[i], X.get(A.j_[i], 0));
689 void csr_binop_csr_canonical(
690 const CSRMatrix &A,
const CSRMatrix &B, CSRMatrix &C,
691 RCP<const Basic> (&bin_op)(
const RCP<const Basic> &,
692 const RCP<const Basic> &))
694 SYMENGINE_ASSERT(A.row_ == B.row_ and A.col_ == B.col_ and C.row_ == A.row_
695 and C.col_ == A.col_);
702 unsigned A_pos, B_pos, A_end, B_end;
704 for (
unsigned i = 0; i < A.row_; i++) {
711 while (A_pos < A_end and B_pos < B_end) {
712 unsigned A_j = A.j_[A_pos];
713 unsigned B_j = B.j_[B_pos];
716 RCP<const Basic> result = bin_op(A.x_[A_pos], B.x_[B_pos]);
717 if (!is_true(
is_zero(*result))) {
719 C.x_.push_back(result);
724 }
else if (A_j < B_j) {
725 RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
726 if (!is_true(
is_zero(*result))) {
728 C.x_.push_back(result);
734 RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
735 if (!is_true(
is_zero(*result))) {
737 C.x_.push_back(result);
745 while (A_pos < A_end) {
746 RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
747 if (!is_true(
is_zero(*result))) {
748 C.j_.push_back(A.j_[A_pos]);
749 C.x_.push_back(result);
754 while (B_pos < B_end) {
755 RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
756 if (!is_true(
is_zero(*result))) {
757 C.j_.push_back(B.j_[B_pos]);
758 C.x_.push_back(result);
769 if (CSRMatrix::csr_has_duplicates(C.p_, C.j_, A.row_))
770 CSRMatrix::csr_sum_duplicates(C.p_, C.j_, C.x_, A.row_);
Classes and functions relating to the binary operation of addition.
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Main namespace for SymEngine package.
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
tribool is_zero(const Basic &b, const Assumptions *assumptions=nullptr)
Check if a number is zero.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.