Loading...
Searching...
No Matches
matrix.h
1#ifndef SYMENGINE_MATRIX_H
2#define SYMENGINE_MATRIX_H
3
4#include <symengine/basic.h>
5#include <symengine/sets.h>
6
7namespace SymEngine
8{
9
10// Base class for matrices
12{
13public:
14 virtual ~MatrixBase(){};
15
16 bool is_square() const
17 {
18 return ncols() == nrows();
19 }
20
21 // Below methods should be implemented by the derived classes. If not
22 // applicable, raise an exception
23
24 // Get the # of rows and # of columns
25 virtual unsigned nrows() const = 0;
26 virtual unsigned ncols() const = 0;
27 virtual bool eq(const MatrixBase &other) const;
28
29 virtual tribool is_real(const Assumptions *assumptions = nullptr) const = 0;
30
31 // Get and set elements
32 virtual RCP<const Basic> get(unsigned i, unsigned j) const = 0;
33 virtual void set(unsigned i, unsigned j, const RCP<const Basic> &e) = 0;
34
35 // Print Matrix, very mundane version, should be overriden derived
36 // class if better printing is available
37 virtual std::string __str__() const;
38
39 virtual unsigned rank() const = 0;
40 virtual RCP<const Basic> det() const = 0;
41 virtual void inv(MatrixBase &result) const = 0;
42
43 // Matrix addition
44 virtual void add_matrix(const MatrixBase &other,
45 MatrixBase &result) const = 0;
46
47 // Matrix Multiplication
48 virtual void mul_matrix(const MatrixBase &other,
49 MatrixBase &result) const = 0;
50
51 // Matrix elementwise Multiplication
52 virtual void elementwise_mul_matrix(const MatrixBase &other,
53 MatrixBase &result) const = 0;
54
55 // Add a scalar
56 virtual void add_scalar(const RCP<const Basic> &k,
57 MatrixBase &result) const = 0;
58
59 // Multiply by a scalar
60 virtual void mul_scalar(const RCP<const Basic> &k,
61 MatrixBase &result) const = 0;
62
63 // Matrix conjugate
64 virtual void conjugate(MatrixBase &result) const = 0;
65
66 // Matrix transpose
67 virtual void transpose(MatrixBase &result) const = 0;
68
69 // Matrix conjugate transpose
70 virtual void conjugate_transpose(MatrixBase &result) const = 0;
71
72 // Extract out a submatrix
73 virtual void submatrix(MatrixBase &result, unsigned row_start,
74 unsigned col_start, unsigned row_end,
75 unsigned col_end, unsigned row_step = 1,
76 unsigned col_step = 1) const = 0;
77 // LU factorization
78 virtual void LU(MatrixBase &L, MatrixBase &U) const = 0;
79
80 // LDL factorization
81 virtual void LDL(MatrixBase &L, MatrixBase &D) const = 0;
82
83 // Fraction free LU factorization
84 virtual void FFLU(MatrixBase &LU) const = 0;
85
86 // Fraction free LDU factorization
87 virtual void FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U) const = 0;
88
89 // QR factorization
90 virtual void QR(MatrixBase &Q, MatrixBase &R) const = 0;
91
92 // Cholesky decomposition
93 virtual void cholesky(MatrixBase &L) const = 0;
94
95 // Solve Ax = b using LU factorization
96 virtual void LU_solve(const MatrixBase &b, MatrixBase &x) const = 0;
97};
98
100
101class CSRMatrix;
102
103// ----------------------------- Dense Matrix --------------------------------//
105{
106public:
107 // Constructors
108 DenseMatrix();
109 DenseMatrix(const DenseMatrix &) = default;
110 DenseMatrix(unsigned row, unsigned col);
111 DenseMatrix(unsigned row, unsigned col, const vec_basic &l);
112 DenseMatrix(const vec_basic &column_elements);
113 DenseMatrix &operator=(const DenseMatrix &other) = default;
114 // Resize
115 void resize(unsigned i, unsigned j);
116
117 // Should implement all the virtual methods from MatrixBase
118 // and throw an exception if a method is not applicable.
119
120 // Get and set elements
121 RCP<const Basic> get(unsigned i, unsigned j) const override;
122 void set(unsigned i, unsigned j, const RCP<const Basic> &e) override;
123 virtual vec_basic as_vec_basic() const;
124
125 unsigned nrows() const override
126 {
127 return row_;
128 }
129 unsigned ncols() const override
130 {
131 return col_;
132 }
133
134 virtual bool is_lower() const;
135 virtual bool is_upper() const;
136 virtual tribool is_zero() const;
137 virtual tribool is_diagonal() const;
138 tribool is_real(const Assumptions *assumptions = nullptr) const override;
139 virtual tribool is_symmetric() const;
140 virtual tribool is_hermitian() const;
141 virtual tribool is_weakly_diagonally_dominant() const;
142 virtual tribool is_strictly_diagonally_dominant() const;
143 virtual tribool is_positive_definite() const;
144 virtual tribool is_negative_definite() const;
145
146 RCP<const Basic> trace() const;
147 unsigned rank() const override;
148 RCP<const Basic> det() const override;
149 void inv(MatrixBase &result) const override;
150
151 // Matrix addition
152 void add_matrix(const MatrixBase &other, MatrixBase &result) const override;
153
154 // Matrix multiplication
155 void mul_matrix(const MatrixBase &other, MatrixBase &result) const override;
156
157 // Matrix elementwise Multiplication
158 void elementwise_mul_matrix(const MatrixBase &other,
159 MatrixBase &result) const override;
160
161 // Add a scalar
162 void add_scalar(const RCP<const Basic> &k,
163 MatrixBase &result) const override;
164
165 // Multiply by a scalar
166 void mul_scalar(const RCP<const Basic> &k,
167 MatrixBase &result) const override;
168
169 // Matrix conjugate
170 void conjugate(MatrixBase &result) const override;
171
172 // Matrix transpose
173 void transpose(MatrixBase &result) const override;
174
175 // Matrix conjugate transpose
176 void conjugate_transpose(MatrixBase &result) const override;
177
178 // Extract out a submatrix
179 void submatrix(MatrixBase &result, unsigned row_start, unsigned col_start,
180 unsigned row_end, unsigned col_end, unsigned row_step = 1,
181 unsigned col_step = 1) const override;
182
183 // LU factorization
184 void LU(MatrixBase &L, MatrixBase &U) const override;
185
186 // LDL factorization
187 void LDL(MatrixBase &L, MatrixBase &D) const override;
188
189 // Solve Ax = b using LU factorization
190 void LU_solve(const MatrixBase &b, MatrixBase &x) const override;
191
192 // Fraction free LU factorization
193 void FFLU(MatrixBase &LU) const override;
194
195 // Fraction free LDU factorization
196 void FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U) const override;
197
198 // QR factorization
199 void QR(MatrixBase &Q, MatrixBase &R) const override;
200
201 // Cholesky decomposition
202 void cholesky(MatrixBase &L) const override;
203
204 // Return the Jacobian of the matrix
205 friend void jacobian(const DenseMatrix &A, const DenseMatrix &x,
206 DenseMatrix &result, bool diff_cache);
207 // Return the Jacobian of the matrix using sdiff
208 friend void sjacobian(const DenseMatrix &A, const DenseMatrix &x,
209 DenseMatrix &result, bool diff_cache);
210
211 // Differentiate the matrix element-wise
212 friend void diff(const DenseMatrix &A, const RCP<const Symbol> &x,
213 DenseMatrix &result, bool diff_cache);
214 // Differentiate the matrix element-wise using SymPy compatible diff
215 friend void sdiff(const DenseMatrix &A, const RCP<const Basic> &x,
216 DenseMatrix &result, bool diff_cache);
217
218 // Friend functions related to Matrix Operations
219 friend void add_dense_dense(const DenseMatrix &A, const DenseMatrix &B,
220 DenseMatrix &C);
221 friend void add_dense_scalar(const DenseMatrix &A,
222 const RCP<const Basic> &k, DenseMatrix &B);
223 friend void mul_dense_dense(const DenseMatrix &A, const DenseMatrix &B,
224 DenseMatrix &C);
225 friend void elementwise_mul_dense_dense(const DenseMatrix &A,
226 const DenseMatrix &B,
227 DenseMatrix &C);
228 friend void mul_dense_scalar(const DenseMatrix &A,
229 const RCP<const Basic> &k, DenseMatrix &C);
230 friend void conjugate_dense(const DenseMatrix &A, DenseMatrix &B);
231 friend void transpose_dense(const DenseMatrix &A, DenseMatrix &B);
232 friend void conjugate_transpose_dense(const DenseMatrix &A, DenseMatrix &B);
233 friend void submatrix_dense(const DenseMatrix &A, DenseMatrix &B,
234 unsigned row_start, unsigned col_start,
235 unsigned row_end, unsigned col_end,
236 unsigned row_step, unsigned col_step);
237 void row_join(const DenseMatrix &B);
238 void col_join(const DenseMatrix &B);
239 void row_insert(const DenseMatrix &B, unsigned pos);
240 void col_insert(const DenseMatrix &B, unsigned pos);
241 void row_del(unsigned k);
242 void col_del(unsigned k);
243
244 // Row operations
245 friend void row_exchange_dense(DenseMatrix &A, unsigned i, unsigned j);
246 friend void row_mul_scalar_dense(DenseMatrix &A, unsigned i,
247 RCP<const Basic> &c);
248 friend void row_add_row_dense(DenseMatrix &A, unsigned i, unsigned j,
249 RCP<const Basic> &c);
250 friend void permuteFwd(DenseMatrix &A, permutelist &pl);
251
252 // Column operations
253 friend void column_exchange_dense(DenseMatrix &A, unsigned i, unsigned j);
254
255 // Gaussian elimination
256 friend void pivoted_gaussian_elimination(const DenseMatrix &A,
257 DenseMatrix &B,
258 permutelist &pivotlist);
259 friend void fraction_free_gaussian_elimination(const DenseMatrix &A,
260 DenseMatrix &B);
261 friend void pivoted_fraction_free_gaussian_elimination(
262 const DenseMatrix &A, DenseMatrix &B, permutelist &pivotlist);
263 friend void pivoted_gauss_jordan_elimination(const DenseMatrix &A,
264 DenseMatrix &B,
265 permutelist &pivotlist);
266 friend void fraction_free_gauss_jordan_elimination(const DenseMatrix &A,
267 DenseMatrix &B);
268 friend void pivoted_fraction_free_gauss_jordan_elimination(
269 const DenseMatrix &A, DenseMatrix &B, permutelist &pivotlist);
270 friend unsigned pivot(DenseMatrix &B, unsigned r, unsigned c);
271
272 friend void reduced_row_echelon_form(const DenseMatrix &A, DenseMatrix &B,
273 vec_uint &pivot_cols,
274 bool normalize_last);
275
276 // Ax = b
277 friend void diagonal_solve(const DenseMatrix &A, const DenseMatrix &b,
278 DenseMatrix &x);
279 friend void back_substitution(const DenseMatrix &U, const DenseMatrix &b,
280 DenseMatrix &x);
281 friend void forward_substitution(const DenseMatrix &A, const DenseMatrix &b,
282 DenseMatrix &x);
283 friend void fraction_free_gaussian_elimination_solve(const DenseMatrix &A,
284 const DenseMatrix &b,
285 DenseMatrix &x);
286 friend void fraction_free_gauss_jordan_solve(const DenseMatrix &A,
287 const DenseMatrix &b,
288 DenseMatrix &x, bool pivot);
289
290 // Matrix Decomposition
291 friend void fraction_free_LU(const DenseMatrix &A, DenseMatrix &LU);
292 friend void LU(const DenseMatrix &A, DenseMatrix &L, DenseMatrix &U);
293 friend void pivoted_LU(const DenseMatrix &A, DenseMatrix &LU,
294 permutelist &pl);
295 friend void pivoted_LU(const DenseMatrix &A, DenseMatrix &L, DenseMatrix &U,
296 permutelist &pl);
297 friend void fraction_free_LDU(const DenseMatrix &A, DenseMatrix &L,
298 DenseMatrix &D, DenseMatrix &U);
299 friend void QR(const DenseMatrix &A, DenseMatrix &Q, DenseMatrix &R);
300 friend void LDL(const DenseMatrix &A, DenseMatrix &L, DenseMatrix &D);
301 friend void cholesky(const DenseMatrix &A, DenseMatrix &L);
302
303 // Matrix queries
304 friend bool is_symmetric_dense(const DenseMatrix &A);
305
306 // Determinant
307 friend RCP<const Basic> det_bareis(const DenseMatrix &A);
308 friend void berkowitz(const DenseMatrix &A,
310
311 // Inverse
312 friend void inverse_fraction_free_LU(const DenseMatrix &A, DenseMatrix &B);
313 friend void inverse_LU(const DenseMatrix &A, DenseMatrix &B);
314 friend void inverse_pivoted_LU(const DenseMatrix &A, DenseMatrix &B);
315 friend void inverse_gauss_jordan(const DenseMatrix &A, DenseMatrix &B);
316
317 // Vector-specific methods
318 friend void dot(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C);
319 friend void cross(const DenseMatrix &A, const DenseMatrix &B,
320 DenseMatrix &C);
321
322 // NumPy-like functions
323 friend void eye(DenseMatrix &A, int k);
324 friend void diag(DenseMatrix &A, vec_basic &v, int k);
325 friend void ones(DenseMatrix &A);
326 friend void zeros(DenseMatrix &A);
327
328 friend CSRMatrix;
329
330private:
331 // Matrix elements are stored in row-major order
332 vec_basic m_;
333 // Stores the dimension of the Matrix
334 unsigned row_;
335 unsigned col_;
336
337 tribool shortcut_to_posdef() const;
338 tribool is_positive_definite_GE();
339};
340
341// ----------------------------- Sparse Matrices -----------------------------//
342class CSRMatrix : public MatrixBase
343{
344public:
345 CSRMatrix();
346 CSRMatrix(unsigned row, unsigned col);
347 CSRMatrix(unsigned row, unsigned col, const std::vector<unsigned> &p,
348 const std::vector<unsigned> &j, const vec_basic &x);
349 CSRMatrix(unsigned row, unsigned col, std::vector<unsigned> &&p,
351 CSRMatrix &operator=(CSRMatrix &&other);
352 CSRMatrix(const CSRMatrix &) = default;
354 as_vectors() const;
355
356 bool is_canonical() const;
357
358 bool eq(const MatrixBase &other) const override;
359
360 // Get and set elements
361 RCP<const Basic> get(unsigned i, unsigned j) const override;
362 void set(unsigned i, unsigned j, const RCP<const Basic> &e) override;
363
364 unsigned nrows() const override
365 {
366 return row_;
367 }
368 unsigned ncols() const override
369 {
370 return col_;
371 }
372
373 tribool is_real(const Assumptions *assumptions = nullptr) const override;
374 unsigned rank() const override;
375 RCP<const Basic> det() const override;
376 void inv(MatrixBase &result) const override;
377
378 // Matrix addition
379 void add_matrix(const MatrixBase &other, MatrixBase &result) const override;
380
381 // Matrix Multiplication
382 void mul_matrix(const MatrixBase &other, MatrixBase &result) const override;
383
384 // Matrix elementwise Multiplication
385 void elementwise_mul_matrix(const MatrixBase &other,
386 MatrixBase &result) const override;
387
388 // Add a scalar
389 void add_scalar(const RCP<const Basic> &k,
390 MatrixBase &result) const override;
391
392 // Multiply by a scalar
393 void mul_scalar(const RCP<const Basic> &k,
394 MatrixBase &result) const override;
395
396 // Matrix conjugate
397 void conjugate(MatrixBase &result) const override;
398
399 // Matrix transpose
400 void transpose(MatrixBase &result) const override;
401 CSRMatrix transpose(bool conjugate = false) const;
402
403 // Matrix conjugate transpose
404 void conjugate_transpose(MatrixBase &result) const override;
405
406 // Extract out a submatrix
407 void submatrix(MatrixBase &result, unsigned row_start, unsigned col_start,
408 unsigned row_end, unsigned col_end, unsigned row_step = 1,
409 unsigned col_step = 1) const override;
410
411 // LU factorization
412 void LU(MatrixBase &L, MatrixBase &U) const override;
413
414 // LDL factorization
415 void LDL(MatrixBase &L, MatrixBase &D) const override;
416
417 // Solve Ax = b using LU factorization
418 void LU_solve(const MatrixBase &b, MatrixBase &x) const override;
419
420 // Fraction free LU factorization
421 void FFLU(MatrixBase &LU) const override;
422
423 // Fraction free LDU factorization
424 void FFLDU(MatrixBase &L, MatrixBase &D, MatrixBase &U) const override;
425
426 // QR factorization
427 void QR(MatrixBase &Q, MatrixBase &R) const override;
428
429 // Cholesky decomposition
430 void cholesky(MatrixBase &L) const override;
431
432 static void csr_sum_duplicates(std::vector<unsigned> &p_,
434 unsigned row_);
435
436 static void csr_sort_indices(std::vector<unsigned> &p_,
438 unsigned row_);
439
440 static bool csr_has_sorted_indices(const std::vector<unsigned> &p_,
441 const std::vector<unsigned> &j_,
442 unsigned row_);
443
444 static bool csr_has_duplicates(const std::vector<unsigned> &p_,
445 const std::vector<unsigned> &j_,
446 unsigned row_);
447
448 static bool csr_has_canonical_format(const std::vector<unsigned> &p_,
449 const std::vector<unsigned> &j_,
450 unsigned row_);
451
452 static CSRMatrix from_coo(unsigned row, unsigned col,
453 const std::vector<unsigned> &i,
454 const std::vector<unsigned> &j,
455 const vec_basic &x);
456 static CSRMatrix jacobian(const vec_basic &exprs, const vec_sym &x,
457 bool diff_cache = true);
458 static CSRMatrix jacobian(const DenseMatrix &A, const DenseMatrix &x,
459 bool diff_cache = true);
460
461 friend void csr_matmat_pass1(const CSRMatrix &A, const CSRMatrix &B,
462 CSRMatrix &C);
463 friend void csr_matmat_pass2(const CSRMatrix &A, const CSRMatrix &B,
464 CSRMatrix &C);
465 friend void csr_diagonal(const CSRMatrix &A, DenseMatrix &D);
466 friend void csr_scale_rows(CSRMatrix &A, const DenseMatrix &X);
467 friend void csr_scale_columns(CSRMatrix &A, const DenseMatrix &X);
468
469 friend void csr_binop_csr_canonical(
470 const CSRMatrix &A, const CSRMatrix &B, CSRMatrix &C,
471 RCP<const Basic> (&bin_op)(const RCP<const Basic> &,
472 const RCP<const Basic> &));
473
474private:
477 vec_basic x_;
478 // Stores the dimension of the Matrix
479 unsigned row_;
480 unsigned col_;
481};
482
483// Return the Jacobian of the matrix
484void jacobian(const DenseMatrix &A, const DenseMatrix &x, DenseMatrix &result,
485 bool diff_cache = true);
486// Return the Jacobian of the matrix using sdiff
487void sjacobian(const DenseMatrix &A, const DenseMatrix &x, DenseMatrix &result,
488 bool diff_cache = true);
489
490// Differentiate all the elements
491void diff(const DenseMatrix &A, const RCP<const Symbol> &x, DenseMatrix &result,
492 bool diff_cache = true);
493// Differentiate all the elements using SymPy compatible diff
494void sdiff(const DenseMatrix &A, const RCP<const Basic> &x, DenseMatrix &result,
495 bool diff_cache = true);
496
497// Get submatrix from a DenseMatrix
498void submatrix_dense(const DenseMatrix &A, DenseMatrix &B, unsigned row_start,
499 unsigned col_start, unsigned row_end, unsigned col_end,
500 unsigned row_step = 1, unsigned col_step = 1);
501
502// Row operations
503void row_exchange_dense(DenseMatrix &A, unsigned i, unsigned j);
504void row_mul_scalar_dense(DenseMatrix &A, unsigned i, RCP<const Basic> &c);
505void row_add_row_dense(DenseMatrix &A, unsigned i, unsigned j,
506 RCP<const Basic> &c);
507
508// Column operations
509void column_exchange_dense(DenseMatrix &A, unsigned i, unsigned j);
510
511// Vector-specific methods
512void dot(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C);
513void cross(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C);
514
515// Matrix Factorization
516void LU(const DenseMatrix &A, DenseMatrix &L, DenseMatrix &U);
517void LDL(const DenseMatrix &A, DenseMatrix &L, DenseMatrix &D);
518void QR(const DenseMatrix &A, DenseMatrix &Q, DenseMatrix &R);
519void cholesky(const DenseMatrix &A, DenseMatrix &L);
520
521// Inverse
522void inverse_fraction_free_LU(const DenseMatrix &A, DenseMatrix &B);
523
524void inverse_gauss_jordan(const DenseMatrix &A, DenseMatrix &B);
525
526// Solving Ax = b
527void fraction_free_LU_solve(const DenseMatrix &A, const DenseMatrix &b,
528 DenseMatrix &x);
529
530void fraction_free_gauss_jordan_solve(const DenseMatrix &A,
531 const DenseMatrix &b, DenseMatrix &x,
532 bool pivot = true);
533
534void LU_solve(const DenseMatrix &A, const DenseMatrix &b, DenseMatrix &x);
535void pivoted_LU_solve(const DenseMatrix &A, const DenseMatrix &b,
536 DenseMatrix &x);
537
538void LDL_solve(const DenseMatrix &A, const DenseMatrix &b, DenseMatrix &x);
539
540// Determinant
541RCP<const Basic> det_berkowitz(const DenseMatrix &A);
542
543// Characteristic polynomial: Only the coefficients of monomials in decreasing
544// order of monomial powers is returned, i.e. if `B = transpose([1, -2, 3])`
545// then the corresponding polynomial is `x**2 - 2x + 3`.
546void char_poly(const DenseMatrix &A, DenseMatrix &B);
547
548// returns a finiteset of eigenvalues of a matrix
549RCP<const Set> eigen_values(const DenseMatrix &A);
550
551// Mimic `eye` function in NumPy
552void eye(DenseMatrix &A, int k = 0);
553
554// Create diagonal matrices directly
555void diag(DenseMatrix &A, vec_basic &v, int k = 0);
556
557// Create a matrix filled with ones
558void ones(DenseMatrix &A);
559
560// Create a matrix filled with zeros
561void zeros(DenseMatrix &A);
562
563// Reduced row echelon form and returns the cols with pivots
564void reduced_row_echelon_form(const DenseMatrix &A, DenseMatrix &B,
565 vec_uint &pivot_cols,
566 bool normalize_last = false);
567
568// Returns true if `b` is exactly the type T.
569// Here T can be a DenseMatrix, CSRMatrix, etc.
570template <class T>
571inline bool is_a(const MatrixBase &b)
572{
573 return typeid(T) == typeid(b);
574}
575
576// Test two matrices for equality
577inline bool operator==(const SymEngine::MatrixBase &lhs,
578 const SymEngine::MatrixBase &rhs)
579{
580 return lhs.eq(rhs);
581}
582
583// Test two matrices for equality
584inline bool operator!=(const SymEngine::MatrixBase &lhs,
585 const SymEngine::MatrixBase &rhs)
586{
587 return not lhs.eq(rhs);
588}
589
590} // namespace SymEngine
591
592// Print Matrix
594 const SymEngine::MatrixBase &A)
595{
596 return out << A.__str__();
597}
598
599#endif
The base class for SymEngine.
Main namespace for SymEngine package.
Definition: add.cpp:19
bool is_a(const Basic &b)
Templatised version to check is_a type.
Definition: basic-inl.h:36
std::ostream & operator<<(std::ostream &out, const SymEngine::Basic &p)
<< Operator
Definition: basic-inl.h:53
RCP< const Basic > conjugate(const RCP< const Basic > &arg)
Canonicalize Conjugate.
Definition: functions.cpp:149
T operator!=(T... args)