Add matrix multiplication infrastructure

This commit is contained in:
Pavel Krajcevski 2014-02-21 16:18:00 -05:00
parent 05eeb09f36
commit 8b9e8cd9b5
2 changed files with 131 additions and 17 deletions

View file

@ -30,11 +30,17 @@
namespace FasTC { namespace FasTC {
template <typename T, const int nRows, const int nCols> template <typename T, const int nRows, const int nCols>
class MatrixBase : public VectorBase<T, nRows * nCols> { class MatrixBase {
private: protected:
typedef VectorBase<T, nRows * nCols> Base;
// Vector representation
T mat[nRows * nCols];
public: public:
static const int Size = Base::Size; typedef T ScalarType;
static const int kNumRows = nRows;
static const int kNumCols = nCols;
static const int Size = kNumCols * kNumRows;
// Constructors // Constructors
MatrixBase() { } MatrixBase() { }
@ -45,16 +51,16 @@ namespace FasTC {
} }
// Accessors // Accessors
T &operator()(int idx) { return Base::operator()(idx); } T &operator()(int idx) { return mat[idx]; }
T &operator[](int idx) { return Base::operator[](idx); } T &operator[](int idx) { return mat[idx]; }
const T &operator()(int idx) const { return Base::operator()(idx); } const T &operator()(int idx) const { return mat[idx]; }
const T &operator[](int idx) const { return Base::operator[](idx); } const T &operator[](int idx) const { return mat[idx]; }
T &operator()(int r, int c) { return (*this)[r * nCols + c]; } T &operator()(int r, int c) { return (*this)[r * nCols + c]; }
const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; } const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; }
// Allow casts to the respective array representation... // Allow casts to the respective array representation...
operator const T *() const { return this->vec; } operator const T *() const { return this->mat; }
MatrixBase<T, nRows, nCols> &operator=(const T *v) { MatrixBase<T, nRows, nCols> &operator=(const T *v) {
for(int i = 0; i < Size; i++) for(int i = 0; i < Size; i++)
(*this)[i] = v[i]; (*this)[i] = v[i];
@ -66,7 +72,7 @@ namespace FasTC {
operator MatrixBase<_T, nRows, nCols>() const { operator MatrixBase<_T, nRows, nCols>() const {
MatrixBase<_T, nRows, nCols> ret; MatrixBase<_T, nRows, nCols> ret;
for(int i = 0; i < Size; i++) { for(int i = 0; i < Size; i++) {
ret[i] = static_cast<_T>(this->vec[i]); ret[i] = static_cast<_T>(mat[i]);
} }
return ret; return ret;
} }
@ -87,8 +93,20 @@ namespace FasTC {
// Vector multiplication -- treat vectors as Nx1 matrices... // Vector multiplication -- treat vectors as Nx1 matrices...
template<typename _T> template<typename _T>
VectorBase<T, nCols> operator*(const VectorBase<_T, nCols> &v) { VectorBase<T, nCols> MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const {
VectorBase<T, nCols> result; VectorBase<T, nCols> result;
for(int j = 0; j < nCols; j++) {
result(j) = 0;
for(int r = 0; r < nRows; r++) {
result(j) += (*this)(r, j) * v(r);
}
}
return result;
}
template<typename _T>
VectorBase<T, nRows> MultiplyVectorRight(const VectorBase<_T, nCols> &v) const {
VectorBase<T, nRows> result;
for(int r = 0; r < nRows; r++) { for(int r = 0; r < nRows; r++) {
result(r) = 0; result(r) = 0;
for(int j = 0; j < nCols; j++) { for(int j = 0; j < nCols; j++) {
@ -111,14 +129,88 @@ namespace FasTC {
// Double dot product // Double dot product
template<typename _T> template<typename _T>
T DDot(const MatrixBase<_T, nRows, nCols> &m) { T DDot(const MatrixBase<_T, nRows, nCols> &m) const {
T result = 0; T result = 0;
for(int i = 0; i < Size; i++) { for(int i = 0; i < Size; i++) {
result += (*this)[i] * m[i]; result += (*this)[i] * m[i];
} }
return result; return result;
} }
};
template<typename T, const int N, const int M>
class VectorTraits<MatrixBase<T, N, M> > {
public:
static const EVectorType kVectorType = eVectorType_Matrix;
};
#define REGISTER_MATRIX_TYPE(TYPE) \
template<> \
class VectorTraits< TYPE > { \
public: \
static const EVectorType kVectorType = eVectorType_Matrix; \
}
#define REGISTER_ONE_TEMPLATE_MATRIX_TYPE(TYPE) \
template<typename T> \
class VectorTraits< TYPE <T> > { \
public: \
static const EVectorType kVectorType = eVectorType_Matrix; \
}
// Define matrix multiplication for * operator
template<typename TypeOne, typename TypeTwo>
class MultSwitch<
eVectorType_Matrix,
eVectorType_Vector,
TypeOne, TypeTwo> {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef VectorBase<typename TypeTwo::ScalarType, TypeOne::kNumRows> ResultType;
MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); }
};
template<typename TypeOne, typename TypeTwo>
class MultSwitch<
eVectorType_Vector,
eVectorType_Matrix,
TypeOne, TypeTwo> {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef VectorBase<typename TypeOne::ScalarType, TypeTwo::kNumCols> ResultType;
MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); }
};
template<typename TypeOne, typename TypeTwo>
class MultSwitch<
eVectorType_Matrix,
eVectorType_Matrix,
TypeOne, TypeTwo> {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef MatrixBase<typename TypeOne::ScalarType, TypeOne::kNumRows, TypeTwo::kNumCols> ResultType;
MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); }
}; };
// Outer product... // Outer product...

View file

@ -158,9 +158,9 @@ TEST(MatrixBase, MatrixMultiplication) {
TEST(MatrixBase, Transposition) { TEST(MatrixBase, Transposition) {
FasTC::MatrixBase<int, 3, 5> a; FasTC::MatrixBase<int, 3, 5> a;
a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0; a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0;
a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3; a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3;
a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5; a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5;
FasTC::MatrixBase<int, 5, 3> b = a.Transpose(); FasTC::MatrixBase<int, 5, 3> b = a.Transpose();
@ -172,8 +172,30 @@ TEST(MatrixBase, Transposition) {
} }
TEST(MatrixBase, VectorMultiplication) { TEST(MatrixBase, VectorMultiplication) {
// Stub
EXPECT_EQ(0, 1); FasTC::MatrixBase<int, 3, 5> a;
a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0;
a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3;
a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5;
FasTC::VectorBase<int, 5> v;
for(int i = 0; i < 5; i++) v[i] = i + 1;
FasTC::VectorBase<int, 3> u = a * v;
EXPECT_EQ(u[0], -1 + (2 * 2) - (4 * 3) + (5 * 4));
EXPECT_EQ(u[1], 1 + (2 * 2) + (4 * 3) + (6 * 4) + (3 * 5));
EXPECT_EQ(u[2], -1 + (-2 * 2) - (3 * 3) - (4 * 4) + (5 * 5));
/////
for(int i = 0; i < 3; i++) u[i] = i + 1;
v = u * a;
EXPECT_EQ(v[0], -1 + (1 * 2) - (1 * 3));
EXPECT_EQ(v[1], 2 + (2 * 2) - (2 * 3));
EXPECT_EQ(v[2], -4 + (4 * 2) - (3 * 3));
EXPECT_EQ(v[3], 5 + (6 * 2) - (4 * 3));
EXPECT_EQ(v[4], 0 + (3 * 2) + (5 * 3));
} }
TEST(MatrixSquare, Constructors) { TEST(MatrixSquare, Constructors) {