sparse_matrix.cpp
1 #include <numeric>
2 #include <symengine/matrix.h>
3 #include <symengine/add.h>
4 #include <symengine/functions.h>
5 #include <symengine/mul.h>
6 #include <symengine/constants.h>
7 #include <symengine/symengine_exception.h>
8 #include <symengine/visitor.h>
9 #include <symengine/test_visitors.h>
10 
11 namespace SymEngine
12 {
13 // ----------------------------- CSRMatrix ------------------------------------
14 CSRMatrix::CSRMatrix() {}
15 
16 CSRMatrix::CSRMatrix(unsigned row, unsigned col) : row_(row), col_(col)
17 {
18  p_ = std::vector<unsigned>(row + 1, 0);
19  SYMENGINE_ASSERT(is_canonical());
20 }
21 
22 CSRMatrix::CSRMatrix(unsigned row, unsigned col, const std::vector<unsigned> &p,
23  const std::vector<unsigned> &j, const vec_basic &x)
24  : p_{p}, j_{j}, x_{x}, row_(row), col_(col)
25 {
26  SYMENGINE_ASSERT(is_canonical());
27 }
28 
29 CSRMatrix::CSRMatrix(unsigned row, unsigned col, std::vector<unsigned> &&p,
30  std::vector<unsigned> &&j, vec_basic &&x)
31  : p_{std::move(p)}, j_{std::move(j)}, x_{std::move(x)}, row_(row), col_(col)
32 {
33  SYMENGINE_ASSERT(is_canonical());
34 }
35 
36 CSRMatrix &CSRMatrix::operator=(CSRMatrix &&other)
37 {
38  col_ = other.col_;
39  row_ = other.row_;
40  p_ = std::move(other.p_);
41  j_ = std::move(other.j_);
42  x_ = std::move(other.x_);
43  return *this;
44 }
45 
47 CSRMatrix::as_vectors() const
48 {
49  auto p = p_, j = j_;
50  auto x = x_;
51  return std::make_tuple(std::move(p), std::move(j), std::move(x));
52 }
53 
54 bool CSRMatrix::eq(const MatrixBase &other) const
55 {
56  unsigned row = this->nrows();
57  if (row != other.nrows() or this->ncols() != other.ncols())
58  return false;
59 
60  if (is_a<CSRMatrix>(other)) {
61  const CSRMatrix &o = down_cast<const CSRMatrix &>(other);
62 
63  if (this->p_[row] != o.p_[row])
64  return false;
65 
66  for (unsigned i = 0; i <= row; i++)
67  if (this->p_[i] != o.p_[i])
68  return false;
69 
70  for (unsigned i = 0; i < this->p_[row]; i++)
71  if ((this->j_[i] != o.j_[i]) or neq(*this->x_[i], *(o.x_[i])))
72  return false;
73 
74  return true;
75  } else {
76  return this->MatrixBase::eq(other);
77  }
78 }
79 
80 bool CSRMatrix::is_canonical() const
81 {
82  if (p_.size() != row_ + 1 or j_.size() != p_[row_] or x_.size() != p_[row_])
83  return false;
84 
85  if (p_[row_] != 0) // Zero matrix is in canonical format
86  return csr_has_canonical_format(p_, j_, row_);
87  return true;
88 }
89 
90 // Get and set elements
91 RCP<const Basic> CSRMatrix::get(unsigned i, unsigned j) const
92 {
93  SYMENGINE_ASSERT(i < row_ and j < col_);
94 
95  unsigned row_start = p_[i];
96  unsigned row_end = p_[i + 1];
97  unsigned k;
98 
99  if (row_start == row_end) {
100  return zero;
101  }
102 
103  while (row_start < row_end) {
104  k = (row_start + row_end) / 2;
105  if (j_[k] == j) {
106  return x_[k];
107  } else if (j_[k] < j) {
108  row_start = k + 1;
109  } else {
110  row_end = k;
111  }
112  }
113 
114  return zero;
115 }
116 
117 void CSRMatrix::set(unsigned i, unsigned j, const RCP<const Basic> &e)
118 {
119  SYMENGINE_ASSERT(i < row_ and j < col_);
120 
121  unsigned k = p_[i];
122  unsigned row_end = p_[i + 1];
123  unsigned end = p_[i + 1];
124  unsigned mid;
125 
126  while (k < end) {
127  mid = (k + end) / 2;
128  if (mid == k) {
129  if (j_[k] < j) {
130  k++;
131  }
132  break;
133  } else if (j_[mid] >= j and j_[mid - 1] < j) {
134  k = mid;
135  break;
136  } else if (j_[mid - 1] >= j) {
137  end = mid - 1;
138  } else {
139  k = mid + 1;
140  }
141  }
142 
143  if (!is_true(is_zero(*e))) {
144  if (k < row_end and j_[k] == j) {
145  x_[k] = e;
146  } else { // j_[k] > j or k is the last non-zero element
147  x_.insert(x_.begin() + k, e);
148  j_.insert(j_.begin() + k, j);
149  for (unsigned l = i + 1; l <= row_; l++)
150  p_[l]++;
151  }
152  } else { // e is zero
153  if (k < row_end and j_[k] == j) { // remove existing non-zero element
154  x_.erase(x_.begin() + k);
155  j_.erase(j_.begin() + k);
156  for (unsigned l = i + 1; l <= row_; l++)
157  p_[l]--;
158  }
159  }
160 }
161 
162 unsigned CSRMatrix::rank() const
163 {
164  throw NotImplementedError("Not Implemented");
165 }
166 
167 RCP<const Basic> CSRMatrix::det() const
168 {
169  throw NotImplementedError("Not Implemented");
170 }
171 
172 void CSRMatrix::inv(MatrixBase &result) const
173 {
174  throw NotImplementedError("Not Implemented");
175 }
176 
177 void CSRMatrix::add_matrix(const MatrixBase &other, MatrixBase &result) const
178 {
179  throw NotImplementedError("Not Implemented");
180 }
181 
182 void CSRMatrix::mul_matrix(const MatrixBase &other, MatrixBase &result) const
183 {
184  throw NotImplementedError("Not Implemented");
185 }
186 
187 void CSRMatrix::elementwise_mul_matrix(const MatrixBase &other,
188  MatrixBase &result) const
189 {
190  if (is_a<CSRMatrix>(result)) {
191  auto &o = down_cast<const CSRMatrix &>(other);
192  auto &r = down_cast<CSRMatrix &>(result);
193  csr_binop_csr_canonical(*this, o, r, mul);
194  }
195 }
196 
197 // Add a scalar
198 void CSRMatrix::add_scalar(const RCP<const Basic> &k, MatrixBase &result) const
199 {
200  throw NotImplementedError("Not Implemented");
201 }
202 
203 // Multiply by a scalar
204 void CSRMatrix::mul_scalar(const RCP<const Basic> &k, MatrixBase &result) const
205 {
206  throw NotImplementedError("Not Implemented");
207 }
208 
209 // Matrix conjugate
210 void CSRMatrix::conjugate(MatrixBase &result) const
211 {
212  if (is_a<CSRMatrix>(result)) {
213  auto &r = down_cast<CSRMatrix &>(result);
214  std::vector<unsigned> p(p_), j(j_);
215  vec_basic x(x_.size());
216  for (unsigned i = 0; i < x_.size(); ++i) {
217  x[i] = SymEngine::conjugate(x_[i]);
218  }
219  r = CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
220  } else {
221  throw NotImplementedError("Not Implemented");
222  }
223 }
224 
225 // Matrix transpose
226 void CSRMatrix::transpose(MatrixBase &result) const
227 {
228  if (is_a<CSRMatrix>(result)) {
229  auto &r = down_cast<CSRMatrix &>(result);
230  r = this->transpose();
231  } else {
232  throw NotImplementedError("Not Implemented");
233  }
234 }
235 
236 CSRMatrix CSRMatrix::transpose(bool conjugate) const
237 {
238  const auto nnz = j_.size();
239  std::vector<unsigned> p(col_ + 1, 0), j(nnz), tmp(col_, 0);
240  vec_basic x(nnz);
241 
242  for (unsigned i = 0; i < nnz; ++i)
243  p[j_[i] + 1]++;
244  std::partial_sum(p.begin(), p.end(), p.begin());
245 
246  for (unsigned ri = 0; ri < row_; ++ri) {
247  for (unsigned i = p_[ri]; i < p_[ri + 1]; ++i) {
248  const auto ci = j_[i];
249  const unsigned k = p[ci] + tmp[ci];
250  j[k] = ri;
251  if (conjugate) {
252  x[k] = SymEngine::conjugate(x_[i]);
253  } else {
254  x[k] = x_[i];
255  }
256  tmp[ci]++;
257  }
258  }
259  return CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
260 }
261 
262 // Matrix conjugate transpose
263 void CSRMatrix::conjugate_transpose(MatrixBase &result) const
264 {
265  if (is_a<CSRMatrix>(result)) {
266  auto &r = down_cast<CSRMatrix &>(result);
267  r = this->transpose(true);
268  } else {
269  throw NotImplementedError("Not Implemented");
270  }
271 }
272 
273 // Extract out a submatrix
274 void CSRMatrix::submatrix(MatrixBase &result, unsigned row_start,
275  unsigned col_start, unsigned row_end,
276  unsigned col_end, unsigned row_step,
277  unsigned col_step) const
278 {
279  throw NotImplementedError("Not Implemented");
280 }
281 
282 // LU factorization
283 void CSRMatrix::LU(MatrixBase &L, MatrixBase &U) const
284 {
285  throw NotImplementedError("Not Implemented");
286 }
287 
288 // LDL factorization
289 void CSRMatrix::LDL(MatrixBase &L, MatrixBase &D) const
290 {
291  throw NotImplementedError("Not Implemented");
292 }
293 
294 // Solve Ax = b using LU factorization
295 void CSRMatrix::LU_solve(const MatrixBase &b, MatrixBase &x) const
296 {
297  throw NotImplementedError("Not Implemented");
298 }
299 
300 // Fraction free LU factorization
301 void CSRMatrix::FFLU(MatrixBase &LU) const
302 {
303  throw NotImplementedError("Not Implemented");
304 }
305 
306 // Fraction free LDU factorization
307 void CSRMatrix::FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U) const
308 {
309  throw NotImplementedError("Not Implemented");
310 }
311 
312 // QR factorization
313 void CSRMatrix::QR(MatrixBase &Q, MatrixBase &R) const
314 {
315  throw NotImplementedError("Not Implemented");
316 }
317 
318 // Cholesky decomposition
319 void CSRMatrix::cholesky(MatrixBase &L) const
320 {
321  throw NotImplementedError("Not Implemented");
322 }
323 
324 void CSRMatrix::csr_sum_duplicates(std::vector<unsigned> &p_,
325  std::vector<unsigned> &j_, vec_basic &x_,
326  unsigned row_)
327 {
328  unsigned nnz = 0;
329  unsigned row_end = 0;
330  unsigned jj = 0, j = 0;
331  RCP<const Basic> x = zero;
332 
333  for (unsigned i = 0; i < row_; i++) {
334  jj = row_end;
335  row_end = p_[i + 1];
336 
337  while (jj < row_end) {
338  j = j_[jj];
339  x = x_[jj];
340  jj++;
341 
342  while (jj < row_end and j_[jj] == j) {
343  x = add(x, x_[jj]);
344  jj++;
345  }
346 
347  j_[nnz] = j;
348  x_[nnz] = x;
349  nnz++;
350  }
351  p_[i + 1] = nnz;
352  }
353 
354  // Resize to discard unnecessary elements
355  j_.resize(nnz);
356  x_.resize(nnz);
357 }
358 
359 void CSRMatrix::csr_sort_indices(std::vector<unsigned> &p_,
360  std::vector<unsigned> &j_, vec_basic &x_,
361  unsigned row_)
362 {
364 
365  for (unsigned i = 0; i < row_; i++) {
366  unsigned row_start = p_[i];
367  unsigned row_end = p_[i + 1];
368 
369  temp.clear();
370 
371  for (unsigned jj = row_start; jj < row_end; jj++) {
372  temp.push_back(std::make_pair(j_[jj], x_[jj]));
373  }
374 
375  std::sort(temp.begin(), temp.end(),
376  [](const std::pair<unsigned, RCP<const Basic>> &x,
377  const std::pair<unsigned, RCP<const Basic>> &y) {
378  return x.first < y.first;
379  });
380 
381  for (unsigned jj = row_start, n = 0; jj < row_end; jj++, n++) {
382  j_[jj] = temp[n].first;
383  x_[jj] = temp[n].second;
384  }
385  }
386 }
387 
388 // Assumes that the indices are sorted
389 bool CSRMatrix::csr_has_duplicates(const std::vector<unsigned> &p_,
390  const std::vector<unsigned> &j_,
391  unsigned row_)
392 {
393  for (unsigned i = 0; i < row_; i++) {
394  for (unsigned j = p_[i]; j + 1 < p_[i + 1]; j++) {
395  if (j_[j] == j_[j + 1])
396  return true;
397  }
398  }
399  return false;
400 }
401 
402 bool CSRMatrix::csr_has_sorted_indices(const std::vector<unsigned> &p_,
403  const std::vector<unsigned> &j_,
404  unsigned row_)
405 {
406  for (unsigned i = 0; i < row_; i++) {
407  for (unsigned jj = p_[i]; jj < p_[i + 1] - 1; jj++) {
408  if (j_[jj] > j_[jj + 1])
409  return false;
410  }
411  }
412  return true;
413 }
414 
415 bool CSRMatrix::csr_has_canonical_format(const std::vector<unsigned> &p_,
416  const std::vector<unsigned> &j_,
417  unsigned row_)
418 {
419  for (unsigned i = 0; i < row_; i++) {
420  if (p_[i] > p_[i + 1])
421  return false;
422  }
423 
424  return csr_has_sorted_indices(p_, j_, row_)
425  and not csr_has_duplicates(p_, j_, row_);
426 }
427 
428 CSRMatrix CSRMatrix::from_coo(unsigned row, unsigned col,
429  const std::vector<unsigned> &i,
430  const std::vector<unsigned> &j,
431  const vec_basic &x)
432 {
433  // cast is okay, because CSRMatrix indices are unsigned.
434  unsigned nnz = numeric_cast<unsigned>(x.size());
437  vec_basic x_ = vec_basic(nnz);
438 
439  for (unsigned n = 0; n < nnz; n++) {
440  p_[i[n]]++;
441  }
442 
443  // cumsum the nnz per row to get p
444  unsigned temp;
445  for (unsigned i = 0, cumsum = 0; i < row; i++) {
446  temp = p_[i];
447  p_[i] = cumsum;
448  cumsum += temp;
449  }
450  p_[row] = nnz;
451 
452  // write j, x into j_, x_
453  unsigned row_, dest_;
454  for (unsigned n = 0; n < nnz; n++) {
455  row_ = i[n];
456  dest_ = p_[row_];
457 
458  j_[dest_] = j[n];
459  x_[dest_] = x[n];
460 
461  p_[row_]++;
462  }
463 
464  for (unsigned i = 0, last = 0; i <= row; i++) {
465  std::swap(p_[i], last);
466  }
467 
468  // sort indices
469  csr_sort_indices(p_, j_, x_, row);
470  // Remove duplicates
471  csr_sum_duplicates(p_, j_, x_, row);
472 
473  CSRMatrix B
474  = CSRMatrix(row, col, std::move(p_), std::move(j_), std::move(x_));
475  return B;
476 }
477 
478 CSRMatrix CSRMatrix::jacobian(const vec_basic &exprs, const vec_sym &x,
479  bool diff_cache)
480 {
481  const unsigned nrows = static_cast<unsigned>(exprs.size());
482  const unsigned ncols = static_cast<unsigned>(x.size());
483  std::vector<unsigned> p(1, 0), j;
484  vec_basic elems;
485  p.reserve(nrows + 1);
486  j.reserve(nrows);
487  elems.reserve(nrows);
488  for (unsigned ri = 0; ri < nrows; ++ri) {
489  p.push_back(p.back());
490  for (unsigned ci = 0; ci < ncols; ++ci) {
491  auto elem = exprs[ri]->diff(x[ci], diff_cache);
492  if (!is_true(is_zero(*elem))) {
493  p.back()++;
494  j.push_back(ci);
495  elems.emplace_back(std::move(elem));
496  }
497  }
498  }
499  return CSRMatrix(nrows, ncols, std::move(p), std::move(j),
500  std::move(elems));
501 }
502 
503 CSRMatrix CSRMatrix::jacobian(const DenseMatrix &A, const DenseMatrix &x,
504  bool diff_cache)
505 {
506  SYMENGINE_ASSERT(A.col_ == 1);
507  SYMENGINE_ASSERT(x.col_ == 1);
508  vec_sym syms;
509  syms.reserve(x.row_);
510  for (const auto &dx : x.m_) {
511  if (!is_a<Symbol>(*dx)) {
512  throw SymEngineException("'x' must contain Symbols only");
513  }
514  syms.push_back(rcp_static_cast<const Symbol>(dx));
515  }
516  return CSRMatrix::jacobian(A.m_, syms, diff_cache);
517 }
518 
519 void csr_matmat_pass1(const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C)
520 {
521  // method that uses O(n) temp storage
522  std::vector<unsigned> mask(A.col_, -1);
523  C.p_[0] = 0;
524 
525  unsigned nnz = 0;
526  for (unsigned i = 0; i < A.row_; i++) {
527  // npy_intp row_nnz = 0;
528  unsigned row_nnz = 0;
529 
530  for (unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++) {
531  unsigned j = A.j_[jj];
532  for (unsigned kk = B.p_[j]; kk < B.p_[j + 1]; kk++) {
533  unsigned k = B.j_[kk];
534  if (mask[k] != i) {
535  mask[k] = i;
536  row_nnz++;
537  }
538  }
539  }
540 
541  unsigned next_nnz = nnz + row_nnz;
542 
543  // Addition overflow: http://www.cplusplus.com/articles/DE18T05o/
544  if (next_nnz < nnz) {
545  throw std::overflow_error("nnz of the result is too large");
546  }
547 
548  nnz = next_nnz;
549  C.p_[i + 1] = nnz;
550  }
551 }
552 
553 // Pass 2 computes CSR entries for matrix C = A*B using the
554 // row pointer Cp[] computed in Pass 1.
555 void csr_matmat_pass2(const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C)
556 {
557  std::vector<int> next(A.col_, -1);
558  vec_basic sums(A.col_, zero);
559 
560  unsigned nnz = 0;
561 
562  C.p_[0] = 0;
563 
564  for (unsigned i = 0; i < A.row_; i++) {
565  int head = -2;
566  unsigned length = 0;
567 
568  unsigned jj_start = A.p_[i];
569  unsigned jj_end = A.p_[i + 1];
570  for (unsigned jj = jj_start; jj < jj_end; jj++) {
571  unsigned j = A.j_[jj];
572  RCP<const Basic> v = A.x_[jj];
573 
574  unsigned kk_start = B.p_[j];
575  unsigned kk_end = B.p_[j + 1];
576  for (unsigned kk = kk_start; kk < kk_end; kk++) {
577  unsigned k = B.j_[kk];
578 
579  sums[k] = add(sums[k], mul(v, B.x_[kk]));
580 
581  if (next[k] == -1) {
582  next[k] = head;
583  head = k;
584  length++;
585  }
586  }
587  }
588 
589  for (unsigned jj = 0; jj < length; jj++) {
590 
591  if (!is_true(is_zero(*sums[head]))) {
592  C.j_[nnz] = head;
593  C.x_[nnz] = sums[head];
594  nnz++;
595  }
596 
597  unsigned temp = head;
598  head = next[head];
599 
600  next[temp] = -1; // clear arrays
601  sums[temp] = zero;
602  }
603 
604  C.p_[i + 1] = nnz;
605  }
606 }
607 
608 // Extract main diagonal of CSR matrix A
609 void csr_diagonal(const CSRMatrix &A, DenseMatrix &D)
610 {
611  unsigned N = std::min(A.row_, A.col_);
612 
613  SYMENGINE_ASSERT(D.nrows() == N and D.ncols() == 1);
614 
615  unsigned row_start;
616  unsigned row_end;
617  RCP<const Basic> diag;
618 
619  for (unsigned i = 0; i < N; i++) {
620  row_start = A.p_[i];
621  row_end = A.p_[i + 1];
622  diag = zero;
623  unsigned jj;
624 
625  while (row_start <= row_end) {
626  jj = (row_start + row_end) / 2;
627  if (A.j_[jj] == i) {
628  diag = A.x_[jj];
629  break;
630  } else if (A.j_[jj] < i) {
631  row_start = jj + 1;
632  } else {
633  row_end = jj - 1;
634  }
635  }
636 
637  D.set(i, 0, diag);
638  }
639 }
640 
641 // Scale the rows of a CSR matrix *in place*
642 // A[i, :] *= X[i]
643 void csr_scale_rows(CSRMatrix &A, const DenseMatrix &X)
644 {
645  SYMENGINE_ASSERT(A.row_ == X.nrows() and X.ncols() == 1);
646 
647  for (unsigned i = 0; i < A.row_; i++) {
648  if (is_true(is_zero(*X.get(i, 0))))
649  throw SymEngineException("Scaling factor can't be zero");
650  for (unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++)
651  A.x_[jj] = mul(A.x_[jj], X.get(i, 0));
652  }
653 }
654 
655 // Scale the columns of a CSR matrix *in place*
656 // A[:, i] *= X[i]
657 void csr_scale_columns(CSRMatrix &A, const DenseMatrix &X)
658 {
659  SYMENGINE_ASSERT(A.col_ == X.nrows() and X.ncols() == 1);
660 
661  const unsigned nnz = A.p_[A.row_];
662  unsigned i;
663 
664  for (i = 0; i < A.col_; i++) {
665  if (is_true(is_zero(*X.get(i, 0))))
666  throw SymEngineException("Scaling factor can't be zero");
667  }
668 
669  for (i = 0; i < nnz; i++)
670  A.x_[i] = mul(A.x_[i], X.get(A.j_[i], 0));
671 }
672 
673 // Compute C = A (binary_op) B for CSR matrices that are in the
674 // canonical CSR format. Matrix dimensions of A and B should be the
675 // same. C will be in canonical format as well.
676 void csr_binop_csr_canonical(
677  const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C,
678  RCP<const Basic> (&bin_op)(const RCP<const Basic> &,
679  const RCP<const Basic> &))
680 {
681  SYMENGINE_ASSERT(A.row_ == B.row_ and A.col_ == B.col_ and C.row_ == A.row_
682  and C.col_ == A.col_);
683 
684  // Method that works for canonical CSR matrices
685  C.p_[0] = 0;
686  C.j_.clear();
687  C.x_.clear();
688  unsigned nnz = 0;
689  unsigned A_pos, B_pos, A_end, B_end;
690 
691  for (unsigned i = 0; i < A.row_; i++) {
692  A_pos = A.p_[i];
693  B_pos = B.p_[i];
694  A_end = A.p_[i + 1];
695  B_end = B.p_[i + 1];
696 
697  // while not finished with either row
698  while (A_pos < A_end and B_pos < B_end) {
699  unsigned A_j = A.j_[A_pos];
700  unsigned B_j = B.j_[B_pos];
701 
702  if (A_j == B_j) {
703  RCP<const Basic> result = bin_op(A.x_[A_pos], B.x_[B_pos]);
704  if (!is_true(is_zero(*result))) {
705  C.j_.push_back(A_j);
706  C.x_.push_back(result);
707  nnz++;
708  }
709  A_pos++;
710  B_pos++;
711  } else if (A_j < B_j) {
712  RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
713  if (!is_true(is_zero(*result))) {
714  C.j_.push_back(A_j);
715  C.x_.push_back(result);
716  nnz++;
717  }
718  A_pos++;
719  } else {
720  // B_j < A_j
721  RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
722  if (!is_true(is_zero(*result))) {
723  C.j_.push_back(B_j);
724  C.x_.push_back(result);
725  nnz++;
726  }
727  B_pos++;
728  }
729  }
730 
731  // tail
732  while (A_pos < A_end) {
733  RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
734  if (!is_true(is_zero(*result))) {
735  C.j_.push_back(A.j_[A_pos]);
736  C.x_.push_back(result);
737  nnz++;
738  }
739  A_pos++;
740  }
741  while (B_pos < B_end) {
742  RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
743  if (!is_true(is_zero(*result))) {
744  C.j_.push_back(B.j_[B_pos]);
745  C.x_.push_back(result);
746  nnz++;
747  }
748  B_pos++;
749  }
750 
751  C.p_[i + 1] = nnz;
752  }
753 
754  // It's enough to check for duplicates as the column indices
755  // remain sorted after the above operations
756  if (CSRMatrix::csr_has_duplicates(C.p_, C.j_, A.row_))
757  CSRMatrix::csr_sum_duplicates(C.p_, C.j_, C.x_, A.row_);
758 }
759 
760 } // namespace SymEngine
Classes and functions relating to the binary operation of addition.
T back(T... args)
T begin(T... args)
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
T clear(T... args)
T end(T... args)
T make_pair(T... args)
T make_tuple(T... args)
T min(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:347
tribool is_zero(const Basic &b, const Assumptions *assumptions=nullptr)
Check if a number is zero.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.
Definition: functions.cpp:106
STL namespace.
T next(T... args)
T partial_sum(T... args)
T push_back(T... args)
T reserve(T... args)
T resize(T... args)
T sort(T... args)
T swap(T... args)