mirror of
https://github.com/yuzu-emu/FasTC.git
synced 2025-01-23 19:11:05 +00:00
Add matrix multiplication infrastructure
This commit is contained in:
parent
05eeb09f36
commit
8b9e8cd9b5
|
@ -30,11 +30,17 @@
|
||||||
namespace FasTC {
|
namespace FasTC {
|
||||||
|
|
||||||
template <typename T, const int nRows, const int nCols>
|
template <typename T, const int nRows, const int nCols>
|
||||||
class MatrixBase : public VectorBase<T, nRows * nCols> {
|
class MatrixBase {
|
||||||
private:
|
protected:
|
||||||
typedef VectorBase<T, nRows * nCols> Base;
|
|
||||||
|
// Vector representation
|
||||||
|
T mat[nRows * nCols];
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static const int Size = Base::Size;
|
typedef T ScalarType;
|
||||||
|
static const int kNumRows = nRows;
|
||||||
|
static const int kNumCols = nCols;
|
||||||
|
static const int Size = kNumCols * kNumRows;
|
||||||
|
|
||||||
// Constructors
|
// Constructors
|
||||||
MatrixBase() { }
|
MatrixBase() { }
|
||||||
|
@ -45,16 +51,16 @@ namespace FasTC {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accessors
|
// Accessors
|
||||||
T &operator()(int idx) { return Base::operator()(idx); }
|
T &operator()(int idx) { return mat[idx]; }
|
||||||
T &operator[](int idx) { return Base::operator[](idx); }
|
T &operator[](int idx) { return mat[idx]; }
|
||||||
const T &operator()(int idx) const { return Base::operator()(idx); }
|
const T &operator()(int idx) const { return mat[idx]; }
|
||||||
const T &operator[](int idx) const { return Base::operator[](idx); }
|
const T &operator[](int idx) const { return mat[idx]; }
|
||||||
|
|
||||||
T &operator()(int r, int c) { return (*this)[r * nCols + c]; }
|
T &operator()(int r, int c) { return (*this)[r * nCols + c]; }
|
||||||
const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; }
|
const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; }
|
||||||
|
|
||||||
// Allow casts to the respective array representation...
|
// Allow casts to the respective array representation...
|
||||||
operator const T *() const { return this->vec; }
|
operator const T *() const { return this->mat; }
|
||||||
MatrixBase<T, nRows, nCols> &operator=(const T *v) {
|
MatrixBase<T, nRows, nCols> &operator=(const T *v) {
|
||||||
for(int i = 0; i < Size; i++)
|
for(int i = 0; i < Size; i++)
|
||||||
(*this)[i] = v[i];
|
(*this)[i] = v[i];
|
||||||
|
@ -66,7 +72,7 @@ namespace FasTC {
|
||||||
operator MatrixBase<_T, nRows, nCols>() const {
|
operator MatrixBase<_T, nRows, nCols>() const {
|
||||||
MatrixBase<_T, nRows, nCols> ret;
|
MatrixBase<_T, nRows, nCols> ret;
|
||||||
for(int i = 0; i < Size; i++) {
|
for(int i = 0; i < Size; i++) {
|
||||||
ret[i] = static_cast<_T>(this->vec[i]);
|
ret[i] = static_cast<_T>(mat[i]);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -87,8 +93,20 @@ namespace FasTC {
|
||||||
|
|
||||||
// Vector multiplication -- treat vectors as Nx1 matrices...
|
// Vector multiplication -- treat vectors as Nx1 matrices...
|
||||||
template<typename _T>
|
template<typename _T>
|
||||||
VectorBase<T, nCols> operator*(const VectorBase<_T, nCols> &v) {
|
VectorBase<T, nCols> MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const {
|
||||||
VectorBase<T, nCols> result;
|
VectorBase<T, nCols> result;
|
||||||
|
for(int j = 0; j < nCols; j++) {
|
||||||
|
result(j) = 0;
|
||||||
|
for(int r = 0; r < nRows; r++) {
|
||||||
|
result(j) += (*this)(r, j) * v(r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename _T>
|
||||||
|
VectorBase<T, nRows> MultiplyVectorRight(const VectorBase<_T, nCols> &v) const {
|
||||||
|
VectorBase<T, nRows> result;
|
||||||
for(int r = 0; r < nRows; r++) {
|
for(int r = 0; r < nRows; r++) {
|
||||||
result(r) = 0;
|
result(r) = 0;
|
||||||
for(int j = 0; j < nCols; j++) {
|
for(int j = 0; j < nCols; j++) {
|
||||||
|
@ -111,14 +129,88 @@ namespace FasTC {
|
||||||
|
|
||||||
// Double dot product
|
// Double dot product
|
||||||
template<typename _T>
|
template<typename _T>
|
||||||
T DDot(const MatrixBase<_T, nRows, nCols> &m) {
|
T DDot(const MatrixBase<_T, nRows, nCols> &m) const {
|
||||||
T result = 0;
|
T result = 0;
|
||||||
for(int i = 0; i < Size; i++) {
|
for(int i = 0; i < Size; i++) {
|
||||||
result += (*this)[i] * m[i];
|
result += (*this)[i] * m[i];
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T, const int N, const int M>
|
||||||
|
class VectorTraits<MatrixBase<T, N, M> > {
|
||||||
|
public:
|
||||||
|
static const EVectorType kVectorType = eVectorType_Matrix;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_MATRIX_TYPE(TYPE) \
|
||||||
|
template<> \
|
||||||
|
class VectorTraits< TYPE > { \
|
||||||
|
public: \
|
||||||
|
static const EVectorType kVectorType = eVectorType_Matrix; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REGISTER_ONE_TEMPLATE_MATRIX_TYPE(TYPE) \
|
||||||
|
template<typename T> \
|
||||||
|
class VectorTraits< TYPE <T> > { \
|
||||||
|
public: \
|
||||||
|
static const EVectorType kVectorType = eVectorType_Matrix; \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define matrix multiplication for * operator
|
||||||
|
template<typename TypeOne, typename TypeTwo>
|
||||||
|
class MultSwitch<
|
||||||
|
eVectorType_Matrix,
|
||||||
|
eVectorType_Vector,
|
||||||
|
TypeOne, TypeTwo> {
|
||||||
|
private:
|
||||||
|
const TypeOne &m_A;
|
||||||
|
const TypeTwo &m_B;
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef VectorBase<typename TypeTwo::ScalarType, TypeOne::kNumRows> ResultType;
|
||||||
|
|
||||||
|
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||||
|
: m_A(a), m_B(b) { }
|
||||||
|
|
||||||
|
ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename TypeOne, typename TypeTwo>
|
||||||
|
class MultSwitch<
|
||||||
|
eVectorType_Vector,
|
||||||
|
eVectorType_Matrix,
|
||||||
|
TypeOne, TypeTwo> {
|
||||||
|
private:
|
||||||
|
const TypeOne &m_A;
|
||||||
|
const TypeTwo &m_B;
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef VectorBase<typename TypeOne::ScalarType, TypeTwo::kNumCols> ResultType;
|
||||||
|
|
||||||
|
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||||
|
: m_A(a), m_B(b) { }
|
||||||
|
|
||||||
|
ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename TypeOne, typename TypeTwo>
|
||||||
|
class MultSwitch<
|
||||||
|
eVectorType_Matrix,
|
||||||
|
eVectorType_Matrix,
|
||||||
|
TypeOne, TypeTwo> {
|
||||||
|
private:
|
||||||
|
const TypeOne &m_A;
|
||||||
|
const TypeTwo &m_B;
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef MatrixBase<typename TypeOne::ScalarType, TypeOne::kNumRows, TypeTwo::kNumCols> ResultType;
|
||||||
|
|
||||||
|
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||||
|
: m_A(a), m_B(b) { }
|
||||||
|
|
||||||
|
ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Outer product...
|
// Outer product...
|
||||||
|
|
|
@ -172,8 +172,30 @@ TEST(MatrixBase, Transposition) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MatrixBase, VectorMultiplication) {
|
TEST(MatrixBase, VectorMultiplication) {
|
||||||
// Stub
|
|
||||||
EXPECT_EQ(0, 1);
|
FasTC::MatrixBase<int, 3, 5> a;
|
||||||
|
a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0;
|
||||||
|
a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3;
|
||||||
|
a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5;
|
||||||
|
|
||||||
|
FasTC::VectorBase<int, 5> v;
|
||||||
|
for(int i = 0; i < 5; i++) v[i] = i + 1;
|
||||||
|
|
||||||
|
FasTC::VectorBase<int, 3> u = a * v;
|
||||||
|
EXPECT_EQ(u[0], -1 + (2 * 2) - (4 * 3) + (5 * 4));
|
||||||
|
EXPECT_EQ(u[1], 1 + (2 * 2) + (4 * 3) + (6 * 4) + (3 * 5));
|
||||||
|
EXPECT_EQ(u[2], -1 + (-2 * 2) - (3 * 3) - (4 * 4) + (5 * 5));
|
||||||
|
|
||||||
|
/////
|
||||||
|
|
||||||
|
for(int i = 0; i < 3; i++) u[i] = i + 1;
|
||||||
|
v = u * a;
|
||||||
|
|
||||||
|
EXPECT_EQ(v[0], -1 + (1 * 2) - (1 * 3));
|
||||||
|
EXPECT_EQ(v[1], 2 + (2 * 2) - (2 * 3));
|
||||||
|
EXPECT_EQ(v[2], -4 + (4 * 2) - (3 * 3));
|
||||||
|
EXPECT_EQ(v[3], 5 + (6 * 2) - (4 * 3));
|
||||||
|
EXPECT_EQ(v[4], 0 + (3 * 2) + (5 * 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MatrixSquare, Constructors) {
|
TEST(MatrixSquare, Constructors) {
|
||||||
|
|
Loading…
Reference in a new issue