diff --git a/Base/include/Image.h b/Base/include/Image.h index 7e252a8..a8c8fb3 100644 --- a/Base/include/Image.h +++ b/Base/include/Image.h @@ -106,6 +106,7 @@ namespace FasTC { } double ComputePSNR(Image *other); + double ComputeMSSIM(Image *other) const; // Function to allow derived classes to populate the pixel array. // This may involve decompressing a compressed image or otherwise diff --git a/Base/src/Image.cpp b/Base/src/Image.cpp index cf3d1e8..3607216 100644 --- a/Base/src/Image.cpp +++ b/Base/src/Image.cpp @@ -218,6 +218,165 @@ double Image::ComputePSNR(Image *other) { return 10 * log10(maxi/mse); } +template +static Image FilterValid(const Image &img, uint32 size, double sigma) { + assert(size % 2); + Image gaussian(size, size); + Image tmp = img; + tmp.Filter(gaussian); + + Image out(tmp.GetWidth() - size + 1, tmp.GetHeight() - size + 1); + uint32 halfSz = size >> 1; + for(uint32 j = halfSz; j < img.GetHeight()-halfSz; j++) { + for(uint32 i = halfSz; i < img.GetWidth()-halfSz; i++) { + out(i-halfSz, j-halfSz) = tmp(i, j); + } + } + + return out; +} + +template +double Image::ComputeMSSIM(Image *other) const { + if(!other) { + return -1.0; + } + + if(GetWidth() != other->GetWidth() || + GetHeight() != other->GetHeight()) { + return -1.0; + } + + double C1 = (0.01 * 255.0 * 0.01 * 255.0); + double C2 = (0.03 * 255.0 * 0.03 * 255.0); + + Image img1(GetWidth(), GetHeight()); + Image img2(GetWidth(), GetHeight()); + + ConvertTo(img1); + other->ConvertTo(img2); + + for(uint32 j = 0; j < GetHeight(); j++) { + for(uint32 i = 0; i < GetWidth(); i++) { + img1(i, j) = 255.0f * static_cast(img1(i, j)); + img2(i, j) = 255.0f * static_cast(img2(i, j)); + } + } + + /* Matlab code taken from + http://www.cns.nyu.edu/lcv/ssim/ssim_index.m + + C1 = (K(1)*L)^2; + C2 = (K(2)*L)^2; + window = window/sum(sum(window)); + img1 = double(img1); + img2 = double(img2); + + mu1 = filter2(window, img1, 'valid'); + mu2 = filter2(window, img2, 'valid'); + mu1_sq = mu1.*mu1; + mu2_sq = mu2.*mu2; + mu1_mu2 = mu1.*mu2; + sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; + sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; + sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; + + ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./ + ((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); + */ + + const uint32 filterSz = 11; + const double filterSigma = 1.5; + + Image mu1 = FilterValid(img1, filterSz, filterSigma); + Image mu2 = FilterValid(img2, filterSz, filterSigma); + + assert(mu1.GetHeight() == mu2.GetHeight()); + assert(mu1.GetWidth() == mu2.GetWidth()); + + Image mu1_sq(mu1); + Image mu2_sq(mu2); + Image mu1_mu2(mu1); + Image sigma1_sq(img1); + Image sigma2_sq(img2); + Image sigma12(img1); + + uint32 w = ::std::max(img1.GetWidth(), mu1.GetWidth()); + uint32 h = ::std::max(img1.GetHeight(), mu1.GetHeight()); + for(uint32 j = 0; j < h; j++) { + for(uint32 i = 0; i < w; i++) { + if(i < mu1.GetWidth() && j < mu1.GetHeight()) { + double m1 = static_cast(mu1(i, j)); + double m2 = static_cast(mu2(i, j)); + + mu1_sq(i, j) = static_cast(m1 * m1); + mu2_sq(i, j) = static_cast(m2 * m2); + mu1_mu2(i, j) = static_cast(m1 * m2); + } + + if(i < img1.GetWidth() && j < img1.GetHeight()) { + double i1 = static_cast(img1(i, j)); + double i2 = static_cast(img2(i, j)); + + sigma1_sq(i, j) = static_cast(i1 * i1); + sigma2_sq(i, j) = static_cast(i2 * i2); + sigma12(i, j) = static_cast(i1 * i2); + } + } + } + + sigma1_sq = FilterValid(sigma1_sq, filterSz, filterSigma); + sigma2_sq = FilterValid(sigma2_sq, filterSz, filterSigma); + sigma12 = FilterValid(sigma12, filterSz, filterSigma); + + assert(sigma1_sq.GetWidth() == mu1.GetWidth()); + assert(sigma1_sq.GetHeight() == mu1.GetHeight()); + + assert(sigma2_sq.GetWidth() == mu1.GetWidth()); + assert(sigma2_sq.GetHeight() == mu1.GetHeight()); + + assert(sigma12.GetWidth() == mu1.GetWidth()); + assert(sigma12.GetHeight() == mu1.GetHeight()); + + w = mu1_sq.GetWidth(); + h = mu2_sq.GetHeight(); + + for(uint32 j = 0; j < h; j++) { + for(uint32 i = 0; i < w; i++) { + double m1sq = static_cast(mu1_sq(i, j)); + double m2sq = static_cast(mu2_sq(i, j)); + double m1m2 = static_cast(mu1_mu2(i, j)); + + double s1sq = static_cast(sigma1_sq(i, j)); + double s2sq = static_cast(sigma2_sq(i, j)); + double s1s2 = static_cast(sigma12(i, j)); + + sigma1_sq(i, j) = static_cast(s1sq - m1sq); + sigma2_sq(i, j) = static_cast(s2sq - m2sq); + sigma12(i, j) = static_cast(s1s2 - m1m2); + } + } + + double mssim = 0.0; + for(uint32 j = 0; j < h; j++) { + for(uint32 i = 0; i < w; i++) { + double m1sq = static_cast(mu1_sq(i, j)); + double m2sq = static_cast(mu2_sq(i, j)); + double m1m2 = static_cast(mu1_mu2(i, j)); + + double s1sq = static_cast(sigma1_sq(i, j)); + double s2sq = static_cast(sigma2_sq(i, j)); + double s1s2 = static_cast(sigma12(i, j)); + + mssim += + ((2.0 * m1m2 + C1) * (2.0 * s1s2 + C2)) / + ((m1sq + m2sq + C1) * (s1sq + s2sq + C2)); + } + } + + return mssim / static_cast(w * h); +} + // !FIXME! These won't work for non-RGBA8 data. template void Image::ConvertToBlockStreamOrder() { diff --git a/Base/test/TestImage.cpp b/Base/test/TestImage.cpp index 1e0c012..baf5ed1 100644 --- a/Base/test/TestImage.cpp +++ b/Base/test/TestImage.cpp @@ -149,3 +149,20 @@ TEST(Image, Filter) { } } } + +TEST(Image, ComputeMSSIM) { + + const uint32 w = 16; + const uint32 h = 16; + + FasTC::Image img(w, h); + for(uint32 j = 0; j < h; j++) { + for(uint32 i = 0; i < w; i++) { + img(i, j) = + (static_cast(i) * static_cast(j)) / + (static_cast(w) * static_cast(h)); + } + } + + EXPECT_EQ(img.ComputeMSSIM(&img), 1.0); +}