Add MSSIM metric for images based on matlab implementation.

This commit is contained in:
Pavel Krajcevski 2013-10-11 12:12:32 -04:00
parent 42c6f85642
commit 8d37d6eee5
3 changed files with 177 additions and 0 deletions

View file

@ -106,6 +106,7 @@ namespace FasTC {
} }
double ComputePSNR(Image<PixelType> *other); double ComputePSNR(Image<PixelType> *other);
double ComputeMSSIM(Image<PixelType> *other) const;
// Function to allow derived classes to populate the pixel array. // Function to allow derived classes to populate the pixel array.
// This may involve decompressing a compressed image or otherwise // This may involve decompressing a compressed image or otherwise

View file

@ -218,6 +218,165 @@ double Image<PixelType>::ComputePSNR(Image<PixelType> *other) {
return 10 * log10(maxi/mse); return 10 * log10(maxi/mse);
} }
template<typename PixelType>
static Image<PixelType> FilterValid(const Image<PixelType> &img, uint32 size, double sigma) {
assert(size % 2);
Image<IPixel> gaussian(size, size);
Image<PixelType> tmp = img;
tmp.Filter(gaussian);
Image<PixelType> out(tmp.GetWidth() - size + 1, tmp.GetHeight() - size + 1);
uint32 halfSz = size >> 1;
for(uint32 j = halfSz; j < img.GetHeight()-halfSz; j++) {
for(uint32 i = halfSz; i < img.GetWidth()-halfSz; i++) {
out(i-halfSz, j-halfSz) = tmp(i, j);
}
}
return out;
}
template<typename PixelType>
double Image<PixelType>::ComputeMSSIM(Image<PixelType> *other) const {
if(!other) {
return -1.0;
}
if(GetWidth() != other->GetWidth() ||
GetHeight() != other->GetHeight()) {
return -1.0;
}
double C1 = (0.01 * 255.0 * 0.01 * 255.0);
double C2 = (0.03 * 255.0 * 0.03 * 255.0);
Image<IPixel> img1(GetWidth(), GetHeight());
Image<IPixel> img2(GetWidth(), GetHeight());
ConvertTo(img1);
other->ConvertTo(img2);
for(uint32 j = 0; j < GetHeight(); j++) {
for(uint32 i = 0; i < GetWidth(); i++) {
img1(i, j) = 255.0f * static_cast<float>(img1(i, j));
img2(i, j) = 255.0f * static_cast<float>(img2(i, j));
}
}
/* Matlab code taken from
http://www.cns.nyu.edu/lcv/ssim/ssim_index.m
C1 = (K(1)*L)^2;
C2 = (K(2)*L)^2;
window = window/sum(sum(window));
img1 = double(img1);
img2 = double(img2);
mu1 = filter2(window, img1, 'valid');
mu2 = filter2(window, img2, 'valid');
mu1_sq = mu1.*mu1;
mu2_sq = mu2.*mu2;
mu1_mu2 = mu1.*mu2;
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./
((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
*/
const uint32 filterSz = 11;
const double filterSigma = 1.5;
Image<IPixel> mu1 = FilterValid(img1, filterSz, filterSigma);
Image<IPixel> mu2 = FilterValid(img2, filterSz, filterSigma);
assert(mu1.GetHeight() == mu2.GetHeight());
assert(mu1.GetWidth() == mu2.GetWidth());
Image<IPixel> mu1_sq(mu1);
Image<IPixel> mu2_sq(mu2);
Image<IPixel> mu1_mu2(mu1);
Image<IPixel> sigma1_sq(img1);
Image<IPixel> sigma2_sq(img2);
Image<IPixel> sigma12(img1);
uint32 w = ::std::max(img1.GetWidth(), mu1.GetWidth());
uint32 h = ::std::max(img1.GetHeight(), mu1.GetHeight());
for(uint32 j = 0; j < h; j++) {
for(uint32 i = 0; i < w; i++) {
if(i < mu1.GetWidth() && j < mu1.GetHeight()) {
double m1 = static_cast<float>(mu1(i, j));
double m2 = static_cast<float>(mu2(i, j));
mu1_sq(i, j) = static_cast<float>(m1 * m1);
mu2_sq(i, j) = static_cast<float>(m2 * m2);
mu1_mu2(i, j) = static_cast<float>(m1 * m2);
}
if(i < img1.GetWidth() && j < img1.GetHeight()) {
double i1 = static_cast<float>(img1(i, j));
double i2 = static_cast<float>(img2(i, j));
sigma1_sq(i, j) = static_cast<float>(i1 * i1);
sigma2_sq(i, j) = static_cast<float>(i2 * i2);
sigma12(i, j) = static_cast<float>(i1 * i2);
}
}
}
sigma1_sq = FilterValid(sigma1_sq, filterSz, filterSigma);
sigma2_sq = FilterValid(sigma2_sq, filterSz, filterSigma);
sigma12 = FilterValid(sigma12, filterSz, filterSigma);
assert(sigma1_sq.GetWidth() == mu1.GetWidth());
assert(sigma1_sq.GetHeight() == mu1.GetHeight());
assert(sigma2_sq.GetWidth() == mu1.GetWidth());
assert(sigma2_sq.GetHeight() == mu1.GetHeight());
assert(sigma12.GetWidth() == mu1.GetWidth());
assert(sigma12.GetHeight() == mu1.GetHeight());
w = mu1_sq.GetWidth();
h = mu2_sq.GetHeight();
for(uint32 j = 0; j < h; j++) {
for(uint32 i = 0; i < w; i++) {
double m1sq = static_cast<float>(mu1_sq(i, j));
double m2sq = static_cast<float>(mu2_sq(i, j));
double m1m2 = static_cast<float>(mu1_mu2(i, j));
double s1sq = static_cast<float>(sigma1_sq(i, j));
double s2sq = static_cast<float>(sigma2_sq(i, j));
double s1s2 = static_cast<float>(sigma12(i, j));
sigma1_sq(i, j) = static_cast<float>(s1sq - m1sq);
sigma2_sq(i, j) = static_cast<float>(s2sq - m2sq);
sigma12(i, j) = static_cast<float>(s1s2 - m1m2);
}
}
double mssim = 0.0;
for(uint32 j = 0; j < h; j++) {
for(uint32 i = 0; i < w; i++) {
double m1sq = static_cast<float>(mu1_sq(i, j));
double m2sq = static_cast<float>(mu2_sq(i, j));
double m1m2 = static_cast<float>(mu1_mu2(i, j));
double s1sq = static_cast<float>(sigma1_sq(i, j));
double s2sq = static_cast<float>(sigma2_sq(i, j));
double s1s2 = static_cast<float>(sigma12(i, j));
mssim +=
((2.0 * m1m2 + C1) * (2.0 * s1s2 + C2)) /
((m1sq + m2sq + C1) * (s1sq + s2sq + C2));
}
}
return mssim / static_cast<double>(w * h);
}
// !FIXME! These won't work for non-RGBA8 data. // !FIXME! These won't work for non-RGBA8 data.
template<typename PixelType> template<typename PixelType>
void Image<PixelType>::ConvertToBlockStreamOrder() { void Image<PixelType>::ConvertToBlockStreamOrder() {

View file

@ -149,3 +149,20 @@ TEST(Image, Filter) {
} }
} }
} }
TEST(Image, ComputeMSSIM) {
const uint32 w = 16;
const uint32 h = 16;
FasTC::Image<FasTC::IPixel> img(w, h);
for(uint32 j = 0; j < h; j++) {
for(uint32 i = 0; i < w; i++) {
img(i, j) =
(static_cast<double>(i) * static_cast<double>(j)) /
(static_cast<double>(w) * static_cast<double>(h));
}
}
EXPECT_EQ(img.ComputeMSSIM(&img), 1.0);
}