sparse_matrix.cpp
1 #include <numeric>
2 #include <symengine/matrix.h>
3 #include <symengine/add.h>
4 #include <symengine/functions.h>
5 #include <symengine/mul.h>
6 #include <symengine/constants.h>
7 #include <symengine/symengine_exception.h>
8 #include <symengine/visitor.h>
9 #include <symengine/test_visitors.h>
10 
11 namespace SymEngine
12 {
13 // ----------------------------- CSRMatrix ------------------------------------
14 CSRMatrix::CSRMatrix() {}
15 
16 CSRMatrix::CSRMatrix(unsigned row, unsigned col) : row_(row), col_(col)
17 {
18  p_ = std::vector<unsigned>(row + 1, 0);
19  SYMENGINE_ASSERT(is_canonical());
20 }
21 
22 CSRMatrix::CSRMatrix(unsigned row, unsigned col, const std::vector<unsigned> &p,
23  const std::vector<unsigned> &j, const vec_basic &x)
24  : p_{p}, j_{j}, x_{x}, row_(row), col_(col)
25 {
26  SYMENGINE_ASSERT(is_canonical());
27 }
28 
29 CSRMatrix::CSRMatrix(unsigned row, unsigned col, std::vector<unsigned> &&p,
30  std::vector<unsigned> &&j, vec_basic &&x)
31  : p_{std::move(p)}, j_{std::move(j)}, x_{std::move(x)}, row_(row), col_(col)
32 {
33  SYMENGINE_ASSERT(is_canonical());
34 }
35 
36 CSRMatrix &CSRMatrix::operator=(CSRMatrix &&other)
37 {
38  col_ = other.col_;
39  row_ = other.row_;
40  p_ = std::move(other.p_);
41  j_ = std::move(other.j_);
42  x_ = std::move(other.x_);
43  return *this;
44 }
45 
46 std::tuple<std::vector<unsigned>, std::vector<unsigned>, vec_basic>
47 CSRMatrix::as_vectors() const
48 {
49  auto p = p_, j = j_;
50  auto x = x_;
51  return std::make_tuple(std::move(p), std::move(j), std::move(x));
52 }
53 
54 bool CSRMatrix::eq(const MatrixBase &other) const
55 {
56  unsigned row = this->nrows();
57  if (row != other.nrows() or this->ncols() != other.ncols())
58  return false;
59 
60  if (is_a<CSRMatrix>(other)) {
61  const CSRMatrix &o = down_cast<const CSRMatrix &>(other);
62 
63  if (this->p_[row] != o.p_[row])
64  return false;
65 
66  for (unsigned i = 0; i <= row; i++)
67  if (this->p_[i] != o.p_[i])
68  return false;
69 
70  for (unsigned i = 0; i < this->p_[row]; i++)
71  if ((this->j_[i] != o.j_[i]) or neq(*this->x_[i], *(o.x_[i])))
72  return false;
73 
74  return true;
75  } else {
76  return this->MatrixBase::eq(other);
77  }
78 }
79 
80 bool CSRMatrix::is_canonical() const
81 {
82  if (p_.size() != row_ + 1 or j_.size() != p_[row_] or x_.size() != p_[row_])
83  return false;
84 
85  if (p_[row_] != 0) // Zero matrix is in canonical format
86  return csr_has_canonical_format(p_, j_, row_);
87  return true;
88 }
89 
90 // Get and set elements
91 RCP<const Basic> CSRMatrix::get(unsigned i, unsigned j) const
92 {
93  SYMENGINE_ASSERT(i < row_ and j < col_);
94 
95  unsigned row_start = p_[i];
96  unsigned row_end = p_[i + 1];
97  unsigned k;
98 
99  if (row_start == row_end) {
100  return zero;
101  }
102 
103  while (row_start < row_end) {
104  k = (row_start + row_end) / 2;
105  if (j_[k] == j) {
106  return x_[k];
107  } else if (j_[k] < j) {
108  row_start = k + 1;
109  } else {
110  row_end = k;
111  }
112  }
113 
114  return zero;
115 }
116 
117 void CSRMatrix::set(unsigned i, unsigned j, const RCP<const Basic> &e)
118 {
119  SYMENGINE_ASSERT(i < row_ and j < col_);
120 
121  unsigned k = p_[i];
122  unsigned row_end = p_[i + 1];
123  unsigned end = p_[i + 1];
124  unsigned mid;
125 
126  while (k < end) {
127  mid = (k + end) / 2;
128  if (mid == k) {
129  if (j_[k] < j) {
130  k++;
131  }
132  break;
133  }
134  SYMENGINE_ASSERT(mid > 0);
135  if (j_[mid] >= j and j_[mid - 1] < j) {
136  k = mid;
137  break;
138  } else if (j_[mid - 1] >= j) {
139  end = mid - 1;
140  } else {
141  k = mid + 1;
142  }
143  }
144 
145  if (!is_true(is_zero(*e))) {
146  if (k < row_end and j_[k] == j) {
147  x_[k] = e;
148  } else { // j_[k] > j or k is the last non-zero element
149  x_.insert(x_.begin() + k, e);
150  j_.insert(j_.begin() + k, j);
151  for (unsigned l = i + 1; l <= row_; l++)
152  p_[l]++;
153  }
154  } else { // e is zero
155  if (k < row_end and j_[k] == j) { // remove existing non-zero element
156  x_.erase(x_.begin() + k);
157  j_.erase(j_.begin() + k);
158  for (unsigned l = i + 1; l <= row_; l++)
159  p_[l]--;
160  }
161  }
162 }
163 
164 tribool CSRMatrix::is_real(const Assumptions *assumptions) const
165 {
166  RealVisitor visitor(assumptions);
167  tribool cur = tribool::tritrue;
168  for (auto &e : x_) {
169  cur = and_tribool(cur, visitor.apply(*e));
170  if (is_false(cur)) {
171  return cur;
172  }
173  }
174  return cur;
175 }
176 
177 unsigned CSRMatrix::rank() const
178 {
179  throw NotImplementedError("Not Implemented");
180 }
181 
182 RCP<const Basic> CSRMatrix::det() const
183 {
184  throw NotImplementedError("Not Implemented");
185 }
186 
187 void CSRMatrix::inv(MatrixBase &result) const
188 {
189  throw NotImplementedError("Not Implemented");
190 }
191 
192 void CSRMatrix::add_matrix(const MatrixBase &other, MatrixBase &result) const
193 {
194  throw NotImplementedError("Not Implemented");
195 }
196 
197 void CSRMatrix::mul_matrix(const MatrixBase &other, MatrixBase &result) const
198 {
199  throw NotImplementedError("Not Implemented");
200 }
201 
202 void CSRMatrix::elementwise_mul_matrix(const MatrixBase &other,
203  MatrixBase &result) const
204 {
205  if (is_a<CSRMatrix>(result)) {
206  auto &o = down_cast<const CSRMatrix &>(other);
207  auto &r = down_cast<CSRMatrix &>(result);
208  csr_binop_csr_canonical(*this, o, r, mul);
209  }
210 }
211 
212 // Add a scalar
213 void CSRMatrix::add_scalar(const RCP<const Basic> &k, MatrixBase &result) const
214 {
215  throw NotImplementedError("Not Implemented");
216 }
217 
218 // Multiply by a scalar
219 void CSRMatrix::mul_scalar(const RCP<const Basic> &k, MatrixBase &result) const
220 {
221  throw NotImplementedError("Not Implemented");
222 }
223 
224 // Matrix conjugate
225 void CSRMatrix::conjugate(MatrixBase &result) const
226 {
227  if (is_a<CSRMatrix>(result)) {
228  auto &r = down_cast<CSRMatrix &>(result);
229  std::vector<unsigned> p(p_), j(j_);
230  vec_basic x(x_.size());
231  for (unsigned i = 0; i < x_.size(); ++i) {
232  x[i] = SymEngine::conjugate(x_[i]);
233  }
234  r = CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
235  } else {
236  throw NotImplementedError("Not Implemented");
237  }
238 }
239 
240 // Matrix transpose
241 void CSRMatrix::transpose(MatrixBase &result) const
242 {
243  if (is_a<CSRMatrix>(result)) {
244  auto &r = down_cast<CSRMatrix &>(result);
245  r = this->transpose();
246  } else {
247  throw NotImplementedError("Not Implemented");
248  }
249 }
250 
251 CSRMatrix CSRMatrix::transpose(bool conjugate) const
252 {
253  const auto nnz = j_.size();
254  std::vector<unsigned> p(col_ + 1, 0), j(nnz), tmp(col_, 0);
255  vec_basic x(nnz);
256 
257  for (unsigned i = 0; i < nnz; ++i)
258  p[j_[i] + 1]++;
259  std::partial_sum(p.begin(), p.end(), p.begin());
260 
261  for (unsigned ri = 0; ri < row_; ++ri) {
262  for (unsigned i = p_[ri]; i < p_[ri + 1]; ++i) {
263  const auto ci = j_[i];
264  const unsigned k = p[ci] + tmp[ci];
265  j[k] = ri;
266  if (conjugate) {
267  x[k] = SymEngine::conjugate(x_[i]);
268  } else {
269  x[k] = x_[i];
270  }
271  tmp[ci]++;
272  }
273  }
274  return CSRMatrix(col_, row_, std::move(p), std::move(j), std::move(x));
275 }
276 
277 // Matrix conjugate transpose
278 void CSRMatrix::conjugate_transpose(MatrixBase &result) const
279 {
280  if (is_a<CSRMatrix>(result)) {
281  auto &r = down_cast<CSRMatrix &>(result);
282  r = this->transpose(true);
283  } else {
284  throw NotImplementedError("Not Implemented");
285  }
286 }
287 
288 // Extract out a submatrix
289 void CSRMatrix::submatrix(MatrixBase &result, unsigned row_start,
290  unsigned col_start, unsigned row_end,
291  unsigned col_end, unsigned row_step,
292  unsigned col_step) const
293 {
294  throw NotImplementedError("Not Implemented");
295 }
296 
297 // LU factorization
298 void CSRMatrix::LU(MatrixBase &L, MatrixBase &U) const
299 {
300  throw NotImplementedError("Not Implemented");
301 }
302 
303 // LDL factorization
304 void CSRMatrix::LDL(MatrixBase &L, MatrixBase &D) const
305 {
306  throw NotImplementedError("Not Implemented");
307 }
308 
309 // Solve Ax = b using LU factorization
310 void CSRMatrix::LU_solve(const MatrixBase &b, MatrixBase &x) const
311 {
312  throw NotImplementedError("Not Implemented");
313 }
314 
315 // Fraction free LU factorization
316 void CSRMatrix::FFLU(MatrixBase &LU) const
317 {
318  throw NotImplementedError("Not Implemented");
319 }
320 
321 // Fraction free LDU factorization
322 void CSRMatrix::FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U) const
323 {
324  throw NotImplementedError("Not Implemented");
325 }
326 
327 // QR factorization
328 void CSRMatrix::QR(MatrixBase &Q, MatrixBase &R) const
329 {
330  throw NotImplementedError("Not Implemented");
331 }
332 
333 // Cholesky decomposition
334 void CSRMatrix::cholesky(MatrixBase &L) const
335 {
336  throw NotImplementedError("Not Implemented");
337 }
338 
339 void CSRMatrix::csr_sum_duplicates(std::vector<unsigned> &p_,
340  std::vector<unsigned> &j_, vec_basic &x_,
341  unsigned row_)
342 {
343  unsigned nnz = 0;
344  unsigned row_end = 0;
345  unsigned jj = 0, j = 0;
346  RCP<const Basic> x = zero;
347 
348  for (unsigned i = 0; i < row_; i++) {
349  jj = row_end;
350  row_end = p_[i + 1];
351 
352  while (jj < row_end) {
353  j = j_[jj];
354  x = x_[jj];
355  jj++;
356 
357  while (jj < row_end and j_[jj] == j) {
358  x = add(x, x_[jj]);
359  jj++;
360  }
361 
362  j_[nnz] = j;
363  x_[nnz] = x;
364  nnz++;
365  }
366  p_[i + 1] = nnz;
367  }
368 
369  // Resize to discard unnecessary elements
370  j_.resize(nnz);
371  x_.resize(nnz);
372 }
373 
374 void CSRMatrix::csr_sort_indices(std::vector<unsigned> &p_,
375  std::vector<unsigned> &j_, vec_basic &x_,
376  unsigned row_)
377 {
378  std::vector<std::pair<unsigned, RCP<const Basic>>> temp;
379 
380  for (unsigned i = 0; i < row_; i++) {
381  unsigned row_start = p_[i];
382  unsigned row_end = p_[i + 1];
383 
384  temp.clear();
385 
386  for (unsigned jj = row_start; jj < row_end; jj++) {
387  temp.push_back(std::make_pair(j_[jj], x_[jj]));
388  }
389 
390  std::sort(temp.begin(), temp.end(),
391  [](const std::pair<unsigned, RCP<const Basic>> &x,
392  const std::pair<unsigned, RCP<const Basic>> &y) {
393  return x.first < y.first;
394  });
395 
396  for (unsigned jj = row_start, n = 0; jj < row_end; jj++, n++) {
397  j_[jj] = temp[n].first;
398  x_[jj] = temp[n].second;
399  }
400  }
401 }
402 
403 // Assumes that the indices are sorted
404 bool CSRMatrix::csr_has_duplicates(const std::vector<unsigned> &p_,
405  const std::vector<unsigned> &j_,
406  unsigned row_)
407 {
408  for (unsigned i = 0; i < row_; i++) {
409  for (unsigned j = p_[i]; j + 1 < p_[i + 1]; j++) {
410  if (j_[j] == j_[j + 1])
411  return true;
412  }
413  }
414  return false;
415 }
416 
417 bool CSRMatrix::csr_has_sorted_indices(const std::vector<unsigned> &p_,
418  const std::vector<unsigned> &j_,
419  unsigned row_)
420 {
421  for (unsigned i = 0; i < row_; i++) {
422  for (unsigned jj = p_[i]; jj + 1 < p_[i + 1]; jj++) {
423  if (j_[jj] > j_[jj + 1])
424  return false;
425  }
426  }
427  return true;
428 }
429 
430 bool CSRMatrix::csr_has_canonical_format(const std::vector<unsigned> &p_,
431  const std::vector<unsigned> &j_,
432  unsigned row_)
433 {
434  for (unsigned i = 0; i < row_; i++) {
435  if (p_[i] > p_[i + 1])
436  return false;
437  }
438 
439  return csr_has_sorted_indices(p_, j_, row_)
440  and not csr_has_duplicates(p_, j_, row_);
441 }
442 
443 CSRMatrix CSRMatrix::from_coo(unsigned row, unsigned col,
444  const std::vector<unsigned> &i,
445  const std::vector<unsigned> &j,
446  const vec_basic &x)
447 {
448  // cast is okay, because CSRMatrix indices are unsigned.
449  unsigned nnz = numeric_cast<unsigned>(x.size());
450  std::vector<unsigned> p_ = std::vector<unsigned>(row + 1, 0);
451  std::vector<unsigned> j_ = std::vector<unsigned>(nnz);
452  vec_basic x_ = vec_basic(nnz);
453 
454  for (unsigned n = 0; n < nnz; n++) {
455  p_[i[n]]++;
456  }
457 
458  // cumsum the nnz per row to get p
459  unsigned temp;
460  for (unsigned i = 0, cumsum = 0; i < row; i++) {
461  temp = p_[i];
462  p_[i] = cumsum;
463  cumsum += temp;
464  }
465  p_[row] = nnz;
466 
467  // write j, x into j_, x_
468  unsigned row_, dest_;
469  for (unsigned n = 0; n < nnz; n++) {
470  row_ = i[n];
471  dest_ = p_[row_];
472 
473  j_[dest_] = j[n];
474  x_[dest_] = x[n];
475 
476  p_[row_]++;
477  }
478 
479  for (unsigned i = 0, last = 0; i <= row; i++) {
480  std::swap(p_[i], last);
481  }
482 
483  // sort indices
484  csr_sort_indices(p_, j_, x_, row);
485  // Remove duplicates
486  csr_sum_duplicates(p_, j_, x_, row);
487 
488  CSRMatrix B
489  = CSRMatrix(row, col, std::move(p_), std::move(j_), std::move(x_));
490  return B;
491 }
492 
493 CSRMatrix CSRMatrix::jacobian(const vec_basic &exprs, const vec_sym &x,
494  bool diff_cache)
495 {
496  const unsigned nrows = static_cast<unsigned>(exprs.size());
497  const unsigned ncols = static_cast<unsigned>(x.size());
498  std::vector<unsigned> p(1, 0), j;
499  vec_basic elems;
500  p.reserve(nrows + 1);
501  j.reserve(nrows);
502  elems.reserve(nrows);
503  for (unsigned ri = 0; ri < nrows; ++ri) {
504  p.push_back(p.back());
505  for (unsigned ci = 0; ci < ncols; ++ci) {
506  auto elem = exprs[ri]->diff(x[ci], diff_cache);
507  if (!is_true(is_zero(*elem))) {
508  p.back()++;
509  j.push_back(ci);
510  elems.emplace_back(std::move(elem));
511  }
512  }
513  }
514  return CSRMatrix(nrows, ncols, std::move(p), std::move(j),
515  std::move(elems));
516 }
517 
518 CSRMatrix CSRMatrix::jacobian(const DenseMatrix &A, const DenseMatrix &x,
519  bool diff_cache)
520 {
521  SYMENGINE_ASSERT(A.col_ == 1);
522  SYMENGINE_ASSERT(x.col_ == 1);
523  vec_sym syms;
524  syms.reserve(x.row_);
525  for (const auto &dx : x.m_) {
526  if (!is_a<Symbol>(*dx)) {
527  throw SymEngineException("'x' must contain Symbols only");
528  }
529  syms.push_back(rcp_static_cast<const Symbol>(dx));
530  }
531  return CSRMatrix::jacobian(A.m_, syms, diff_cache);
532 }
533 
534 void csr_matmat_pass1(const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C)
535 {
536  // method that uses O(n) temp storage
537  std::vector<unsigned> mask(A.col_, -1);
538  C.p_[0] = 0;
539 
540  unsigned nnz = 0;
541  for (unsigned i = 0; i < A.row_; i++) {
542  // npy_intp row_nnz = 0;
543  unsigned row_nnz = 0;
544 
545  for (unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++) {
546  unsigned j = A.j_[jj];
547  for (unsigned kk = B.p_[j]; kk < B.p_[j + 1]; kk++) {
548  unsigned k = B.j_[kk];
549  if (mask[k] != i) {
550  mask[k] = i;
551  row_nnz++;
552  }
553  }
554  }
555 
556  unsigned next_nnz = nnz + row_nnz;
557 
558  // Addition overflow: http://www.cplusplus.com/articles/DE18T05o/
559  if (next_nnz < nnz) {
560  throw std::overflow_error("nnz of the result is too large");
561  }
562 
563  nnz = next_nnz;
564  C.p_[i + 1] = nnz;
565  }
566 }
567 
568 // Pass 2 computes CSR entries for matrix C = A*B using the
569 // row pointer Cp[] computed in Pass 1.
570 void csr_matmat_pass2(const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C)
571 {
572  std::vector<int> next(A.col_, -1);
573  vec_basic sums(A.col_, zero);
574 
575  unsigned nnz = 0;
576 
577  C.p_[0] = 0;
578 
579  for (unsigned i = 0; i < A.row_; i++) {
580  int head = -2;
581  unsigned length = 0;
582 
583  unsigned jj_start = A.p_[i];
584  unsigned jj_end = A.p_[i + 1];
585  for (unsigned jj = jj_start; jj < jj_end; jj++) {
586  unsigned j = A.j_[jj];
587  RCP<const Basic> v = A.x_[jj];
588 
589  unsigned kk_start = B.p_[j];
590  unsigned kk_end = B.p_[j + 1];
591  for (unsigned kk = kk_start; kk < kk_end; kk++) {
592  unsigned k = B.j_[kk];
593 
594  sums[k] = add(sums[k], mul(v, B.x_[kk]));
595 
596  if (next[k] == -1) {
597  next[k] = head;
598  head = k;
599  length++;
600  }
601  }
602  }
603 
604  for (unsigned jj = 0; jj < length; jj++) {
605 
606  if (!is_true(is_zero(*sums[head]))) {
607  C.j_[nnz] = head;
608  C.x_[nnz] = sums[head];
609  nnz++;
610  }
611 
612  unsigned temp = head;
613  head = next[head];
614 
615  next[temp] = -1; // clear arrays
616  sums[temp] = zero;
617  }
618 
619  C.p_[i + 1] = nnz;
620  }
621 }
622 
623 // Extract main diagonal of CSR matrix A
624 void csr_diagonal(const CSRMatrix &A, DenseMatrix &D)
625 {
626  unsigned N = std::min(A.row_, A.col_);
627 
628  SYMENGINE_ASSERT(D.nrows() == N and D.ncols() == 1);
629 
630  unsigned row_start;
631  unsigned row_end;
632  RCP<const Basic> diag;
633 
634  for (unsigned i = 0; i < N; i++) {
635  row_start = A.p_[i];
636  row_end = A.p_[i + 1];
637  diag = zero;
638  unsigned jj;
639 
640  while (row_start <= row_end) {
641  jj = (row_start + row_end) / 2;
642  if (A.j_[jj] == i) {
643  diag = A.x_[jj];
644  break;
645  } else if (A.j_[jj] < i) {
646  row_start = jj + 1;
647  } else {
648  SYMENGINE_ASSERT(jj > 0);
649  row_end = jj - 1;
650  }
651  }
652 
653  D.set(i, 0, diag);
654  }
655 }
656 
657 // Scale the rows of a CSR matrix *in place*
658 // A[i, :] *= X[i]
659 void csr_scale_rows(CSRMatrix &A, const DenseMatrix &X)
660 {
661  SYMENGINE_ASSERT(A.row_ == X.nrows() and X.ncols() == 1);
662 
663  for (unsigned i = 0; i < A.row_; i++) {
664  if (is_true(is_zero(*X.get(i, 0))))
665  throw SymEngineException("Scaling factor can't be zero");
666  for (unsigned jj = A.p_[i]; jj < A.p_[i + 1]; jj++)
667  A.x_[jj] = mul(A.x_[jj], X.get(i, 0));
668  }
669 }
670 
671 // Scale the columns of a CSR matrix *in place*
672 // A[:, i] *= X[i]
673 void csr_scale_columns(CSRMatrix &A, const DenseMatrix &X)
674 {
675  SYMENGINE_ASSERT(A.col_ == X.nrows() and X.ncols() == 1);
676 
677  const unsigned nnz = A.p_[A.row_];
678  unsigned i;
679 
680  for (i = 0; i < A.col_; i++) {
681  if (is_true(is_zero(*X.get(i, 0))))
682  throw SymEngineException("Scaling factor can't be zero");
683  }
684 
685  for (i = 0; i < nnz; i++)
686  A.x_[i] = mul(A.x_[i], X.get(A.j_[i], 0));
687 }
688 
689 // Compute C = A (binary_op) B for CSR matrices that are in the
690 // canonical CSR format. Matrix dimensions of A and B should be the
691 // same. C will be in canonical format as well.
692 void csr_binop_csr_canonical(
693  const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C,
694  RCP<const Basic> (&bin_op)(const RCP<const Basic> &,
695  const RCP<const Basic> &))
696 {
697  SYMENGINE_ASSERT(A.row_ == B.row_ and A.col_ == B.col_ and C.row_ == A.row_
698  and C.col_ == A.col_);
699 
700  // Method that works for canonical CSR matrices
701  C.p_[0] = 0;
702  C.j_.clear();
703  C.x_.clear();
704  unsigned nnz = 0;
705  unsigned A_pos, B_pos, A_end, B_end;
706 
707  for (unsigned i = 0; i < A.row_; i++) {
708  A_pos = A.p_[i];
709  B_pos = B.p_[i];
710  A_end = A.p_[i + 1];
711  B_end = B.p_[i + 1];
712 
713  // while not finished with either row
714  while (A_pos < A_end and B_pos < B_end) {
715  unsigned A_j = A.j_[A_pos];
716  unsigned B_j = B.j_[B_pos];
717 
718  if (A_j == B_j) {
719  RCP<const Basic> result = bin_op(A.x_[A_pos], B.x_[B_pos]);
720  if (!is_true(is_zero(*result))) {
721  C.j_.push_back(A_j);
722  C.x_.push_back(result);
723  nnz++;
724  }
725  A_pos++;
726  B_pos++;
727  } else if (A_j < B_j) {
728  RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
729  if (!is_true(is_zero(*result))) {
730  C.j_.push_back(A_j);
731  C.x_.push_back(result);
732  nnz++;
733  }
734  A_pos++;
735  } else {
736  // B_j < A_j
737  RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
738  if (!is_true(is_zero(*result))) {
739  C.j_.push_back(B_j);
740  C.x_.push_back(result);
741  nnz++;
742  }
743  B_pos++;
744  }
745  }
746 
747  // tail
748  while (A_pos < A_end) {
749  RCP<const Basic> result = bin_op(A.x_[A_pos], zero);
750  if (!is_true(is_zero(*result))) {
751  C.j_.push_back(A.j_[A_pos]);
752  C.x_.push_back(result);
753  nnz++;
754  }
755  A_pos++;
756  }
757  while (B_pos < B_end) {
758  RCP<const Basic> result = bin_op(zero, B.x_[B_pos]);
759  if (!is_true(is_zero(*result))) {
760  C.j_.push_back(B.j_[B_pos]);
761  C.x_.push_back(result);
762  nnz++;
763  }
764  B_pos++;
765  }
766 
767  C.p_[i + 1] = nnz;
768  }
769 
770  // It's enough to check for duplicates as the column indices
771  // remain sorted after the above operations
772  if (CSRMatrix::csr_has_duplicates(C.p_, C.j_, A.row_))
773  CSRMatrix::csr_sum_duplicates(C.p_, C.j_, C.x_, A.row_);
774 }
775 
776 } // namespace SymEngine
Classes and functions relating to the binary operation of addition.
RCP< const Basic > add(const RCP< const Basic > &a, const RCP< const Basic > &b)
Adds two objects (safely).
Definition: add.cpp:425
Main namespace for SymEngine package.
Definition: add.cpp:19
RCP< const Basic > mul(const RCP< const Basic > &a, const RCP< const Basic > &b)
Multiplication.
Definition: mul.cpp:352
tribool is_zero(const Basic &b, const Assumptions *assumptions=nullptr)
Check if a number is zero.
bool neq(const Basic &a, const Basic &b)
Checks inequality for a and b
Definition: basic-inl.h:29
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.
Definition: functions.cpp:149