Skip to content

Restructure the video/video_reader C++ codebase #3311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions torchvision/csrc/io/video/register.cpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "Video.h"
#include <c10/util/Logging.h>
#include <torch/script.h>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
#include "video.h"

using namespace std;
using namespace ffmpeg;
#include <regex>

namespace vision {
namespace video {

namespace {

const size_t decoderTimeoutMs = 600000;
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
Expand Down Expand Up @@ -93,6 +92,8 @@ std::tuple<std::string, long> _parseStream(const std::string& streamString) {
return std::make_tuple(type_, index_);
}

} // namespace

void Video::_getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
Expand Down Expand Up @@ -159,7 +160,7 @@ Video::Video(std::string videoPath, std::string stream) {
Video::_getDecoderParams(
0, // video start
0, // headerOnly
get<0>(current_stream), // stream info - remove that
std::get<0>(current_stream), // stream info - remove that
long(-1), // stream_id parsed from info above change to -2
true // read all streams
);
Expand Down Expand Up @@ -209,9 +210,9 @@ Video::Video(std::string videoPath, std::string stream) {

succeeded = Video::setCurrentStream(stream);
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
if (get<1>(current_stream) != -1) {
if (std::get<1>(current_stream) != -1) {
LOG(INFO)
<< "Stream index set to " << get<1>(current_stream)
<< "Stream index set to " << std::get<1>(current_stream)
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
}
} // video
Expand All @@ -229,8 +230,8 @@ bool Video::setCurrentStream(std::string stream = "video") {
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
std::get<0>(current_stream), // stream
long(std::get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
Expand All @@ -253,8 +254,8 @@ void Video::Seek(double ts) {
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
std::get<0>(current_stream), // stream
long(std::get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
Expand Down Expand Up @@ -319,3 +320,15 @@ std::tuple<torch::Tensor, double> Video::Next() {

return std::make_tuple(outFrame, frame_pts_s);
}

static auto registerVideo =
torch::class_<Video>("torchvision", "Video")
.def(torch::init<std::string, std::string>())
.def("get_current_stream", &Video::getCurrentStream)
.def("set_current_stream", &Video::setCurrentStream)
.def("get_metadata", &Video::getStreamMetadata)
.def("seek", &Video::Seek)
.def("next", &Video::Next);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from register.cpp


} // namespace video
} // namespace vision
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
#pragma once

#include <map>
#include <regex>
#include <string>
#include <vector>
#include <torch/types.h>

#include <ATen/ATen.h>
#include <c10/util/Logging.h>
#include <torch/script.h>

#include <exception>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
#include "../decoder/defs.h"
#include "../decoder/memory_buffer.h"
#include "../decoder/sync_decoder.h"

using namespace ffmpeg;

namespace vision {
namespace video {

struct Video : torch::CustomClassHolder {
std::tuple<std::string, long> current_stream; // stream type, id
// global video metadata
Expand Down Expand Up @@ -58,3 +53,6 @@ struct Video : torch::CustomClassHolder {
DecoderParameters params;

}; // struct Video

} // namespace video
} // namespace vision
3 changes: 0 additions & 3 deletions torchvision/csrc/io/video_reader/VideoReader.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
#include "VideoReader.h"
#include <ATen/ATen.h>
#include "video_reader.h"

#include <Python.h>
#include <c10/util/Logging.h>
#include <exception>
#include "memory_buffer.h"
#include "sync_decoder.h"

using namespace std;
using namespace ffmpeg;
#include "../decoder/memory_buffer.h"
#include "../decoder/sync_decoder.h"

// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
Expand All @@ -18,8 +14,13 @@ PyMODINIT_FUNC PyInit_video_reader(void) {
}
#endif

using namespace ffmpeg;

namespace vision {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow similar pattern as vision::ops

namespace video_reader {

namespace {

const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
const AVRational timeBaseQ = AVRational{1, AV_TIME_BASE};
Expand Down Expand Up @@ -417,95 +418,6 @@ torch::List<torch::Tensor> readVideo(
return result;
}

torch::List<torch::Tensor> readVideoFromMemory(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving out of the anonymous namespace.

torch::Tensor input_video,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
return readVideo(
false,
input_video,
"", // videoPath
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> readVideoFromFile(
std::string videoPath,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
torch::Tensor dummy_input_video = torch::ones({0});
return readVideo(
true,
dummy_input_video,
videoPath,
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> probeVideo(
bool isReadFile,
const torch::Tensor& input_video,
Expand Down Expand Up @@ -650,20 +562,112 @@ torch::List<torch::Tensor> probeVideo(
return result;
}

torch::List<torch::Tensor> probeVideoFromMemory(torch::Tensor input_video) {
} // namespace

torch::List<torch::Tensor> read_video_from_memory(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"public" methods.

torch::Tensor input_video,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
return readVideo(
false,
input_video,
"", // videoPath
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> read_video_from_file(
std::string videoPath,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
torch::Tensor dummy_input_video = torch::ones({0});
return readVideo(
true,
dummy_input_video,
videoPath,
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> probe_video_from_memory(torch::Tensor input_video) {
return probeVideo(false, input_video, "");
}

torch::List<torch::Tensor> probeVideoFromFile(std::string videoPath) {
torch::List<torch::Tensor> probe_video_from_file(std::string videoPath) {
torch::Tensor dummy_input_video = torch::ones({0});
return probeVideo(true, dummy_input_video, videoPath);
}

} // namespace video_reader

TORCH_LIBRARY_FRAGMENT(video_reader, m) {
m.def("read_video_from_memory", video_reader::readVideoFromMemory);
m.def("read_video_from_file", video_reader::readVideoFromFile);
m.def("probe_video_from_memory", video_reader::probeVideoFromMemory);
m.def("probe_video_from_file", video_reader::probeVideoFromFile);
m.def("read_video_from_memory", read_video_from_memory);
m.def("read_video_from_file", read_video_from_file);
m.def("probe_video_from_memory", probe_video_from_memory);
m.def("probe_video_from_file", probe_video_from_file);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration of methods within the namespaces at the end of the file.

}

} // namespace video_reader
} // namespace vision
Loading