Fix some of the vector multiplication and divide routines.

In general, we want the scalar division of vectors and matrices to
have the matrix come first and the scalar come second. It doesn't make
sense to divide a scalar by a vector or to divide a matrix by a vector,
so these should now produce errors at compile time.

Also, make sure to add additional types that can be multiplied together
using the * operator. If we multiply two vectors together, that's a dot
product. The size restrictions should be enforced at compile time by the
template parameters for VectorBase<T, N>::Dot

In this way, we can support vector/matrix multiplication by retaining the
* operator as well.
This commit is contained in:
Pavel Krajcevski 2014-02-20 14:49:35 -05:00
parent 2d7ee21fb7
commit 62cce58c2f
2 changed files with 111 additions and 60 deletions

View file

@ -31,6 +31,12 @@
namespace FasTC { namespace FasTC {
enum EVectorType {
eVectorType_Scalar,
eVectorType_Vector,
eVectorType_Matrix
};
template <typename T, const int N> template <typename T, const int N>
class VectorBase { class VectorBase {
protected: protected:
@ -146,62 +152,29 @@ namespace FasTC {
template<typename T> template<typename T>
class VectorTraits { class VectorTraits {
public: public:
static const bool IsVector = false; static const EVectorType kVectorType = eVectorType_Scalar;
}; };
template<typename T, const int N> template<typename T, const int N>
class VectorTraits<VectorBase<T, N> > { class VectorTraits<VectorBase<T, N> > {
public: public:
static const bool IsVector = true; static const EVectorType kVectorType = eVectorType_Vector;
}; };
#define REGISTER_VECTOR_TYPE(TYPE) \ #define REGISTER_VECTOR_TYPE(TYPE) \
template<> \ template<> \
class VectorTraits< TYPE > { \ class VectorTraits< TYPE > { \
public: \ public: \
static const bool IsVector = true; \ static const EVectorType kVectorType = eVectorType_Vector; \
} }
#define REGISTER_ONE_TEMPLATE_VECTOR_TYPE(TYPE) \ #define REGISTER_ONE_TEMPLATE_VECTOR_TYPE(TYPE) \
template<typename T> \ template<typename T> \
class VectorTraits< TYPE <T> > { \ class VectorTraits< TYPE <T> > { \
public: \ public: \
static const bool IsVector = true; \ static const EVectorType kVectorType = eVectorType_Vector; \
} }
template<bool condition, typename TypeOne, typename TypeTwo>
class VectorSwitch {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef TypeOne VectorType;
typedef TypeTwo ScalarType;
VectorSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
const VectorType &GetVector() { return m_A; }
const ScalarType &GetScalar() { return m_B; }
};
template<typename TypeOne, typename TypeTwo>
class VectorSwitch<false, TypeOne, TypeTwo> {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef TypeTwo VectorType;
typedef TypeOne ScalarType;
VectorSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
const VectorType &GetVector() { return m_B; }
const ScalarType &GetScalar() { return m_A; }
};
template<typename VectorType, typename ScalarType> template<typename VectorType, typename ScalarType>
static inline VectorType ScalarMultiply(const VectorType &v, const ScalarType &s) { static inline VectorType ScalarMultiply(const VectorType &v, const ScalarType &s) {
VectorType a(v); VectorType a(v);
@ -210,13 +183,97 @@ namespace FasTC {
return a; return a;
} }
template<
EVectorType kVectorTypeOne,
EVectorType kVectorTypeTwo,
typename TypeOne,
typename TypeTwo>
class MultSwitch {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef TypeOne ResultType;
MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
ResultType GetMultiplication() { return m_A * m_B; }
};
template<typename TypeOne, typename TypeTwo>
class MultSwitch<
eVectorType_Scalar,
eVectorType_Vector,
TypeOne, TypeTwo> {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef TypeTwo ResultType;
MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
ResultType GetMultiplication() { return ScalarMultiply(m_B, m_A); }
};
template<typename TypeOne, typename TypeTwo>
class MultSwitch<
eVectorType_Vector,
eVectorType_Scalar,
TypeOne, TypeTwo> {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef TypeOne ResultType;
MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
ResultType GetMultiplication() { return ScalarMultiply(m_A, m_B); }
};
template<typename TypeOne, typename TypeTwo>
class MultSwitch<
eVectorType_Vector,
eVectorType_Vector,
TypeOne, TypeTwo> {
private:
const TypeOne &m_A;
const TypeTwo &m_B;
public:
typedef typename TypeOne::ScalarType ResultType;
MultSwitch(const TypeOne &a, const TypeTwo &b)
: m_A(a), m_B(b) { }
ResultType GetMultiplication() { return m_A.Dot(m_B); }
};
template<typename TypeOne, typename TypeTwo> template<typename TypeOne, typename TypeTwo>
static inline static inline
typename VectorSwitch< VectorTraits<TypeOne>::IsVector, TypeOne, TypeTwo >::VectorType typename MultSwitch<
VectorTraits<TypeOne>::kVectorType,
VectorTraits<TypeTwo>::kVectorType,
TypeOne, TypeTwo
>::ResultType
operator*(const TypeOne &v1, const TypeTwo &v2) { operator*(const TypeOne &v1, const TypeTwo &v2) {
typedef VectorSwitch< VectorTraits<TypeOne>::IsVector, TypeOne, TypeTwo > VSwitch; typedef MultSwitch<
VSwitch s(v1, v2); VectorTraits<TypeOne>::kVectorType,
return ScalarMultiply(s.GetVector(), s.GetScalar()); VectorTraits<TypeTwo>::kVectorType,
TypeOne, TypeTwo
> VSwitch;
return VSwitch(v1, v2).GetMultiplication();
}
template<typename VectorType, typename ScalarType>
static inline VectorType &operator*=(VectorType &v, const ScalarType &s) {
return v = v * s;
} }
template<typename VectorType, typename ScalarType> template<typename VectorType, typename ScalarType>
@ -228,17 +285,8 @@ namespace FasTC {
} }
template<typename TypeOne, typename TypeTwo> template<typename TypeOne, typename TypeTwo>
static inline static inline TypeOne operator/(const TypeOne &v1, const TypeTwo &v2) {
typename VectorSwitch< VectorTraits<TypeOne>::IsVector, TypeOne, TypeTwo >::VectorType return ScalarDivide(v1, v2);
operator/(const TypeOne &v1, const TypeTwo &v2) {
typedef VectorSwitch< VectorTraits<TypeOne>::IsVector, TypeOne, TypeTwo > VSwitch;
VSwitch s(v1, v2);
return ScalarDivide(s.GetVector(), s.GetScalar());
}
template<typename VectorType, typename ScalarType>
static inline VectorType &operator*=(VectorType &v, const ScalarType &s) {
return v = ScalarMultiply(v, s);
} }
template<typename VectorType, typename ScalarType> template<typename VectorType, typename ScalarType>

View file

@ -135,6 +135,9 @@ TEST(VectorBase, DotProduct) {
EXPECT_EQ(v5i.Dot(v5u), 10); EXPECT_EQ(v5i.Dot(v5u), 10);
EXPECT_EQ(v5u.Dot(v5i), 10); EXPECT_EQ(v5u.Dot(v5i), 10);
EXPECT_EQ(v5i * v5u, 10);
EXPECT_EQ(v5u * v5i, 10);
} }
TEST(VectorBase, Length) { TEST(VectorBase, Length) {