mirror of
				https://github.com/yuzu-emu/FasTC.git
				synced 2025-11-04 06:14:50 +00:00 
			
		
		
		
	Add matrix multiplication infrastructure
This commit is contained in:
		
							parent
							
								
									05eeb09f36
								
							
						
					
					
						commit
						8b9e8cd9b5
					
				| 
						 | 
					@ -30,11 +30,17 @@
 | 
				
			||||||
namespace FasTC {
 | 
					namespace FasTC {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  template <typename T, const int nRows, const int nCols>
 | 
					  template <typename T, const int nRows, const int nCols>
 | 
				
			||||||
  class MatrixBase : public VectorBase<T, nRows * nCols> {
 | 
					  class MatrixBase {
 | 
				
			||||||
   private:
 | 
					   protected:
 | 
				
			||||||
    typedef VectorBase<T, nRows * nCols> Base;
 | 
					
 | 
				
			||||||
 | 
					    // Vector representation
 | 
				
			||||||
 | 
					    T mat[nRows * nCols];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   public:
 | 
					   public:
 | 
				
			||||||
    static const int Size = Base::Size;
 | 
					    typedef T ScalarType;
 | 
				
			||||||
 | 
					    static const int kNumRows = nRows;
 | 
				
			||||||
 | 
					    static const int kNumCols = nCols;
 | 
				
			||||||
 | 
					    static const int Size = kNumCols * kNumRows;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Constructors
 | 
					    // Constructors
 | 
				
			||||||
    MatrixBase() { }
 | 
					    MatrixBase() { }
 | 
				
			||||||
| 
						 | 
					@ -45,16 +51,16 @@ namespace FasTC {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Accessors
 | 
					    // Accessors
 | 
				
			||||||
    T &operator()(int idx) { return Base::operator()(idx); }
 | 
					    T &operator()(int idx) { return mat[idx]; }
 | 
				
			||||||
    T &operator[](int idx) { return Base::operator[](idx); }
 | 
					    T &operator[](int idx) { return mat[idx]; }
 | 
				
			||||||
    const T &operator()(int idx) const { return Base::operator()(idx); }
 | 
					    const T &operator()(int idx) const { return mat[idx]; }
 | 
				
			||||||
    const T &operator[](int idx) const { return Base::operator[](idx); }
 | 
					    const T &operator[](int idx) const { return mat[idx]; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    T &operator()(int r, int c) { return (*this)[r * nCols + c]; }
 | 
					    T &operator()(int r, int c) { return (*this)[r * nCols + c]; }
 | 
				
			||||||
    const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; }
 | 
					    const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Allow casts to the respective array representation...
 | 
					    // Allow casts to the respective array representation...
 | 
				
			||||||
    operator const T *() const { return this->vec; }
 | 
					    operator const T *() const { return this->mat; }
 | 
				
			||||||
    MatrixBase<T, nRows, nCols> &operator=(const T *v) {
 | 
					    MatrixBase<T, nRows, nCols> &operator=(const T *v) {
 | 
				
			||||||
      for(int i = 0; i < Size; i++)
 | 
					      for(int i = 0; i < Size; i++)
 | 
				
			||||||
        (*this)[i] = v[i];
 | 
					        (*this)[i] = v[i];
 | 
				
			||||||
| 
						 | 
					@ -66,7 +72,7 @@ namespace FasTC {
 | 
				
			||||||
    operator MatrixBase<_T, nRows, nCols>() const { 
 | 
					    operator MatrixBase<_T, nRows, nCols>() const { 
 | 
				
			||||||
      MatrixBase<_T, nRows, nCols> ret;
 | 
					      MatrixBase<_T, nRows, nCols> ret;
 | 
				
			||||||
      for(int i = 0; i < Size; i++) {
 | 
					      for(int i = 0; i < Size; i++) {
 | 
				
			||||||
        ret[i] = static_cast<_T>(this->vec[i]);
 | 
					        ret[i] = static_cast<_T>(mat[i]);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      return ret;
 | 
					      return ret;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -87,8 +93,20 @@ namespace FasTC {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Vector multiplication -- treat vectors as Nx1 matrices...
 | 
					    // Vector multiplication -- treat vectors as Nx1 matrices...
 | 
				
			||||||
    template<typename _T>
 | 
					    template<typename _T>
 | 
				
			||||||
    VectorBase<T, nCols> operator*(const VectorBase<_T, nCols> &v) {
 | 
					    VectorBase<T, nCols> MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const {
 | 
				
			||||||
      VectorBase<T, nCols> result;
 | 
					      VectorBase<T, nCols> result;
 | 
				
			||||||
 | 
					      for(int j = 0; j < nCols; j++) {
 | 
				
			||||||
 | 
					        result(j) = 0;
 | 
				
			||||||
 | 
					        for(int r = 0; r < nRows; r++) {
 | 
				
			||||||
 | 
					          result(j) += (*this)(r, j) * v(r);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return result;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template<typename _T>
 | 
				
			||||||
 | 
					    VectorBase<T, nRows> MultiplyVectorRight(const VectorBase<_T, nCols> &v) const {
 | 
				
			||||||
 | 
					      VectorBase<T, nRows> result;
 | 
				
			||||||
      for(int r = 0; r < nRows; r++) {
 | 
					      for(int r = 0; r < nRows; r++) {
 | 
				
			||||||
        result(r) = 0;
 | 
					        result(r) = 0;
 | 
				
			||||||
        for(int j = 0; j < nCols; j++) {
 | 
					        for(int j = 0; j < nCols; j++) {
 | 
				
			||||||
| 
						 | 
					@ -111,14 +129,88 @@ namespace FasTC {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Double dot product
 | 
					    // Double dot product
 | 
				
			||||||
    template<typename _T>
 | 
					    template<typename _T>
 | 
				
			||||||
    T DDot(const MatrixBase<_T, nRows, nCols> &m) {
 | 
					    T DDot(const MatrixBase<_T, nRows, nCols> &m) const {
 | 
				
			||||||
      T result = 0;
 | 
					      T result = 0;
 | 
				
			||||||
      for(int i = 0; i < Size; i++) {
 | 
					      for(int i = 0; i < Size; i++) {
 | 
				
			||||||
        result += (*this)[i] * m[i];
 | 
					        result += (*this)[i] * m[i];
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      return result;
 | 
					      return result;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  template<typename T, const int N, const int M>
 | 
				
			||||||
 | 
					  class VectorTraits<MatrixBase<T, N, M> > {
 | 
				
			||||||
 | 
					   public:
 | 
				
			||||||
 | 
					    static const EVectorType kVectorType = eVectorType_Matrix;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  #define REGISTER_MATRIX_TYPE(TYPE)                           \
 | 
				
			||||||
 | 
					  template<>                                                   \
 | 
				
			||||||
 | 
					  class VectorTraits< TYPE > {                                 \
 | 
				
			||||||
 | 
					  public:                                                      \
 | 
				
			||||||
 | 
					    static const EVectorType kVectorType = eVectorType_Matrix; \
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  #define REGISTER_ONE_TEMPLATE_MATRIX_TYPE(TYPE)              \
 | 
				
			||||||
 | 
					  template<typename T>                                         \
 | 
				
			||||||
 | 
					  class VectorTraits< TYPE <T> > {                             \
 | 
				
			||||||
 | 
					  public:                                                      \
 | 
				
			||||||
 | 
					    static const EVectorType kVectorType = eVectorType_Matrix; \
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Define matrix multiplication for * operator
 | 
				
			||||||
 | 
					  template<typename TypeOne, typename TypeTwo>
 | 
				
			||||||
 | 
					  class MultSwitch<
 | 
				
			||||||
 | 
					    eVectorType_Matrix,
 | 
				
			||||||
 | 
					    eVectorType_Vector,
 | 
				
			||||||
 | 
					    TypeOne, TypeTwo> {
 | 
				
			||||||
 | 
					   private:
 | 
				
			||||||
 | 
					    const TypeOne &m_A;
 | 
				
			||||||
 | 
					    const TypeTwo &m_B;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   public:
 | 
				
			||||||
 | 
					    typedef VectorBase<typename TypeTwo::ScalarType, TypeOne::kNumRows> ResultType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    MultSwitch(const TypeOne &a, const TypeTwo &b)
 | 
				
			||||||
 | 
					      : m_A(a), m_B(b) { }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  template<typename TypeOne, typename TypeTwo>
 | 
				
			||||||
 | 
					  class MultSwitch<
 | 
				
			||||||
 | 
					    eVectorType_Vector,
 | 
				
			||||||
 | 
					    eVectorType_Matrix,
 | 
				
			||||||
 | 
					    TypeOne, TypeTwo> {
 | 
				
			||||||
 | 
					   private:
 | 
				
			||||||
 | 
					    const TypeOne &m_A;
 | 
				
			||||||
 | 
					    const TypeTwo &m_B;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   public:
 | 
				
			||||||
 | 
					    typedef VectorBase<typename TypeOne::ScalarType, TypeTwo::kNumCols> ResultType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    MultSwitch(const TypeOne &a, const TypeTwo &b)
 | 
				
			||||||
 | 
					      : m_A(a), m_B(b) { }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  template<typename TypeOne, typename TypeTwo>
 | 
				
			||||||
 | 
					  class MultSwitch<
 | 
				
			||||||
 | 
					    eVectorType_Matrix,
 | 
				
			||||||
 | 
					    eVectorType_Matrix,
 | 
				
			||||||
 | 
					    TypeOne, TypeTwo> {
 | 
				
			||||||
 | 
					   private:
 | 
				
			||||||
 | 
					    const TypeOne &m_A;
 | 
				
			||||||
 | 
					    const TypeTwo &m_B;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   public:
 | 
				
			||||||
 | 
					    typedef MatrixBase<typename TypeOne::ScalarType, TypeOne::kNumRows, TypeTwo::kNumCols> ResultType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    MultSwitch(const TypeOne &a, const TypeTwo &b)
 | 
				
			||||||
 | 
					      : m_A(a), m_B(b) { }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Outer product...
 | 
					  // Outer product...
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -172,8 +172,30 @@ TEST(MatrixBase, Transposition) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(MatrixBase, VectorMultiplication) {
 | 
					TEST(MatrixBase, VectorMultiplication) {
 | 
				
			||||||
  // Stub
 | 
					
 | 
				
			||||||
  EXPECT_EQ(0, 1);
 | 
					  FasTC::MatrixBase<int, 3, 5> a;
 | 
				
			||||||
 | 
					  a(0, 0) = -1;  a(0, 1) =  2;  a(0, 2) = -4; a(0, 3) =  5; a(0, 4) = 0;
 | 
				
			||||||
 | 
					  a(1, 0) =  1;  a(1, 1) =  2;  a(1, 2) =  4; a(1, 3) =  6; a(1, 4) = 3;
 | 
				
			||||||
 | 
					  a(2, 0) = -1;  a(2, 1) = -2;  a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  FasTC::VectorBase<int, 5> v;
 | 
				
			||||||
 | 
					  for(int i = 0; i < 5; i++) v[i] = i + 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  FasTC::VectorBase<int, 3> u = a * v;
 | 
				
			||||||
 | 
					  EXPECT_EQ(u[0], -1 + (2 * 2) - (4 * 3) + (5 * 4));
 | 
				
			||||||
 | 
					  EXPECT_EQ(u[1], 1 + (2 * 2) + (4 * 3) + (6 * 4) + (3 * 5));
 | 
				
			||||||
 | 
					  EXPECT_EQ(u[2], -1 + (-2 * 2) - (3 * 3) - (4 * 4) + (5 * 5));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /////
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for(int i = 0; i < 3; i++) u[i] = i + 1;
 | 
				
			||||||
 | 
					  v = u * a;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_EQ(v[0], -1 + (1 * 2) - (1 * 3));
 | 
				
			||||||
 | 
					  EXPECT_EQ(v[1], 2 + (2 * 2) - (2 * 3));
 | 
				
			||||||
 | 
					  EXPECT_EQ(v[2], -4 + (4 * 2) - (3 * 3));
 | 
				
			||||||
 | 
					  EXPECT_EQ(v[3], 5 + (6 * 2) - (4 * 3));
 | 
				
			||||||
 | 
					  EXPECT_EQ(v[4], 0 + (3 * 2) + (5 * 3));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(MatrixSquare, Constructors) {
 | 
					TEST(MatrixSquare, Constructors) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue