sparse_matrix.cpp
1 #include <numeric>
2 #include <symengine/matrix.h>
3 #include <symengine/add.h>
4 #include <symengine/functions.h>
5 #include <symengine/mul.h>
6 #include <symengine/constants.h>
7 #include <symengine/symengine_exception.h>
8 #include <symengine/visitor.h>
9 #include <symengine/test_visitors.h>
10 
11 namespace SymEngine
12 {
13 // ----------------------------- CSRMatrix ------------------------------------
14 CSRMatrix::CSRMatrix() {}
15 
16 CSRMatrix::CSRMatrix(unsigned row, unsigned col) : row_(row), col_(col)
17 {
18  p_ = std::vector<unsigned>(row + 1, 0);
19  SYMENGINE_ASSERT(is_canonical());
20 }
21 
22 CSRMatrix::CSRMatrix(unsigned row, unsigned col, const std::vector<unsigned> &p,
23  const std::vector<unsigned> &j, const vec_basic &x)
24  : p_{p}, j_{j}, x_{x}, row_(row), col_(col)
25 {
26  SYMENGINE_ASSERT(is_canonical());
27 }
28 
29 CSRMatrix::CSRMatrix(unsigned row, unsigned col, std::vector<unsigned> &&p,
30  std::vector<unsigned> &&j, vec_basic &&x)
31  : p_{std::move(p)}, j_{std::move(j)}, x_{std::move(x)}, row_(row), col_(col)
32 {
33  SYMENGINE_ASSERT(is_canonical());
34 }
35 
36 CSRMatrix &CSRMatrix::operator=(CSRMatrix &&other)
37 {
38  col_ = other.col_;
39  row_ = other.row_;
40  p_ = std::move(other.p_);
41  j_ = std::move(other.j_);
42  x_ = std::move(other.x_);
43  return *this;
44 }
45 
47 CSRMatrix::as_vectors() const
48 {
49  auto p = p_, j = j_;
50  auto x = x_;
51  return std::make_tuple(std::move(p), std::move(j), std::move(x));
52 }
53 
54 bool CSRMatrix::eq(const MatrixBase &other) const
55 {
56  unsigned row = this->nrows();
57  if (row != other.nrows() or this->ncols() != other.ncols())
58  return false;
59 
60  if (is_a<CSRMatrix>(other)) {
61  const CSRMatrix &o = down_cast<const CSRMatrix &>(other);
62 
63  if (this->p_[row] != o.p_[row])
64  return false;
65 
66  for (unsigned i = 0; i <= row; i++)
67  if (this->p_[i] != o.p_[i])
68  return false;
69 
70  for (unsigned i = 0; i < this->p_[row]; i++)
71  if ((this->j_[i] != o.j_[i]) or neq(*this->x_[i], *(o.x_[i])))
72  return false;
73 
74  return true;
75  } else {
76  return this->MatrixBase::eq(other);
77  }
78 }
79 
80 bool CSRMatrix::is_canonical() const
81 {
82  if (p_.size() != row_ + 1 or j_.size() != p_[row_] or x_.size() != p_[row_])
83  return false;
84 
85  if (p_[row_] != 0) // Zero matrix is in canonical format
86  return csr_has_canonical_format(p_, j_, row_);
87  return true;
88 }
89 
90 // Get and set elements
91 RCP<const Basic> CSRMatrix::get(unsigned i, unsigned j) const
92 {
93  SYMENGINE_ASSERT(i < row_ and j < col_);
94 
95  unsigned row_start = p_[i];
96  unsigned row_end = p_[i + 1];
97  unsigned k;
98 
99  if (row_start == row_end) {
100  return zero;
101  }
102 
103  while (row_start < row_end) {
104  k = (row_start + row_end) / 2;
105  if (j_[k] == j) {
106  return x_[k];
107  } else if (j_[k] < j) {
108  row_start = k + 1;
109  } else {
110  row_end = k;
111  }
112  }
113 
114  return zero;
115 }
116 
117 void CSRMatrix::set(unsigned i, unsigned j, const RCP<const Basic> &e)
118 {
119  SYMENGINE_ASSERT(i < row_ and j < col_);
120 
121  unsigned k = p_[i];
122  unsigned row_end = p_[i + 1];
123  unsigned end = p_[i + 1];
124  unsigned mid;
125 
126  while (k < end) {
127  mid = (k + end) / 2;
128  if (mid == k) {
129  if (j_[k] < j) {
130  k++;
131  }
132  break;
133  } else if (j_[mid] >= j and j_[mid - 1] < j) {
134  k = mid;
135  break;
136  } else if (j_[mid - 1] >= j) {
137  end = mid - 1;
138  } else {
139  k = mid + 1;
140  }
141  }
142 
143  if (!is_true(is_zero(*e))) {
144  if (k < row_end and j_[k] == j) {
145  x_[k] = e;
146  } else { // j_[k] > j or k is the last non-zero element
147  x_.insert(x_.begin() + k, e);
148  j_.insert(j_.begin() + k, j);
149  for (unsigned l = i + 1; l <= row_; l++)
150  p_[l]++;
151  }
152  } else { // e is zero
153  if (k < row_end and j_[k] == j) { // remove existing non-zero element
154  x_.erase(x_.begin() + k);
155  j_.erase(j_.begin() + k);
156  for (unsigned l = i + 1; l <= row_; l++)
157  p_[l]--;
158  }
159  }
160 }
161 
162 tribool CSRMatrix::is_real(const Assumptions *assumptions) const
163 {
164  RealVisitor visitor(assumptions);
165  tribool cur = tribool::tritrue;
166  for (auto &e : x_) {
167  cur = and_tribool(cur, visitor.apply(*e));
168  if (is_false(cur)) {
169  return cur;
170  }
171  }
172  return cur;
173 }
174 
175 unsigned CSRMatrix::rank() const
176 {
177  throw NotImplementedError("Not Implemented");
178 }
179 
180 RCP<const Basic> CSRMatrix::det() const
181 {
182  throw NotImplementedError("Not Implemented");
183 }
184 
185 void CSRMatrix::inv(MatrixBase &result) const
186 {
187  throw NotImplementedError("Not Implemented");
188 }
189 
190 void CSRMatrix::add_matrix(const MatrixBase &other, MatrixBase &result) const
191 {
192  throw NotImplementedError("Not Implemented");
193 }
194 
195 void CSRMatrix::mul_matrix(const MatrixBase &other, MatrixBase &result) const
196 {
197  throw NotImplementedError("Not Implemented");
198 }
199 
200 void CSRMatrix::elementwise_mul_matrix(const MatrixBase &other,
201  MatrixBase &result) const
202 {
203  if (is_a<CSRMatrix>(result)) {
204  auto &o = down_cast<const CSRMatrix &>(other);
205  auto &r = down_cast<CSRMatrix &>(result);
206  csr_binop_csr_canonical(*this, o, r, mul);
207  }
208 }
209 
210 // Add a scalar
211 void CSRMatrix::add_scalar(const RCP<const Basic> &k, MatrixBase &result) const
212 {
213  throw NotImplementedError("Not Implemented");
214 }
215 
216 // Multiply by a scalar
217 void CSRMatrix::mul_scalar(const RCP<const Basic> &k, MatrixBase &result) const
218 {
219  throw NotImplementedError("Not Implemented");
220 }
221 
222 // Matrix conjugate
223 void CSRMatrix::conjugate(MatrixBase &result) const
224 {
225  if (is_a<CSRMatrix>(result)) {
226  auto &r = down_cast<CSRMatrix &>(result);
227  std::vector<unsigned> p(p_), j(j_);
228  vec_basic x(x_.size());
229  for (unsigned i = 0; i < x_.size(); ++i) {
230  x[i] = SymEngine::conjugate(x_[i]);
231  }
232  r = CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
233  } else {
234  throw NotImplementedError("Not Implemented");
235  }
236 }
237 
238 // Matrix transpose
239 void CSRMatrix::transpose(MatrixBase &result) const
240 {
241  if (is_a<CSRMatrix>(result)) {
242  auto &r = down_cast<CSRMatrix &>(result);
243  r = this->transpose();
244  } else {
245  throw NotImplementedError("Not Implemented");
246  }
247 }
248 
249 CSRMatrix CSRMatrix::transpose(bool conjugate) const
250 {
251  const auto nnz = j_.size();
252  std::vector<unsigned> p(col_ + 1, 0), j(nnz), tmp(col_, 0);
253  vec_basic x(nnz);
254 
255  for (unsigned i = 0; i < nnz; ++i)
256  p[j_[i] + 1]++;
257  std::partial_sum(p.begin(), p.end(), p.begin());
258 
259  for (unsigned ri = 0; ri < row_; ++ri) {
260  for (unsigned i = p_[ri]; i < p_[ri + 1]; ++i) {
261  const auto ci = j_[i];
262  const unsigned k = p[ci] + tmp[ci];
263  j[k] = ri;
264  if (conjugate) {
265  x[k] = SymEngine::conjugate(x_[i]);
266  } else {
267  x[k] = x_[i];
268  }
269  tmp[ci]++;
270  }
271  }
272  return CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
273 }
274 
275 // Matrix conjugate transpose
276 void CSRMatrix::conjugate_transpose(MatrixBase &result) const
277 {
278  if (is_a<CSRMatrix>(result)) {
279  auto &r = down_cast<CSRMatrix &>(result);
280  r = this->transpose(true);
281  } else {
282  throw NotImplementedError("Not Implemented");
283  }
284 }
285 
286 // Extract out a submatrix
287 void CSRMatrix::submatrix(MatrixBase &result, unsigned row_start,
288  unsigned col_start, unsigned row_end,
289  unsigned col_end, unsigned row_step,
290  unsigned col_step) const
291 {
292  throw NotImplementedError("Not Implemented");
293 }
294 
295 // LU factorization
296 void CSRMatrix::LU(MatrixBase &L, MatrixBase &U) const
297 {
298  throw NotImplementedError("Not Implemented");
299 }
300 
301 // LDL factorization
302 void CSRMatrix::LDL(MatrixBase &L, MatrixBase &D) const
303 {
304  throw NotImplementedError("Not Implemented");
305 }
306 
307 // Solve Ax = b using LU factorization
308 void CSRMatrix::LU_solve(const MatrixBase &b, MatrixBase &x) const
309 {
310  throw NotImplementedError("Not Implemented");
311 }
312 
313 // Fraction free LU factorization
314 void CSRMatrix::FFLU(MatrixBase &LU) const
315 {
316  throw NotImplementedError("Not Implemented");
317 }
318 
319 // Fraction free LDU factorization
320 void CSRMatrix::FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U) const
321 {
322  throw NotImplementedError("Not Implemented");
323 }
324 
325 // QR factorization
326 void CSRMatrix::QR(MatrixBase &Q, MatrixBase &R) const
327 {
328  throw NotImplementedError("Not Implemented");
329 }
330 
331 // Cholesky decomposition
332 void CSRMatrix::cholesky(MatrixBase &L) const
333 {
334  throw NotImplementedError("Not Implemented");
335 }
336 
337 void CSRMatrix::csr_sum_duplicates(std::vector<unsigned> &p_,
338  std::vector<unsigned> &j_, vec_basic &x_,
339  unsigned row_)
340 {
341  unsigned nnz = 0;
342  unsigned row_end = 0;
343  unsigned jj = 0, j = 0;
344  RCP<const Basic> x = zero;
345 
346  for (unsigned i = 0; i < row_; i++) {
347  jj = row_end;
348  row_end = p_[i + 1];
349 
350  while (jj < row_end) {
351  j = j_[jj];
352  x = x_[jj];
353  jj++;
354 
355  while (jj < row_end and j_[jj] == j) {
356  x = add(x, x_[jj]);
357  jj++;
358  }
359 
360  j_[nnz] = j;
361  x_[nnz] = x;
362  nnz++;
363  }
364  p_[i + 1] = nnz;
365  }
366 
367  // Resize to discard unnecessary elements
368  j_.resize(nnz);
369  x_.resize(nnz);
370 }
371 
372 void CSRMatrix::csr_sort_indices(std::vector<unsigned> &p_,
373  std::vector<unsigned> &j_, vec_basic &x_,
374  unsigned row_)
375 {
377 
378  for (unsigned i = 0; i < row_; i++) {
379  unsigned row_start = p_[i];
380  unsigned row_end = p_[i + 1];
381 
382  temp.clear();
383 
384  for (unsigned jj = row_start; jj < row_end; jj++) {
385  temp.push_back(std::make_pair(j_[jj], x_[jj]));
386  }
387 
388  std::sort(temp.begin(), temp.end(),
389  [](const std::pair<unsigned, RCP<const Basic>> &x,
390  const std::pair<unsigned, RCP<const Basic>> &y) {
391  return x.first < y.first;
392  });
393 
394  for (unsigned jj = row_start, n = 0; jj < row_end; jj++, n++) {
395  j_[jj] = temp[n].first;
396  x_[jj] = temp[n].second;
397  }
398  }
399 }
400 
401 // Assumes that the indices are sorted
402 bool CSRMatrix::csr_has_duplicates(const std::vector<unsigned> &p_,
403  const std::vector<unsigned> &j_,
404  unsigned row_)
405 {
406  for (unsigned i = 0; i < row_; i++) {
407  for (unsigned j = p_[i]; j + 1 < p_[i + 1]; j++) {
408  if (j_[j] == j_[j + 1])
409  return true;
410  }
411  }
412  return false;
413 }
414 
415 bool CSRMatrix::csr_has_sorted_indices(const std::vector<unsigned> &p_,
416  const std::vector<unsigned> &j_,
417  unsigned row_)
418 {
419  for (unsigned i = 0; i < row_; i++) {
420  for (unsigned jj = p_[i]; jj < p_[i + 1] - 1; jj++) {
421  if (j_[jj] > j_[jj + 1])
422  return false;
423  }
424  }
425  return true;
426 }
427 
428 bool CSRMatrix::csr_has_canonical_format(const std::vector<unsigned> &p_,
429  const std::vector<unsigned> &j_,
430  unsigned row_)
431 {
432  for (unsigned i = 0; i < row_; i++) {
433  if (p_[i] > p_[i + 1])
434  return false;
435  }
436 
437  return csr_has_sorted_indices(p_, j_, row_)
438  and not csr_has_duplicates(p_, j_, row_);
439 }
440 
441 CSRMatrix CSRMatrix::from_coo(unsigned row, unsigned col,
442  const std::vector<unsigned> &i,
443  const std::vector<unsigned> &j,
444  const vec_basic &x)
445 {
446  // cast is okay, because CSRMatrix indices are unsigned.
447  unsigned nnz = numeric_cast<unsigned>(x.size());
450  vec_basic x_ = vec_basic(nnz);
451 
452  for (unsigned n = 0; n < nnz; n++) {
453  p_[i[n]]++;
454  }
455 
456  // cumsum the nnz per row to get p
457  unsigned temp;
458  for (unsigned i = 0, cumsum = 0; i < row; i++) {
459  temp = p_[i];
460  p_[i] = cumsum;
461  cumsum += temp;
462  }
463  p_[row] = nnz;
464 
465  // write j, x into j_, x_
466  unsigned row_, dest_;
467  for (unsigned n = 0; n < nnz; n++) {
468  row_ = i[n];
469  dest_ = p_[row_];
470 
471  j_[dest_] = j[n];
472  x_[dest_] = x[n];
473 
474  p_[row_]++;
475  }
476 
477  for (unsigned i = 0, last = 0; i <= row; i++) {
478  std::swap(p_[i], last);
479  }
480 
481  // sort indices
482  csr_sort_indices(p_, j_, x_, row);
483  // Remove duplicates
484  csr_sum_duplicates(p_, j_, x_, row);
485 
486  CSRMatrix B
487  = CSRMatrix(row, col, std::move(p_), std::move(j_), std::move(x_));
488  return B;
489 }
490 
491 CSRMatrix CSRMatrix::jacobian(const vec_basic &exprs, const vec_sym &x,
492  bool diff_cache)
493 {
494  const unsigned nrows = static_cast<unsigned>(exprs.size());
495  const unsigned ncols = static_cast<unsigned>(x.size());
496  std::vector<unsigned> p(1, 0), j;
497  vec_basic elems;
498  p.reserve(nrows + 1);
499  j.reserve(nrows);
500  elems.reserve(nrows);
501  for (unsigned ri = 0; ri < nrows; ++ri) {
502  p.push_back(p.back());
503  for (unsigned ci = 0; ci < ncols; ++ci) {
504  auto elem = exprs[ri]->diff(x[ci], diff_cache);
505  if (!is_true(is_zero(*elem))) {
506  p.back()++;
507  j.push_back(ci);
508  elems.emplace_back(std::move(elem));
509  }
510  }
511  }
512  return CSRMatrix(nrows, ncols, std::move(p), std::move(j),
513  std::move(elems));
514 }
515 
516 CSRMatrix CSRMatrix::jacobian(const DenseMatrix &A, const DenseMatrix &x,
517  bool diff_cache)
518 {
519  SYMENGINE_ASSERT(A.col_ == 1);
520  SYMENGINE_ASSERT(x.col_ == 1);
521  vec_sym syms;
522  syms.reserve(x.row_);
523  for (const auto &dx : x.m_) {
524  if (!is_a<Symbol>(*dx)) {
525  throw SymEngineException("'x' must contain Symbols only");
526  }
527  syms.push_back(rcp_static_cast<const Symbol>(dx));
528  }
529  return CSRMatrix::jacobian(A.m_, syms, diff_cache);
530 }
531 
532 void csr_matmat_pass1(const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C)
533 {
534  // method that uses O(n) temp storage
535  std::vector<unsigned> mask(A.col_, -1);
536  C.p_[0] = 0;
537 
538  unsigned nnz = 0;
539  for (unsigned i = 0; i < A.row_; i++) {
540  // npy_intp row_nnz = 0;
541  unsigned row_nnz = 0;
542 
543  for (unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++) {
544  unsigned j = A.j_[jj];
545  for (unsigned kk = B.p_[j]; kk < B.p_[j + 1]; kk++) {
546  unsigned k = B.j_[kk];
547  if (mask[k] != i) {
548  mask[k] = i;
549  row_nnz++;
550  }
551  }
552  }
553 
554  unsigned next_nnz = nnz + row_nnz;
555 
556  // Addition overflow: http://www.cplusplus.com/articles/DE18T05o/
557  if (next_nnz < nnz) {
558  throw std::overflow_error("nnz of the result is too large");
559  }
560 
561  nnz = next_nnz;
562  C.p_[i + 1] = nnz;
563  }
564 }
565 
566 // Pass 2 computes CSR entries for matrix C = A*B using the
567 // row pointer Cp[] computed in Pass 1.
568 void csr_matmat_pass2(const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C)
569 {
570  std::vector<int> next(A.col_, -1);
571  vec_basic sums(A.col_, zero);
572 
573  unsigned nnz = 0;
574 
575  C.p_[0] = 0;
576 
577  for (unsigned i = 0; i < A.row_; i++) {
578  int head = -2;
579  unsigned length = 0;
580 
581  unsigned jj_start = A.p_[i];
582  unsigned jj_end = A.p_[i + 1];
583  for (unsigned jj = jj_start; jj < jj_end; jj++) {
584  unsigned j = A.j_[jj];
585  RCP<const Basic> v = A.x_[jj];
586 
587  unsigned kk_start = B.p_[j];
588  unsigned kk_end = B.p_[j + 1];
589  for (unsigned kk = kk_start; kk < kk_end; kk++) {
590  unsigned k = B.j_[kk];
591 
592  sums[k] = add(sums[k], mul(v, B.x_[kk]));
593 
594  if (next[k] == -1) {
595  next[k] = head;
596  head = k;
597  length++;
598  }
599  }
600  }
601 
602  for (unsigned jj = 0; jj < length; jj++) {
603 
604  if (!is_true(is_zero(*sums[head]))) {
605  C.j_[nnz] = head;
606  C.x_[nnz] = sums[head];
607  nnz++;
608  }
609 
610  unsigned temp = head;
611  head = next[head];
612 
613  next[temp] = -1; // clear arrays
614  sums[temp] = zero;
615  }
616 
617  C.p_[i + 1] = nnz;
618  }
619 }
620 
621 // Extract main diagonal of CSR matrix A
622 void csr_diagonal(const CSRMatrix &A, DenseMatrix &D)
623 {
624  unsigned N = std::min(A.row_, A.col_);
625 
626  SYMENGINE_ASSERT(D.nrows() == N and D.ncols() == 1);
627 
628  unsigned row_start;
629  unsigned row_end;
630  RCP<const Basic> diag;
631 
632  for (unsigned i = 0; i < N; i++) {
633  row_start = A.p_[i];
634  row_end = A.p_[i + 1];
635  diag = zero;
636  unsigned jj;
637 
638  while (row_start <= row_end) {
639  jj = (row_start + row_end) / 2;
640  if (A.j_[jj] == i) {
641  diag = A.x_[jj];
642  break;
643  } else if (A.j_[jj] < i) {
644  row_start = jj + 1;
645  } else {
646  row_end = jj - 1;
647  }
648  }
649 
650  D.set(i, 0, diag);
651  }
652 }
653 
654 // Scale the rows of a CSR matrix *in place*
655 // A[i, :] *= X[i]
656 void csr_scale_rows(CSRMatrix &A, const DenseMatrix &X)
657 {
658  SYMENGINE_ASSERT(A.row_ == X.nrows() and X.ncols() == 1);
659 
660  for (unsigned i = 0; i < A.row_; i++) {
661  if (is_true(is_zero(*X.get(i, 0))))
662  throw SymEngineException("Scaling factor can't be zero");
663  for (unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++)
664  A.x_[jj] = mul(A.x_[jj], X.get(i, 0));
665  }
666 }
667 
668 // Scale the columns of a CSR matrix *in place*
669 // A[:, i] *= X[i]
670 void csr_scale_columns(CSRMatrix &A, const DenseMatrix &X)
671 {
672  SYMENGINE_ASSERT(A.col_ == X.nrows() and X.ncols() == 1);
673 
674  const unsigned nnz = A.p_[A.row_];
675  unsigned i;
676 
677  for (i = 0; i < A.col_; i++) {
678  if (is_true(is_zero(*X.get(i, 0))))
679  throw SymEngineException("Scaling factor can't be zero");
680  }
681 
682  for (i = 0; i < nnz; i++)
683  A.x_[i] = mul(A.x_[i], X.get(A.j_[i], 0));
684 }
685 
686 // Compute C = A (binary_op) B for CSR matrices that are in the
687 // canonical CSR format. Matrix dimensions of A and B should be the
688 // same. C will be in canonical format as well.
689 void csr_binop_csr_canonical(
690  const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C,
691  RCP<const Basic> (&bin_op)(const RCP<const Basic> &,
692  const RCP<const Basic> &))
693 {
694  SYMENGINE_ASSERT(A.row_ == B.row_ and A.col_ == B.col_ and C.row_ == A.row_
695  and C.col_ == A.col_);
696 
697  // Method that works for canonical CSR matrices
698  C.p_[0] = 0;
699  C.j_.clear();
700  C.x_.clear();
701  unsigned nnz = 0;
702  unsigned A_pos, B_pos, A_end, B_end;
703 
704  for (unsigned i = 0; i < A.row_; i++) {
705  A_pos = A.p_[i];
706  B_pos = B.p_[i];
707  A_end = A.p_[i + 1];
708  B_end = B.p_[i + 1];
709 
710  // while not finished with either row
711  while (A_pos < A_end and B_pos < B_end) {
712  unsigned A_j = A.j_[A_pos];
713  unsigned B_j = B.j_[B_pos];
714 
715  if (A_j == B_j) {
716  RCP<const Basic> result = bin_op(A.x_[A_pos], B.x_[B_pos]);
717  if (!is_true(is_zero(*result))) {
718  C.j_.push_back(A_j);
719  C.x_.push_back(result);
720  nnz++;
721  }
722  A_pos++;
723  B_pos++;
724  } else if (A_j < B_j) {
725  RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
726  if (!is_true(is_zero(*result))) {
727  C.j_.push_back(A_j);
728  C.x_.push_back(result);
729  nnz++;
730  }
731  A_pos++;
732  } else {
733  // B_j < A_j
734  RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
735  if (!is_true(is_zero(*result))) {
736  C.j_.push_back(B_j);
737  C.x_.push_back(result);
738  nnz++;
739  }
740  B_pos++;
741  }
742  }
743 
744  // tail
745  while (A_pos < A_end) {
746  RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
747  if (!is_true(is_zero(*result))) {
748  C.j_.push_back(A.j_[A_pos]);
749  C.x_.push_back(result);
750  nnz++;
751  }
752  A_pos++;
753  }
754  while (B_pos < B_end) {
755  RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
756  if (!is_true(is_zero(*result))) {
757  C.j_.push_back(B.j_[B_pos]);
758  C.x_.push_back(result);
759  nnz++;
760  }
761  B_pos++;
762  }
763 
764  C.p_[i + 1] = nnz;
765  }
766 
767  // It's enough to check for duplicates as the column indices
768  // remain sorted after the above operations
769  if (CSRMatrix::csr_has_duplicates(C.p_, C.j_, A.row_))
770  CSRMatrix::csr_sum_duplicates(C.p_, C.j_, C.x_, A.row_);
771 }
772 
773 } // namespace SymEngine
Classes and functions relating to the binary operation of addition.
T back(T... args)
T begin(T... args)
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
T clear(T... args)
T end(T... args)
T make_pair(T... args)
T make_tuple(T... args)
T min(T... args)
T move(T... args)
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
tribool is_zero(const Basic &b, const Assumptions *assumptions=nullptr)
Check if a number is zero.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.
Definition: functions.cpp:149
STL namespace.
T next(T... args)
T partial_sum(T... args)
T push_back(T... args)
T reserve(T... args)
T resize(T... args)
T sort(T... args)
T swap(T... args)