diff --git a/Base/include/VectorBase.h b/Base/include/VectorBase.h index 596df02..46bc194 100644 --- a/Base/include/VectorBase.h +++ b/Base/include/VectorBase.h @@ -31,6 +31,12 @@ namespace FasTC { + enum EVectorType { + eVectorType_Scalar, + eVectorType_Vector, + eVectorType_Matrix + }; + template class VectorBase { protected: @@ -146,62 +152,29 @@ namespace FasTC { template class VectorTraits { public: - static const bool IsVector = false; + static const EVectorType kVectorType = eVectorType_Scalar; }; template class VectorTraits > { public: - static const bool IsVector = true; + static const EVectorType kVectorType = eVectorType_Vector; }; - #define REGISTER_VECTOR_TYPE(TYPE) \ - template<> \ - class VectorTraits< TYPE > { \ - public: \ - static const bool IsVector = true; \ + #define REGISTER_VECTOR_TYPE(TYPE) \ + template<> \ + class VectorTraits< TYPE > { \ + public: \ + static const EVectorType kVectorType = eVectorType_Vector; \ } - #define REGISTER_ONE_TEMPLATE_VECTOR_TYPE(TYPE) \ - template \ - class VectorTraits< TYPE > { \ - public: \ - static const bool IsVector = true; \ + #define REGISTER_ONE_TEMPLATE_VECTOR_TYPE(TYPE) \ + template \ + class VectorTraits< TYPE > { \ + public: \ + static const EVectorType kVectorType = eVectorType_Vector; \ } - template - class VectorSwitch { - private: - const TypeOne &m_A; - const TypeTwo &m_B; - public: - typedef TypeOne VectorType; - typedef TypeTwo ScalarType; - - VectorSwitch(const TypeOne &a, const TypeTwo &b) - : m_A(a), m_B(b) { } - - const VectorType &GetVector() { return m_A; } - const ScalarType &GetScalar() { return m_B; } - }; - - template - class VectorSwitch { - private: - const TypeOne &m_A; - const TypeTwo &m_B; - - public: - typedef TypeTwo VectorType; - typedef TypeOne ScalarType; - - VectorSwitch(const TypeOne &a, const TypeTwo &b) - : m_A(a), m_B(b) { } - - const VectorType &GetVector() { return m_B; } - const ScalarType &GetScalar() { return m_A; } - }; - template static inline VectorType ScalarMultiply(const VectorType &v, const ScalarType &s) { VectorType a(v); @@ -210,13 +183,97 @@ namespace FasTC { return a; } + template< + EVectorType kVectorTypeOne, + EVectorType kVectorTypeTwo, + typename TypeOne, + typename TypeTwo> + class MultSwitch { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + public: + typedef TypeOne ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return m_A * m_B; } + }; + + template + class MultSwitch< + eVectorType_Scalar, + eVectorType_Vector, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef TypeTwo ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return ScalarMultiply(m_B, m_A); } + }; + + template + class MultSwitch< + eVectorType_Vector, + eVectorType_Scalar, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef TypeOne ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return ScalarMultiply(m_A, m_B); } + }; + + template + class MultSwitch< + eVectorType_Vector, + eVectorType_Vector, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef typename TypeOne::ScalarType ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() { return m_A.Dot(m_B); } + }; + template static inline - typename VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo >::VectorType + typename MultSwitch< + VectorTraits::kVectorType, + VectorTraits::kVectorType, + TypeOne, TypeTwo + >::ResultType operator*(const TypeOne &v1, const TypeTwo &v2) { - typedef VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo > VSwitch; - VSwitch s(v1, v2); - return ScalarMultiply(s.GetVector(), s.GetScalar()); + typedef MultSwitch< + VectorTraits::kVectorType, + VectorTraits::kVectorType, + TypeOne, TypeTwo + > VSwitch; + return VSwitch(v1, v2).GetMultiplication(); + } + + template + static inline VectorType &operator*=(VectorType &v, const ScalarType &s) { + return v = v * s; } template @@ -228,17 +285,8 @@ namespace FasTC { } template - static inline - typename VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo >::VectorType - operator/(const TypeOne &v1, const TypeTwo &v2) { - typedef VectorSwitch< VectorTraits::IsVector, TypeOne, TypeTwo > VSwitch; - VSwitch s(v1, v2); - return ScalarDivide(s.GetVector(), s.GetScalar()); - } - - template - static inline VectorType &operator*=(VectorType &v, const ScalarType &s) { - return v = ScalarMultiply(v, s); + static inline TypeOne operator/(const TypeOne &v1, const TypeTwo &v2) { + return ScalarDivide(v1, v2); } template diff --git a/Base/test/TestVector.cpp b/Base/test/TestVector.cpp index e6637e1..f2f2f53 100644 --- a/Base/test/TestVector.cpp +++ b/Base/test/TestVector.cpp @@ -135,6 +135,9 @@ TEST(VectorBase, DotProduct) { EXPECT_EQ(v5i.Dot(v5u), 10); EXPECT_EQ(v5u.Dot(v5i), 10); + + EXPECT_EQ(v5i * v5u, 10); + EXPECT_EQ(v5u * v5i, 10); } TEST(VectorBase, Length) {