mirror of
				https://github.com/yuzu-emu/FasTC.git
				synced 2025-11-04 11:34:50 +00:00 
			
		
		
		
	Add matrix multiplication infrastructure
This commit is contained in:
		
							parent
							
								
									05eeb09f36
								
							
						
					
					
						commit
						8b9e8cd9b5
					
				| 
						 | 
				
			
			@ -30,11 +30,17 @@
 | 
			
		|||
namespace FasTC {
 | 
			
		||||
 | 
			
		||||
  template <typename T, const int nRows, const int nCols>
 | 
			
		||||
  class MatrixBase : public VectorBase<T, nRows * nCols> {
 | 
			
		||||
   private:
 | 
			
		||||
    typedef VectorBase<T, nRows * nCols> Base;
 | 
			
		||||
  class MatrixBase {
 | 
			
		||||
   protected:
 | 
			
		||||
 | 
			
		||||
    // Vector representation
 | 
			
		||||
    T mat[nRows * nCols];
 | 
			
		||||
 | 
			
		||||
   public:
 | 
			
		||||
    static const int Size = Base::Size;
 | 
			
		||||
    typedef T ScalarType;
 | 
			
		||||
    static const int kNumRows = nRows;
 | 
			
		||||
    static const int kNumCols = nCols;
 | 
			
		||||
    static const int Size = kNumCols * kNumRows;
 | 
			
		||||
 | 
			
		||||
    // Constructors
 | 
			
		||||
    MatrixBase() { }
 | 
			
		||||
| 
						 | 
				
			
			@ -45,16 +51,16 @@ namespace FasTC {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    // Accessors
 | 
			
		||||
    T &operator()(int idx) { return Base::operator()(idx); }
 | 
			
		||||
    T &operator[](int idx) { return Base::operator[](idx); }
 | 
			
		||||
    const T &operator()(int idx) const { return Base::operator()(idx); }
 | 
			
		||||
    const T &operator[](int idx) const { return Base::operator[](idx); }
 | 
			
		||||
    T &operator()(int idx) { return mat[idx]; }
 | 
			
		||||
    T &operator[](int idx) { return mat[idx]; }
 | 
			
		||||
    const T &operator()(int idx) const { return mat[idx]; }
 | 
			
		||||
    const T &operator[](int idx) const { return mat[idx]; }
 | 
			
		||||
 | 
			
		||||
    T &operator()(int r, int c) { return (*this)[r * nCols + c]; }
 | 
			
		||||
    const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; }
 | 
			
		||||
 | 
			
		||||
    // Allow casts to the respective array representation...
 | 
			
		||||
    operator const T *() const { return this->vec; }
 | 
			
		||||
    operator const T *() const { return this->mat; }
 | 
			
		||||
    MatrixBase<T, nRows, nCols> &operator=(const T *v) {
 | 
			
		||||
      for(int i = 0; i < Size; i++)
 | 
			
		||||
        (*this)[i] = v[i];
 | 
			
		||||
| 
						 | 
				
			
			@ -66,7 +72,7 @@ namespace FasTC {
 | 
			
		|||
    operator MatrixBase<_T, nRows, nCols>() const { 
 | 
			
		||||
      MatrixBase<_T, nRows, nCols> ret;
 | 
			
		||||
      for(int i = 0; i < Size; i++) {
 | 
			
		||||
        ret[i] = static_cast<_T>(this->vec[i]);
 | 
			
		||||
        ret[i] = static_cast<_T>(mat[i]);
 | 
			
		||||
      }
 | 
			
		||||
      return ret;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -87,8 +93,20 @@ namespace FasTC {
 | 
			
		|||
 | 
			
		||||
    // Vector multiplication -- treat vectors as Nx1 matrices...
 | 
			
		||||
    template<typename _T>
 | 
			
		||||
    VectorBase<T, nCols> operator*(const VectorBase<_T, nCols> &v) {
 | 
			
		||||
    VectorBase<T, nCols> MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const {
 | 
			
		||||
      VectorBase<T, nCols> result;
 | 
			
		||||
      for(int j = 0; j < nCols; j++) {
 | 
			
		||||
        result(j) = 0;
 | 
			
		||||
        for(int r = 0; r < nRows; r++) {
 | 
			
		||||
          result(j) += (*this)(r, j) * v(r);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      return result;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template<typename _T>
 | 
			
		||||
    VectorBase<T, nRows> MultiplyVectorRight(const VectorBase<_T, nCols> &v) const {
 | 
			
		||||
      VectorBase<T, nRows> result;
 | 
			
		||||
      for(int r = 0; r < nRows; r++) {
 | 
			
		||||
        result(r) = 0;
 | 
			
		||||
        for(int j = 0; j < nCols; j++) {
 | 
			
		||||
| 
						 | 
				
			
			@ -111,14 +129,88 @@ namespace FasTC {
 | 
			
		|||
 | 
			
		||||
    // Double dot product
 | 
			
		||||
    template<typename _T>
 | 
			
		||||
    T DDot(const MatrixBase<_T, nRows, nCols> &m) {
 | 
			
		||||
    T DDot(const MatrixBase<_T, nRows, nCols> &m) const {
 | 
			
		||||
      T result = 0;
 | 
			
		||||
      for(int i = 0; i < Size; i++) {
 | 
			
		||||
        result += (*this)[i] * m[i];
 | 
			
		||||
      }
 | 
			
		||||
      return result;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  template<typename T, const int N, const int M>
 | 
			
		||||
  class VectorTraits<MatrixBase<T, N, M> > {
 | 
			
		||||
   public:
 | 
			
		||||
    static const EVectorType kVectorType = eVectorType_Matrix;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  #define REGISTER_MATRIX_TYPE(TYPE)                           \
 | 
			
		||||
  template<>                                                   \
 | 
			
		||||
  class VectorTraits< TYPE > {                                 \
 | 
			
		||||
  public:                                                      \
 | 
			
		||||
    static const EVectorType kVectorType = eVectorType_Matrix; \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  #define REGISTER_ONE_TEMPLATE_MATRIX_TYPE(TYPE)              \
 | 
			
		||||
  template<typename T>                                         \
 | 
			
		||||
  class VectorTraits< TYPE <T> > {                             \
 | 
			
		||||
  public:                                                      \
 | 
			
		||||
    static const EVectorType kVectorType = eVectorType_Matrix; \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Define matrix multiplication for * operator
 | 
			
		||||
  template<typename TypeOne, typename TypeTwo>
 | 
			
		||||
  class MultSwitch<
 | 
			
		||||
    eVectorType_Matrix,
 | 
			
		||||
    eVectorType_Vector,
 | 
			
		||||
    TypeOne, TypeTwo> {
 | 
			
		||||
   private:
 | 
			
		||||
    const TypeOne &m_A;
 | 
			
		||||
    const TypeTwo &m_B;
 | 
			
		||||
 | 
			
		||||
   public:
 | 
			
		||||
    typedef VectorBase<typename TypeTwo::ScalarType, TypeOne::kNumRows> ResultType;
 | 
			
		||||
 | 
			
		||||
    MultSwitch(const TypeOne &a, const TypeTwo &b)
 | 
			
		||||
      : m_A(a), m_B(b) { }
 | 
			
		||||
 | 
			
		||||
    ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  template<typename TypeOne, typename TypeTwo>
 | 
			
		||||
  class MultSwitch<
 | 
			
		||||
    eVectorType_Vector,
 | 
			
		||||
    eVectorType_Matrix,
 | 
			
		||||
    TypeOne, TypeTwo> {
 | 
			
		||||
   private:
 | 
			
		||||
    const TypeOne &m_A;
 | 
			
		||||
    const TypeTwo &m_B;
 | 
			
		||||
 | 
			
		||||
   public:
 | 
			
		||||
    typedef VectorBase<typename TypeOne::ScalarType, TypeTwo::kNumCols> ResultType;
 | 
			
		||||
 | 
			
		||||
    MultSwitch(const TypeOne &a, const TypeTwo &b)
 | 
			
		||||
      : m_A(a), m_B(b) { }
 | 
			
		||||
 | 
			
		||||
    ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  template<typename TypeOne, typename TypeTwo>
 | 
			
		||||
  class MultSwitch<
 | 
			
		||||
    eVectorType_Matrix,
 | 
			
		||||
    eVectorType_Matrix,
 | 
			
		||||
    TypeOne, TypeTwo> {
 | 
			
		||||
   private:
 | 
			
		||||
    const TypeOne &m_A;
 | 
			
		||||
    const TypeTwo &m_B;
 | 
			
		||||
 | 
			
		||||
   public:
 | 
			
		||||
    typedef MatrixBase<typename TypeOne::ScalarType, TypeOne::kNumRows, TypeTwo::kNumCols> ResultType;
 | 
			
		||||
 | 
			
		||||
    MultSwitch(const TypeOne &a, const TypeTwo &b)
 | 
			
		||||
      : m_A(a), m_B(b) { }
 | 
			
		||||
 | 
			
		||||
    ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // Outer product...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -172,8 +172,30 @@ TEST(MatrixBase, Transposition) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
TEST(MatrixBase, VectorMultiplication) {
 | 
			
		||||
  // Stub
 | 
			
		||||
  EXPECT_EQ(0, 1);
 | 
			
		||||
 | 
			
		||||
  FasTC::MatrixBase<int, 3, 5> a;
 | 
			
		||||
  a(0, 0) = -1;  a(0, 1) =  2;  a(0, 2) = -4; a(0, 3) =  5; a(0, 4) = 0;
 | 
			
		||||
  a(1, 0) =  1;  a(1, 1) =  2;  a(1, 2) =  4; a(1, 3) =  6; a(1, 4) = 3;
 | 
			
		||||
  a(2, 0) = -1;  a(2, 1) = -2;  a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5;
 | 
			
		||||
 | 
			
		||||
  FasTC::VectorBase<int, 5> v;
 | 
			
		||||
  for(int i = 0; i < 5; i++) v[i] = i + 1;
 | 
			
		||||
 | 
			
		||||
  FasTC::VectorBase<int, 3> u = a * v;
 | 
			
		||||
  EXPECT_EQ(u[0], -1 + (2 * 2) - (4 * 3) + (5 * 4));
 | 
			
		||||
  EXPECT_EQ(u[1], 1 + (2 * 2) + (4 * 3) + (6 * 4) + (3 * 5));
 | 
			
		||||
  EXPECT_EQ(u[2], -1 + (-2 * 2) - (3 * 3) - (4 * 4) + (5 * 5));
 | 
			
		||||
 | 
			
		||||
  /////
 | 
			
		||||
 | 
			
		||||
  for(int i = 0; i < 3; i++) u[i] = i + 1;
 | 
			
		||||
  v = u * a;
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(v[0], -1 + (1 * 2) - (1 * 3));
 | 
			
		||||
  EXPECT_EQ(v[1], 2 + (2 * 2) - (2 * 3));
 | 
			
		||||
  EXPECT_EQ(v[2], -4 + (4 * 2) - (3 * 3));
 | 
			
		||||
  EXPECT_EQ(v[3], 5 + (6 * 2) - (4 * 3));
 | 
			
		||||
  EXPECT_EQ(v[4], 0 + (3 * 2) + (5 * 3));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(MatrixSquare, Constructors) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue