From 62cce58c2ff9abd34c103daf2ae61d3aebe427c4 Mon Sep 17 00:00:00 2001 From: Pavel Krajcevski Date: Thu, 20 Feb 2014 14:49:35 -0500 Subject: [PATCH] Fix some of the vector multiplication and divide routines. In general, we want the scalar division of vectors and matrices to have the matrix come first and the scalar come second. It doesn't make sense to divide a scalar by a vector or to divide a matrix by a vector, so these should now produce errors at compile time. Also, make sure to add additional types that can be multiplied together using the * operator. If we multiply two vectors together, that's a dot product. The size restrictions should be enforced at compile time by the template parameters for VectorBase::Dot In this way, we can support vector/matrix multiplication by retaining the * operator as well. --- Base/include/VectorBase.h | 168 ++++++++++++++++++++++++-------------- Base/test/TestVector.cpp | 3 + 2 files changed, 111 insertions(+), 60 deletions(-) diff --git a/Base/include/VectorBase.h b/Base/include/VectorBase.h index 596df02..46bc194 100644 --- a/Base/include/VectorBase.h +++ b/Base/include/VectorBase.h @@ -31,6 +31,12 @@ namespace FasTC { + enum EVectorType { + eVectorType_Scalar, + eVectorType_Vector, + eVectorType_Matrix + }; + template class VectorBase { protected: @@ -146,62 +152,29 @@ namespace FasTC { template class VectorTraits { public: - static const bool IsVector = false; + static const EVectorType kVectorType = eVectorType_Scalar; }; template class VectorTraits > { public: - static const bool IsVector = true; + static const EVectorType kVectorType = eVectorType_Vector; }; - #define REGISTER_VECTOR_TYPE(TYPE) \ - template<> \ - class VectorTraits< TYPE > { \ - public: \ - static const bool IsVector = true; \ + #define REGISTER_VECTOR_TYPE(TYPE) \ + template<> \ + class VectorTraits< TYPE > { \ + public: \ + static const EVectorType kVectorType = eVectorType_Vector; \ } - #define REGISTER_ONE_TEMPLATE_VECTOR_TYPE(TYPE) \ - template \ - class VectorTraits< TYPE > { \ - public: \ - static const bool IsVector = true; \ + #define REGISTER_ONE_TEMPLATE_VECTOR_TYPE(TYPE) \ + template \ + class VectorTraits< TYPE > { \ + public: \ + static const EVectorType kVectorType = eVectorType_Vector; \ } - template - class VectorSwitch { - private: - const TypeOne &m_A; - const TypeTwo &m_B; - public: - typedef TypeOne VectorType; - typedef TypeTwo ScalarType; - - VectorSwitch(const TypeOne &a, const TypeTwo &b) - : m_A(a), m_B(b) { } - - const VectorType &GetVector() { return m_A; } - const ScalarType &GetScalar() { return m_B; } - }; - - template - class VectorSwitch { - private: - const TypeOne &m_A; - const TypeTwo &m_B; - - public: - typedef TypeTwo VectorType; - typedef TypeOne ScalarType; - - VectorSwitch(const TypeOne &a, const TypeTwo &b) - : m_A(a), m_B(b) { } - - const VectorType &GetVector() { return m_B; } - const ScalarType &GetScalar() { return m_A; } - }; - template static inline VectorType ScalarMultiply(const VectorType &v, const ScalarType &s) { VectorType a(v); @@ -210,13 +183,97 @@ namespace FasTC { return a; } + template< + EVectorType kVectorTypeOne, + EVectorType kVectorTypeTwo, + typename TypeOne, + typename TypeTwo> + class MultSwitch { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + public: + typedef TypeOne ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return m_A * m_B; } + }; + + template + class MultSwitch< + eVectorType_Scalar, + eVectorType_Vector, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef TypeTwo ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return ScalarMultiply(m_B, m_A); } + }; + + template + class MultSwitch< + eVectorType_Vector, + eVectorType_Scalar, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef TypeOne ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return ScalarMultiply(m_A, m_B); } + }; + + template + class MultSwitch< + eVectorType_Vector, + eVectorType_Vector, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef typename TypeOne::ScalarType ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return m_A.Dot(m_B); } + }; + template static inline - typename VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo >::VectorType + typename MultSwitch< + VectorTraits::kVectorType, + VectorTraits::kVectorType, + TypeOne, TypeTwo + >::ResultType operator*(const TypeOne &v1, const TypeTwo &v2) { - typedef VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo > VSwitch; - VSwitch s(v1, v2); - return ScalarMultiply(s.GetVector(), s.GetScalar()); + typedef MultSwitch< + VectorTraits::kVectorType, + VectorTraits::kVectorType, + TypeOne, TypeTwo + > VSwitch; + return VSwitch(v1, v2).GetMultiplication(); + } + + template + static inline VectorType &operator*=(VectorType &v, const ScalarType &s) { + return v = v * s; } template @@ -228,17 +285,8 @@ namespace FasTC { } template - static inline - typename VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo >::VectorType - operator/(const TypeOne &v1, const TypeTwo &v2) { - typedef VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo > VSwitch; - VSwitch s(v1, v2); - return ScalarDivide(s.GetVector(), s.GetScalar()); - } - - template - static inline VectorType &operator*=(VectorType &v, const ScalarType &s) { - return v = ScalarMultiply(v, s); + static inline TypeOne operator/(const TypeOne &v1, const TypeTwo &v2) { + return ScalarDivide(v1, v2); } template diff --git a/Base/test/TestVector.cpp b/Base/test/TestVector.cpp index e6637e1..f2f2f53 100644 --- a/Base/test/TestVector.cpp +++ b/Base/test/TestVector.cpp @@ -135,6 +135,9 @@ TEST(VectorBase, DotProduct) { EXPECT_EQ(v5i.Dot(v5u), 10); EXPECT_EQ(v5u.Dot(v5i), 10); + + EXPECT_EQ(v5i * v5u, 10); + EXPECT_EQ(v5u * v5i, 10); } TEST(VectorBase, Length) {