2 #include <symengine/matrix.h>
7 #include <symengine/symengine_exception.h>
9 #include <symengine/test_visitors.h>
14 CSRMatrix::CSRMatrix() {}
16 CSRMatrix::CSRMatrix(
unsigned row,
unsigned col) : row_(row), col_(col)
18 p_ = std::vector<unsigned>(row + 1, 0);
19 SYMENGINE_ASSERT(is_canonical());
22 CSRMatrix::CSRMatrix(
unsigned row,
unsigned col,
const std::vector<unsigned> &p,
23 const std::vector<unsigned> &j,
const vec_basic &x)
24 : p_{p}, j_{j}, x_{x}, row_(row), col_(col)
26 SYMENGINE_ASSERT(is_canonical());
29 CSRMatrix::CSRMatrix(
unsigned row,
unsigned col, std::vector<unsigned> &&p,
30 std::vector<unsigned> &&j, vec_basic &&x)
31 : p_{std::move(p)}, j_{std::move(j)}, x_{std::move(x)}, row_(row), col_(col)
33 SYMENGINE_ASSERT(is_canonical());
36 CSRMatrix &CSRMatrix::operator=(CSRMatrix &&other)
40 p_ = std::move(other.p_);
41 j_ = std::move(other.j_);
42 x_ = std::move(other.x_);
46 std::tuple<std::vector<unsigned>, std::vector<unsigned>, vec_basic>
47 CSRMatrix::as_vectors()
const
51 return std::make_tuple(std::move(p), std::move(j), std::move(x));
54 bool CSRMatrix::eq(
const MatrixBase &other)
const
56 unsigned row = this->nrows();
57 if (row != other.nrows() or this->ncols() != other.ncols())
60 if (is_a<CSRMatrix>(other)) {
61 const CSRMatrix &o = down_cast<const CSRMatrix &>(other);
63 if (this->p_[row] != o.p_[row])
66 for (
unsigned i = 0; i <= row; i++)
67 if (this->p_[i] != o.p_[i])
70 for (
unsigned i = 0; i < this->p_[row]; i++)
71 if ((this->j_[i] != o.j_[i]) or
neq(*this->x_[i], *(o.x_[i])))
76 return this->MatrixBase::eq(other);
80 bool CSRMatrix::is_canonical()
const
82 if (p_.size() != row_ + 1 or j_.size() != p_[row_] or x_.size() != p_[row_])
86 return csr_has_canonical_format(p_, j_, row_);
91 RCP<const Basic> CSRMatrix::get(
unsigned i,
unsigned j)
const
93 SYMENGINE_ASSERT(i < row_ and j < col_);
95 unsigned row_start = p_[i];
96 unsigned row_end = p_[i + 1];
99 if (row_start == row_end) {
103 while (row_start < row_end) {
104 k = (row_start + row_end) / 2;
107 }
else if (j_[k] < j) {
117 void CSRMatrix::set(
unsigned i,
unsigned j,
const RCP<const Basic> &e)
119 SYMENGINE_ASSERT(i < row_ and j < col_);
122 unsigned row_end = p_[i + 1];
123 unsigned end = p_[i + 1];
134 SYMENGINE_ASSERT(mid > 0);
135 if (j_[mid] >= j and j_[mid - 1] < j) {
138 }
else if (j_[mid - 1] >= j) {
146 if (k < row_end and j_[k] == j) {
149 x_.insert(x_.begin() + k, e);
150 j_.insert(j_.begin() + k, j);
151 for (
unsigned l = i + 1; l <= row_; l++)
155 if (k < row_end and j_[k] == j) {
156 x_.erase(x_.begin() + k);
157 j_.erase(j_.begin() + k);
158 for (
unsigned l = i + 1; l <= row_; l++)
164 tribool CSRMatrix::is_real(
const Assumptions *assumptions)
const
166 RealVisitor visitor(assumptions);
167 tribool cur = tribool::tritrue;
169 cur = and_tribool(cur, visitor.apply(*e));
177 unsigned CSRMatrix::rank()
const
179 throw NotImplementedError(
"Not Implemented");
182 RCP<const Basic> CSRMatrix::det()
const
184 throw NotImplementedError(
"Not Implemented");
187 void CSRMatrix::inv(MatrixBase &result)
const
189 throw NotImplementedError(
"Not Implemented");
192 void CSRMatrix::add_matrix(
const MatrixBase &other, MatrixBase &result)
const
194 throw NotImplementedError(
"Not Implemented");
197 void CSRMatrix::mul_matrix(
const MatrixBase &other, MatrixBase &result)
const
199 throw NotImplementedError(
"Not Implemented");
202 void CSRMatrix::elementwise_mul_matrix(
const MatrixBase &other,
203 MatrixBase &result)
const
205 if (is_a<CSRMatrix>(result)) {
206 auto &o = down_cast<const CSRMatrix &>(other);
207 auto &r = down_cast<CSRMatrix &>(result);
208 csr_binop_csr_canonical(*
this, o, r, mul);
213 void CSRMatrix::add_scalar(
const RCP<const Basic> &k, MatrixBase &result)
const
215 throw NotImplementedError(
"Not Implemented");
219 void CSRMatrix::mul_scalar(
const RCP<const Basic> &k, MatrixBase &result)
const
221 throw NotImplementedError(
"Not Implemented");
225 void CSRMatrix::conjugate(MatrixBase &result)
const
227 if (is_a<CSRMatrix>(result)) {
228 auto &r = down_cast<CSRMatrix &>(result);
229 std::vector<unsigned> p(p_), j(j_);
230 vec_basic x(x_.size());
231 for (
unsigned i = 0; i < x_.size(); ++i) {
234 r = CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
236 throw NotImplementedError(
"Not Implemented");
241 void CSRMatrix::transpose(MatrixBase &result)
const
243 if (is_a<CSRMatrix>(result)) {
244 auto &r = down_cast<CSRMatrix &>(result);
245 r = this->transpose();
247 throw NotImplementedError(
"Not Implemented");
251 CSRMatrix CSRMatrix::transpose(
bool conjugate)
const
253 const auto nnz = j_.size();
254 std::vector<unsigned> p(col_ + 1, 0), j(nnz), tmp(col_, 0);
257 for (
unsigned i = 0; i < nnz; ++i)
259 std::partial_sum(p.begin(), p.end(), p.begin());
261 for (
unsigned ri = 0; ri < row_; ++ri) {
262 for (
unsigned i = p_[ri]; i < p_[ri + 1]; ++i) {
263 const auto ci = j_[i];
264 const unsigned k = p[ci] + tmp[ci];
274 return CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
278 void CSRMatrix::conjugate_transpose(MatrixBase &result)
const
280 if (is_a<CSRMatrix>(result)) {
281 auto &r = down_cast<CSRMatrix &>(result);
282 r = this->transpose(
true);
284 throw NotImplementedError(
"Not Implemented");
289 void CSRMatrix::submatrix(MatrixBase &result,
unsigned row_start,
290 unsigned col_start,
unsigned row_end,
291 unsigned col_end,
unsigned row_step,
292 unsigned col_step)
const
294 throw NotImplementedError(
"Not Implemented");
298 void CSRMatrix::LU(MatrixBase &L, MatrixBase &U)
const
300 throw NotImplementedError(
"Not Implemented");
304 void CSRMatrix::LDL(MatrixBase &L, MatrixBase &D)
const
306 throw NotImplementedError(
"Not Implemented");
310 void CSRMatrix::LU_solve(
const MatrixBase &b, MatrixBase &x)
const
312 throw NotImplementedError(
"Not Implemented");
316 void CSRMatrix::FFLU(MatrixBase &LU)
const
318 throw NotImplementedError(
"Not Implemented");
322 void CSRMatrix::FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U)
const
324 throw NotImplementedError(
"Not Implemented");
328 void CSRMatrix::QR(MatrixBase &Q, MatrixBase &R)
const
330 throw NotImplementedError(
"Not Implemented");
334 void CSRMatrix::cholesky(MatrixBase &L)
const
336 throw NotImplementedError(
"Not Implemented");
339 void CSRMatrix::csr_sum_duplicates(std::vector<unsigned> &p_,
340 std::vector<unsigned> &j_, vec_basic &x_,
344 unsigned row_end = 0;
345 unsigned jj = 0, j = 0;
346 RCP<const Basic> x = zero;
348 for (
unsigned i = 0; i < row_; i++) {
352 while (jj < row_end) {
357 while (jj < row_end and j_[jj] == j) {
374 void CSRMatrix::csr_sort_indices(std::vector<unsigned> &p_,
375 std::vector<unsigned> &j_, vec_basic &x_,
378 std::vector<std::pair<unsigned, RCP<const Basic>>> temp;
380 for (
unsigned i = 0; i < row_; i++) {
381 unsigned row_start = p_[i];
382 unsigned row_end = p_[i + 1];
386 for (
unsigned jj = row_start; jj < row_end; jj++) {
387 temp.push_back(std::make_pair(j_[jj], x_[jj]));
390 std::sort(temp.begin(), temp.end(),
391 [](
const std::pair<
unsigned, RCP<const Basic>> &x,
392 const std::pair<
unsigned, RCP<const Basic>> &y) {
393 return x.first < y.first;
396 for (
unsigned jj = row_start, n = 0; jj < row_end; jj++, n++) {
397 j_[jj] = temp[n].first;
398 x_[jj] = temp[n].second;
404 bool CSRMatrix::csr_has_duplicates(
const std::vector<unsigned> &p_,
405 const std::vector<unsigned> &j_,
408 for (
unsigned i = 0; i < row_; i++) {
409 for (
unsigned j = p_[i]; j + 1 < p_[i + 1]; j++) {
410 if (j_[j] == j_[j + 1])
417 bool CSRMatrix::csr_has_sorted_indices(
const std::vector<unsigned> &p_,
418 const std::vector<unsigned> &j_,
421 for (
unsigned i = 0; i < row_; i++) {
422 for (
unsigned jj = p_[i]; jj + 1 < p_[i + 1]; jj++) {
423 if (j_[jj] > j_[jj + 1])
430 bool CSRMatrix::csr_has_canonical_format(
const std::vector<unsigned> &p_,
431 const std::vector<unsigned> &j_,
434 for (
unsigned i = 0; i < row_; i++) {
435 if (p_[i] > p_[i + 1])
439 return csr_has_sorted_indices(p_, j_, row_)
440 and not csr_has_duplicates(p_, j_, row_);
443 CSRMatrix CSRMatrix::from_coo(
unsigned row,
unsigned col,
444 const std::vector<unsigned> &i,
445 const std::vector<unsigned> &j,
449 unsigned nnz = numeric_cast<unsigned>(x.size());
450 std::vector<unsigned> p_ = std::vector<unsigned>(row + 1, 0);
451 std::vector<unsigned> j_ = std::vector<unsigned>(nnz);
452 vec_basic x_ = vec_basic(nnz);
454 for (
unsigned n = 0; n < nnz; n++) {
460 for (
unsigned i = 0, cumsum = 0; i < row; i++) {
468 unsigned row_, dest_;
469 for (
unsigned n = 0; n < nnz; n++) {
479 for (
unsigned i = 0, last = 0; i <= row; i++) {
480 std::swap(p_[i], last);
484 csr_sort_indices(p_, j_, x_, row);
486 csr_sum_duplicates(p_, j_, x_, row);
489 = CSRMatrix(row, col, std::move(p_), std::move(j_), std::move(x_));
493 CSRMatrix CSRMatrix::jacobian(
const vec_basic &exprs,
const vec_sym &x,
496 const unsigned nrows =
static_cast<unsigned>(exprs.size());
497 const unsigned ncols =
static_cast<unsigned>(x.size());
498 std::vector<unsigned> p(1, 0), j;
500 p.reserve(nrows + 1);
502 elems.reserve(nrows);
503 for (
unsigned ri = 0; ri < nrows; ++ri) {
504 p.push_back(p.back());
505 for (
unsigned ci = 0; ci < ncols; ++ci) {
506 auto elem = exprs[ri]->diff(x[ci], diff_cache);
507 if (!is_true(
is_zero(*elem))) {
510 elems.emplace_back(std::move(elem));
514 return CSRMatrix(nrows, ncols, std::move(p), std::move(j),
518 CSRMatrix CSRMatrix::jacobian(
const DenseMatrix &A,
const DenseMatrix &x,
521 SYMENGINE_ASSERT(A.col_ == 1);
522 SYMENGINE_ASSERT(x.col_ == 1);
524 syms.reserve(x.row_);
525 for (
const auto &dx : x.m_) {
526 if (!is_a<Symbol>(*dx)) {
527 throw SymEngineException(
"'x' must contain Symbols only");
529 syms.push_back(rcp_static_cast<const Symbol>(dx));
531 return CSRMatrix::jacobian(A.m_, syms, diff_cache);
534 void csr_matmat_pass1(
const CSRMatrix &A,
const CSRMatrix &B, CSRMatrix &C)
537 std::vector<unsigned> mask(A.col_, -1);
541 for (
unsigned i = 0; i < A.row_; i++) {
543 unsigned row_nnz = 0;
545 for (
unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++) {
546 unsigned j = A.j_[jj];
547 for (
unsigned kk = B.p_[j]; kk < B.p_[j + 1]; kk++) {
548 unsigned k = B.j_[kk];
556 unsigned next_nnz = nnz + row_nnz;
559 if (next_nnz < nnz) {
560 throw std::overflow_error(
"nnz of the result is too large");
570 void csr_matmat_pass2(
const CSRMatrix &A,
const CSRMatrix &B, CSRMatrix &C)
572 std::vector<int> next(A.col_, -1);
573 vec_basic sums(A.col_, zero);
579 for (
unsigned i = 0; i < A.row_; i++) {
583 unsigned jj_start = A.p_[i];
584 unsigned jj_end = A.p_[i + 1];
585 for (
unsigned jj = jj_start; jj < jj_end; jj++) {
586 unsigned j = A.j_[jj];
587 RCP<const Basic> v = A.x_[jj];
589 unsigned kk_start = B.p_[j];
590 unsigned kk_end = B.p_[j + 1];
591 for (
unsigned kk = kk_start; kk < kk_end; kk++) {
592 unsigned k = B.j_[kk];
594 sums[k] =
add(sums[k],
mul(v, B.x_[kk]));
604 for (
unsigned jj = 0; jj < length; jj++) {
606 if (!is_true(
is_zero(*sums[head]))) {
608 C.x_[nnz] = sums[head];
612 unsigned temp = head;
624 void csr_diagonal(
const CSRMatrix &A, DenseMatrix &D)
626 unsigned N = std::min(A.row_, A.col_);
628 SYMENGINE_ASSERT(D.nrows() == N and D.ncols() == 1);
632 RCP<const Basic> diag;
634 for (
unsigned i = 0; i < N; i++) {
636 row_end = A.p_[i + 1];
640 while (row_start <= row_end) {
641 jj = (row_start + row_end) / 2;
645 }
else if (A.j_[jj] < i) {
648 SYMENGINE_ASSERT(jj > 0);
659 void csr_scale_rows(CSRMatrix &A,
const DenseMatrix &X)
661 SYMENGINE_ASSERT(A.row_ == X.nrows() and X.ncols() == 1);
663 for (
unsigned i = 0; i < A.row_; i++) {
664 if (is_true(
is_zero(*X.get(i, 0))))
665 throw SymEngineException(
"Scaling factor can't be zero");
666 for (
unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++)
667 A.x_[jj] =
mul(A.x_[jj], X.get(i, 0));
673 void csr_scale_columns(CSRMatrix &A,
const DenseMatrix &X)
675 SYMENGINE_ASSERT(A.col_ == X.nrows() and X.ncols() == 1);
677 const unsigned nnz = A.p_[A.row_];
680 for (i = 0; i < A.col_; i++) {
681 if (is_true(
is_zero(*X.get(i, 0))))
682 throw SymEngineException(
"Scaling factor can't be zero");
685 for (i = 0; i < nnz; i++)
686 A.x_[i] =
mul(A.x_[i], X.get(A.j_[i], 0));
692 void csr_binop_csr_canonical(
693 const CSRMatrix &A,
const CSRMatrix &B, CSRMatrix &C,
694 RCP<const Basic> (&bin_op)(
const RCP<const Basic> &,
695 const RCP<const Basic> &))
697 SYMENGINE_ASSERT(A.row_ == B.row_ and A.col_ == B.col_ and C.row_ == A.row_
698 and C.col_ == A.col_);
705 unsigned A_pos, B_pos, A_end, B_end;
707 for (
unsigned i = 0; i < A.row_; i++) {
714 while (A_pos < A_end and B_pos < B_end) {
715 unsigned A_j = A.j_[A_pos];
716 unsigned B_j = B.j_[B_pos];
719 RCP<const Basic> result = bin_op(A.x_[A_pos], B.x_[B_pos]);
720 if (!is_true(
is_zero(*result))) {
722 C.x_.push_back(result);
727 }
else if (A_j < B_j) {
728 RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
729 if (!is_true(
is_zero(*result))) {
731 C.x_.push_back(result);
737 RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
738 if (!is_true(
is_zero(*result))) {
740 C.x_.push_back(result);
748 while (A_pos < A_end) {
749 RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
750 if (!is_true(
is_zero(*result))) {
751 C.j_.push_back(A.j_[A_pos]);
752 C.x_.push_back(result);
757 while (B_pos < B_end) {
758 RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
759 if (!is_true(
is_zero(*result))) {
760 C.j_.push_back(B.j_[B_pos]);
761 C.x_.push_back(result);
772 if (CSRMatrix::csr_has_duplicates(C.p_, C.j_, A.row_))
773 CSRMatrix::csr_sum_duplicates(C.p_, C.j_, C.x_, A.row_);
Classes and functions relating to the binary operation of addition.
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Main namespace for SymEngine package.
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
tribool is_zero(const Basic &b, const Assumptions *assumptions=nullptr)
Check if a number is zero.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.