Skip to content

Commit bd546fa

Browse files
authored
Merge pull request #3468 from cudawarped:cudacodec_fix_colour_conversion
`cudacodec::VideoReader`: fix nv12 to bgr/bgra/grey conversion
2 parents 853144e + d996792 commit bd546fa

File tree

6 files changed

+105
-44
lines changed

6 files changed

+105
-44
lines changed

modules/cudacodec/include/opencv2/cudacodec.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ enum DeinterlaceMode
310310
struct CV_EXPORTS_W_SIMPLE FormatInfo
311311
{
312312
CV_WRAP FormatInfo() : nBitDepthMinus8(-1), ulWidth(0), ulHeight(0), width(0), height(0), ulMaxWidth(0), ulMaxHeight(0), valid(false),
313-
fps(0), ulNumDecodeSurfaces(0) {};
313+
fps(0), ulNumDecodeSurfaces(0), videoFullRangeFlag(false) {};
314314

315315
CV_PROP_RW Codec codec;
316316
CV_PROP_RW ChromaFormat chromaFormat;
@@ -329,6 +329,7 @@ struct CV_EXPORTS_W_SIMPLE FormatInfo
329329
CV_PROP_RW cv::Size targetSz;//!< Post-processed size of the output frame.
330330
CV_PROP_RW cv::Rect srcRoi;//!< Region of interest decoded from video source.
331331
CV_PROP_RW cv::Rect targetRoi;//!< Region of interest in the output frame containing the decoded frame.
332+
CV_PROP_RW bool videoFullRangeFlag;//!< Output value indicating if the black level, luma and chroma of the source are represented using the full or limited range (AKA TV or "analogue" range) of values as defined in Annex E of the ITU-T Specification. Internally the conversion from NV12 to BGR obeys ITU 709.
332333
};
333334

334335
/** @brief cv::cudacodec::VideoReader generic properties identifier.

modules/cudacodec/src/cuda/nv12_to_rgb.cu

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,20 @@ namespace
6666
{
6767
__constant__ float constHueColorSpaceMat[9] = {1.1644f, 0.0f, 1.596f, 1.1644f, -0.3918f, -0.813f, 1.1644f, 2.0172f, 0.0f};
6868

69+
template<bool fullRange>
6970
__device__ static void YUV2RGB(const uint* yuvi, float* red, float* green, float* blue)
7071
{
7172
float luma, chromaCb, chromaCr;
72-
73-
// Prepare for hue adjustment
74-
luma = (float)yuvi[0];
75-
chromaCb = (float)((int)yuvi[1] - 512.0f);
76-
chromaCr = (float)((int)yuvi[2] - 512.0f);
73+
if (fullRange) {
74+
luma = (float)(((int)yuvi[0] * 219.0f / 255.0f));
75+
chromaCb = (float)(((int)yuvi[1] - 512.0f) * 224.0f / 255.0f);
76+
chromaCr = (float)(((int)yuvi[2] - 512.0f) * 224.0f / 255.0f);
77+
}
78+
else {
79+
luma = (float)((int)yuvi[0] - 64.0f);
80+
chromaCb = (float)((int)yuvi[1] - 512.0f);
81+
chromaCr = (float)((int)yuvi[2] - 512.0f);
82+
}
7783

7884
// Convert YUV To RGB with hue adjustment
7985
*red = (luma * constHueColorSpaceMat[0]) +
@@ -112,6 +118,7 @@ namespace
112118
#define COLOR_COMPONENT_BIT_SIZE 10
113119
#define COLOR_COMPONENT_MASK 0x3FF
114120

121+
template<bool fullRange>
115122
__global__ void NV12_to_BGRA(const uchar* srcImage, size_t nSourcePitch,
116123
uint* dstImage, size_t nDestPitch,
117124
uint width, uint height)
@@ -135,31 +142,11 @@ namespace
135142

136143
const int y_chroma = y >> 1;
137144

138-
if (y & 1) // odd scanline ?
139-
{
140-
uint chromaCb = srcImage[chromaOffset + y_chroma * nSourcePitch + x ];
141-
uint chromaCr = srcImage[chromaOffset + y_chroma * nSourcePitch + x + 1];
142-
143-
if (y_chroma < ((height >> 1) - 1)) // interpolate chroma vertically
144-
{
145-
chromaCb = (chromaCb + srcImage[chromaOffset + (y_chroma + 1) * nSourcePitch + x ] + 1) >> 1;
146-
chromaCr = (chromaCr + srcImage[chromaOffset + (y_chroma + 1) * nSourcePitch + x + 1] + 1) >> 1;
147-
}
145+
yuv101010Pel[0] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x ] << ( COLOR_COMPONENT_BIT_SIZE + 2));
146+
yuv101010Pel[0] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x + 1] << ((COLOR_COMPONENT_BIT_SIZE << 1) + 2));
148147

149-
yuv101010Pel[0] |= (chromaCb << ( COLOR_COMPONENT_BIT_SIZE + 2));
150-
yuv101010Pel[0] |= (chromaCr << ((COLOR_COMPONENT_BIT_SIZE << 1) + 2));
151-
152-
yuv101010Pel[1] |= (chromaCb << ( COLOR_COMPONENT_BIT_SIZE + 2));
153-
yuv101010Pel[1] |= (chromaCr << ((COLOR_COMPONENT_BIT_SIZE << 1) + 2));
154-
}
155-
else
156-
{
157-
yuv101010Pel[0] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x ] << ( COLOR_COMPONENT_BIT_SIZE + 2));
158-
yuv101010Pel[0] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x + 1] << ((COLOR_COMPONENT_BIT_SIZE << 1) + 2));
159-
160-
yuv101010Pel[1] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x ] << ( COLOR_COMPONENT_BIT_SIZE + 2));
161-
yuv101010Pel[1] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x + 1] << ((COLOR_COMPONENT_BIT_SIZE << 1) + 2));
162-
}
148+
yuv101010Pel[1] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x ] << ( COLOR_COMPONENT_BIT_SIZE + 2));
149+
yuv101010Pel[1] |= ((uint)srcImage[chromaOffset + y_chroma * nSourcePitch + x + 1] << ((COLOR_COMPONENT_BIT_SIZE << 1) + 2));
163150

164151
// this steps performs the color conversion
165152
uint yuvi[6];
@@ -174,8 +161,8 @@ namespace
174161
yuvi[5] = ((yuv101010Pel[1] >> (COLOR_COMPONENT_BIT_SIZE << 1)) & COLOR_COMPONENT_MASK);
175162

176163
// YUV to RGB Transformation conversion
177-
YUV2RGB(&yuvi[0], &red[0], &green[0], &blue[0]);
178-
YUV2RGB(&yuvi[3], &red[1], &green[1], &blue[1]);
164+
YUV2RGB<fullRange>(&yuvi[0], &red[0], &green[0], &blue[0]);
165+
YUV2RGB<fullRange>(&yuvi[3], &red[1], &green[1], &blue[1]);
179166

180167
// Clamp the results to RGBA
181168

@@ -186,13 +173,15 @@ namespace
186173
}
187174
}
188175

189-
void nv12ToBgra(const GpuMat& decodedFrame, GpuMat& outFrame, int width, int height, cudaStream_t stream)
176+
void nv12ToBgra(const GpuMat& decodedFrame, GpuMat& outFrame, int width, int height, const bool videoFullRangeFlag, cudaStream_t stream)
190177
{
191178
outFrame.create(height, width, CV_8UC4);
192179
dim3 block(32, 8);
193180
dim3 grid(divUp(width, 2 * block.x), divUp(height, block.y));
194-
NV12_to_BGRA<< <grid, block, 0, stream >> > (decodedFrame.ptr<uchar>(), decodedFrame.step,
195-
outFrame.ptr<uint>(), outFrame.step, width, height);
181+
if (videoFullRangeFlag)
182+
NV12_to_BGRA<true> << <grid, block, 0, stream >> > (decodedFrame.ptr<uchar>(), decodedFrame.step, outFrame.ptr<uint>(), outFrame.step, width, height);
183+
else
184+
NV12_to_BGRA<false> << <grid, block, 0, stream >> > (decodedFrame.ptr<uchar>(), decodedFrame.step, outFrame.ptr<uint>(), outFrame.step, width, height);
196185
CV_CUDEV_SAFE_CALL(cudaGetLastError());
197186
if (stream == 0)
198187
CV_CUDEV_SAFE_CALL(cudaDeviceSynchronize());

modules/cudacodec/src/precomp.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
#include "frame_queue.hpp"
8383
#include "video_decoder.hpp"
8484
#include "video_parser.hpp"
85+
#include <opencv2/cudaarithm.hpp>
8586
#endif
8687
#if defined(HAVE_NVCUVENC)
8788
#include <fstream>

modules/cudacodec/src/video_parser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ int CUDAAPI cv::cudacodec::detail::VideoParser::HandleVideoSequence(void* userDa
115115
format->min_num_decode_surfaces != thiz->videoDecoder_->nDecodeSurfaces())
116116
{
117117
FormatInfo newFormat;
118+
newFormat.videoFullRangeFlag = format->video_signal_description.video_full_range_flag;
118119
newFormat.codec = static_cast<Codec>(format->codec);
119120
newFormat.chromaFormat = static_cast<ChromaFormat>(format->chroma_format);
120121
newFormat.nBitDepthMinus8 = format->bit_depth_luma_minus8;

modules/cudacodec/src/video_reader.cpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,54 @@ Ptr<VideoReader> cv::cudacodec::createVideoReader(const Ptr<RawVideoSource>&, co
5353

5454
#else // HAVE_NVCUVID
5555

56-
void nv12ToBgra(const GpuMat& decodedFrame, GpuMat& outFrame, int width, int height, cudaStream_t stream);
56+
void nv12ToBgra(const GpuMat& decodedFrame, GpuMat& outFrame, int width, int height, const bool videoFullRangeFlag, cudaStream_t stream);
5757
bool ValidColorFormat(const ColorFormat colorFormat);
5858

59-
void videoDecPostProcessFrame(const GpuMat& decodedFrame, GpuMat& outFrame, int width, int height, const ColorFormat colorFormat,
59+
void cvtFromNv12(const GpuMat& decodedFrame, GpuMat& outFrame, int width, int height, const ColorFormat colorFormat, const bool videoFullRangeFlag,
6060
Stream stream)
6161
{
62+
CV_Assert(decodedFrame.cols == width && decodedFrame.rows == height * 1.5f);
6263
if (colorFormat == ColorFormat::BGRA) {
63-
nv12ToBgra(decodedFrame, outFrame, width, height, StreamAccessor::getStream(stream));
64+
nv12ToBgra(decodedFrame, outFrame, width, height, videoFullRangeFlag, StreamAccessor::getStream(stream));
6465
}
6566
else if (colorFormat == ColorFormat::BGR) {
6667
outFrame.create(height, width, CV_8UC3);
6768
Npp8u* pSrc[2] = { decodedFrame.data, &decodedFrame.data[decodedFrame.step * height] };
6869
NppiSize oSizeROI = { width,height };
70+
#if (CUDART_VERSION < 9200)
71+
CV_Error(Error::StsUnsupportedFormat, "ColorFormat::BGR is not supported until CUDA 9.2, use default ColorFormat::BGRA.");
72+
#elif (CUDART_VERSION < 10100)
73+
cv::cuda::NppStreamHandler h(stream);
74+
if (videoFullRangeFlag)
75+
nppSafeCall(nppiNV12ToBGR_709HDTV_8u_P2C3R(pSrc, decodedFrame.step, outFrame.data, outFrame.step, oSizeROI));
76+
else {
77+
CV_LOG_DEBUG(NULL, "Color reproduction may be inaccurate due CUDA version <= 11.0, for better results upgrade CUDA runtime or try ColorFormat::BGRA.");
78+
nppSafeCall(nppiNV12ToBGR_8u_P2C3R(pSrc, decodedFrame.step, outFrame.data, outFrame.step, oSizeROI));
79+
}
80+
#elif (CUDART_VERSION >= 10100)
6981
NppStreamContext nppStreamCtx;
7082
nppSafeCall(nppGetStreamContext(&nppStreamCtx));
7183
nppStreamCtx.hStream = StreamAccessor::getStream(stream);
72-
nppSafeCall(nppiNV12ToBGR_8u_P2C3R_Ctx(pSrc, decodedFrame.step, outFrame.data, outFrame.step, oSizeROI, nppStreamCtx));
84+
if (videoFullRangeFlag)
85+
nppSafeCall(nppiNV12ToBGR_709HDTV_8u_P2C3R_Ctx(pSrc, decodedFrame.step, outFrame.data, outFrame.step, oSizeROI, nppStreamCtx));
86+
else {
87+
#if (CUDART_VERSION < 11000)
88+
CV_LOG_DEBUG(NULL, "Color reproduction may be inaccurate due CUDA version <= 11.0, for better results upgrade CUDA runtime or try ColorFormat::BGRA.");
89+
nppSafeCall(nppiNV12ToBGR_8u_P2C3R_Ctx(pSrc, decodedFrame.step, outFrame.data, outFrame.step, oSizeROI, nppStreamCtx));
90+
#else
91+
nppSafeCall(nppiNV12ToBGR_709CSC_8u_P2C3R_Ctx(pSrc, decodedFrame.step, outFrame.data, outFrame.step, oSizeROI, nppStreamCtx));
92+
#endif
93+
}
94+
#endif
7395
}
7496
else if (colorFormat == ColorFormat::GRAY) {
7597
outFrame.create(height, width, CV_8UC1);
76-
cudaMemcpy2DAsync(outFrame.ptr(), outFrame.step, decodedFrame.ptr(), decodedFrame.step, width, height, cudaMemcpyDeviceToDevice, StreamAccessor::getStream(stream));
98+
if(videoFullRangeFlag)
99+
cudaSafeCall(cudaMemcpy2DAsync(outFrame.ptr(), outFrame.step, decodedFrame.ptr(), decodedFrame.step, width, height, cudaMemcpyDeviceToDevice, StreamAccessor::getStream(stream)));
100+
else {
101+
cv::cuda::subtract(decodedFrame(Rect(0,0,width,height)), 16, outFrame, noArray(), CV_8U, stream);
102+
cv::cuda::multiply(outFrame, 255.0f / 219.0f, outFrame, 1.0, CV_8U, stream);
103+
}
77104
}
78105
else if (colorFormat == ColorFormat::NV_NV12) {
79106
decodedFrame.copyTo(outFrame, stream);
@@ -222,9 +249,7 @@ namespace
222249
// map decoded video frame to CUDA surface
223250
GpuMat decodedFrame = videoDecoder_->mapFrame(frameInfo.first.picture_index, frameInfo.second);
224251

225-
// perform post processing on the CUDA surface (performs colors space conversion and post processing)
226-
// comment this out if we include the line of code seen above
227-
videoDecPostProcessFrame(decodedFrame, frame, videoDecoder_->targetWidth(), videoDecoder_->targetHeight(), colorFormat, stream);
252+
cvtFromNv12(decodedFrame, frame, videoDecoder_->targetWidth(), videoDecoder_->targetHeight(), colorFormat, videoDecoder_->format().videoFullRangeFlag, stream);
228253

229254
// unmap video frame
230255
// unmapFrame() synchronizes with the VideoDecode API (ensures the frame has finished decoding)

modules/cudacodec/test/test_video.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ PARAM_TEST_CASE(Video, cv::cuda::DeviceInfo, std::string)
6262
{
6363
};
6464

65+
typedef tuple<std::string, bool> color_conversion_params_t;
66+
PARAM_TEST_CASE(ColorConversion, cv::cuda::DeviceInfo, cv::cudacodec::ColorFormat, color_conversion_params_t)
67+
{
68+
};
69+
6570
PARAM_TEST_CASE(VideoReadRaw, cv::cuda::DeviceInfo, std::string)
6671
{
6772
};
@@ -220,7 +225,7 @@ CUDA_TEST_P(Scaling, Reader)
220225
cv::cuda::resize(frameOr(srcRoiOut), frameGs, targetRoiOut.size(), 0, 0, INTER_AREA);
221226
// assert on mean absolute error due to different resize algorithms
222227
const double mae = cv::cuda::norm(frameGs, frame(targetRoiOut), NORM_L1)/frameGs.size().area();
223-
ASSERT_LT(mae, 2.35);
228+
ASSERT_LT(mae, 2.75);
224229
}
225230

226231
CUDA_TEST_P(Video, Reader)
@@ -265,6 +270,33 @@ CUDA_TEST_P(Video, Reader)
265270
}
266271
}
267272

273+
CUDA_TEST_P(ColorConversion, Reader)
274+
{
275+
cv::cuda::setDevice(GET_PARAM(0).deviceID());
276+
const cv::cudacodec::ColorFormat colorFormat = GET_PARAM(1);
277+
const std::string inputFile = std::string(cvtest::TS::ptr()->get_data_path()) + "../" + get<0>(GET_PARAM(2));
278+
const bool videoFullRangeFlag = get<1>(GET_PARAM(2));
279+
cv::Ptr<cv::cudacodec::VideoReader> reader = cv::cudacodec::createVideoReader(inputFile);
280+
reader->set(colorFormat);
281+
cv::VideoCapture cap(inputFile);
282+
283+
cv::cuda::GpuMat frame;
284+
Mat frameHost, frameHostGs, frameFromDevice;
285+
for (int i = 0; i < 10; i++)
286+
{
287+
reader->nextFrame(frame);
288+
frame.download(frameFromDevice);
289+
cap.read(frameHost);
290+
const cv::cudacodec::FormatInfo fmt = reader->format();
291+
ASSERT_TRUE(fmt.valid && fmt.videoFullRangeFlag == videoFullRangeFlag);
292+
if (colorFormat == cv::cudacodec::ColorFormat::BGRA)
293+
cv::cvtColor(frameHost, frameHostGs, COLOR_BGR2BGRA);
294+
else
295+
frameHostGs = frameHost;
296+
EXPECT_MAT_NEAR(frameHostGs, frameFromDevice, 2.0);
297+
}
298+
}
299+
268300
CUDA_TEST_P(VideoReadRaw, Reader)
269301
{
270302
cv::cuda::setDevice(GET_PARAM(0).deviceID());
@@ -672,6 +704,18 @@ INSTANTIATE_TEST_CASE_P(CUDA_Codec, Video, testing::Combine(
672704
ALL_DEVICES,
673705
testing::Values(VIDEO_SRC_R)));
674706

707+
const color_conversion_params_t color_conversion_params[] =
708+
{
709+
color_conversion_params_t("highgui/video/big_buck_bunny.h264", false),
710+
color_conversion_params_t("highgui/video/big_buck_bunny_full_color_range.h264", true),
711+
};
712+
713+
#define VIDEO_COLOR_OUTPUTS cv::cudacodec::ColorFormat::BGRA, cv::cudacodec::ColorFormat::BGRA
714+
INSTANTIATE_TEST_CASE_P(CUDA_Codec, ColorConversion, testing::Combine(
715+
ALL_DEVICES,
716+
testing::Values(VIDEO_COLOR_OUTPUTS),
717+
testing::ValuesIn(color_conversion_params)));
718+
675719
#define VIDEO_SRC_RW "highgui/video/big_buck_bunny.h264", "highgui/video/big_buck_bunny.h265"
676720
INSTANTIATE_TEST_CASE_P(CUDA_Codec, VideoReadRaw, testing::Combine(
677721
ALL_DEVICES,

0 commit comments

Comments
 (0)