From f566a78dd87b82b6b2114e7437ae214dc307bcd3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 3 Dec 2021 18:09:27 +0000 Subject: [PATCH 1/8] Add training reference for optical flow models --- references/optical_flow/train.py | 331 ++++++++++++++++++++++++++ references/optical_flow/utils.py | 390 +++++++++++++++++++++++++++++++ 2 files changed, 721 insertions(+) create mode 100644 references/optical_flow/train.py create mode 100644 references/optical_flow/utils.py diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py new file mode 100644 index 00000000000..3bd2fa327b0 --- /dev/null +++ b/references/optical_flow/train.py @@ -0,0 +1,331 @@ +import argparse +import warnings +from pathlib import Path + +import torch +import utils +from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval +from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K +from torchvision.models.optical_flow import raft_large, raft_small + + +def get_train_dataset(stage, dataset_root): + if stage == "chairs": + transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True) + return FlyingChairs(root=dataset_root, split="train", transforms=transforms) + elif stage == "things": + transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True) + return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms) + elif stage == "sintel_SKH": # S + K + H as from paper + crop_size = (368, 768) + transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True) + + things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms) + sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms) + + kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True) + kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms) + + hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True) + hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms) + + # As future improvement, we could probably be using a distributed sampler here + # The distribution is S(.71), T(.135), K(.135), H(.02) + return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean + elif stage == "kitti": + transforms = OpticalFlowPresetTrain( + # resize and crop params + crop_size=(288, 960), + min_scale=-0.2, + max_scale=0.4, + stretch_prob=0, + # flip params + do_flip=False, + # jitter params + brightness=0.3, + contrast=0.3, + saturation=0.3, + hue=0.3 / 3.14, + asymmetric_jitter_prob=0, + ) + return KittiFlow(root=dataset_root, split="train", transforms=transforms) + else: + raise ValueError(f"Unknown stage {stage}") + + +@torch.no_grad() +def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): + """Helper function to compute various metrics (epe, etc.) for a model on a given dataset. + + We process as many samples as possible with ddp, and process the rest on a single worker. + """ + model.eval() + + sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) + val_loader = torch.utils.data.DataLoader( + val_dataset, + sampler=sampler, + batch_size=batch_size or args.batch_size, + pin_memory=True, + num_workers=args.num_workers, + ) + + num_flow_updates = num_flow_updates or args.num_flow_updates + + def inner_loop(blob): + if blob[0].dim() == 3: + # input is not batched so we add an extra dim for consistency + blob = [x[None, :, :, :] if x is not None else None for x in blob] + + image1, image2, flow_gt = blob[:3] + valid_flow_mask = None if len(blob) == 3 else blob[-1] + + image1, image2 = image1.cuda(), image2.cuda() + + padder = utils.InputPadder(image1.shape, mode=padder_mode) + image1, image2 = padder.pad(image1, image2) + + flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates) + flow_pred = flow_predictions[-1] + + flow_pred = padder.unpad(flow_pred).cpu() + + epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt() + + logger.meters["epe"].update(epe.mean().item(), n=epe.numel()) + for distance in (1, 3, 5): + logger.meters[f"{distance}px"].update((epe < distance).float().mean().item(), n=epe.numel()) + + relative_epe = epe / (flow_gt ** 2).sum(dim=1).sqrt() + if valid_flow_mask is not None: + epe, relative_epe = epe[valid_flow_mask], relative_epe[valid_flow_mask] + bad_predictions = ((epe > 3) & (relative_epe > 0.05)).float() + + # note the n=1 for per_image_epe: we compute an average over averages. We first average within each image and + # then average over the images. This is in contrast with the other epe computation, where we + # average only once over all the pixels of all images. + logger.meters["per_image_epe"].update(epe.mean().item(), n=1) # f1-epe in paper + logger.meters["f1"].update(bad_predictions.mean().item(), n=bad_predictions.numel()) # f1-all in paper + + logger = utils.MetricLogger() + for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"): + logger.add_meter(meter_name, fmt="{global_avg:.4f}") + + num_processed_samples = 0 + for blob in logger.log_every(val_loader, header=header, print_freq=None): + inner_loop(blob) + num_processed_samples += blob[0].shape[0] # batch size + + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + print( + f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " + "Going to process the remaining samples individually, if any." + ) + + if args.rank == 0: # we only need to process the rest on a single worker + for i in range(num_processed_samples, len(val_dataset)): + inner_loop(val_dataset[i]) + + logger.synchronize_between_processes() + print(header, logger) + + +def validate(model, args): + val_datasets = args.val_dataset or [] + for name in val_datasets: + if name == "kitti": + # Kitti has different image sizes so we need to individually pad them, we can't batch. + # see comment in InputPadder + if args.batch_size != 1 and args.rank == 0: + warnings.warn( + f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." + ) + + val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval()) + _validate( + model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 + ) + elif name == "sintel": + for pass_name in ("clean", "final"): + val_dataset = Sintel( + root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval() + ) + _validate( + model, + args, + val_dataset, + num_flow_updates=32, + padder_mode="sintel", + header=f"Sintel val {pass_name}", + ) + else: + warnings.warn(f"Can't validate on {val_dataset}, skipping.") + + +def main(args): + utils.setup_ddp(args) + + model = raft_small() if args.small else raft_large() + model = model.to(args.local_rank) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) + + if args.resume is not None: + d = torch.load(args.resume, map_location="cpu") + if args.map_orig_to_ours: + d = utils.map_orig_to_ours(d) + model.load_state_dict(d, strict=True) + + if args.train_dataset is None: + # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + validate(model, args) + return + + print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + torch.backends.cudnn.benchmark = True + + model.train() + if args.freeze_batch_norm: + utils.freeze_batch_norm(model.module) + + train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) + + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) + train_loader = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.batch_size, + pin_memory=True, + num_workers=args.num_workers, + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps) + + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=args.lr, + total_steps=args.num_steps + 100, + pct_start=0.05, + cycle_momentum=False, + anneal_strategy="linear", + ) + + logger = utils.MetricLogger() + + done = False + current_epoch = current_step = 0 + while not done: + sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs + print(f"EPOCH {current_epoch}") + + for data_blob in logger.log_every(train_loader): + + optimizer.zero_grad() + + image1, image2, flow, valid_flow_mask = (x.cuda() for x in data_blob) + flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) + + loss, metrics = utils.sequence_loss(flow_predictions, flow, valid_flow_mask, args.gamma) + loss.backward() + + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) + + optimizer.step() + scheduler.step() + + logger.update(**metrics) + + current_step += 1 + + if current_step == args.num_steps: + done = True + break + + # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 + print(f"Epoch {current_epoch} done. ", logger) + + current_epoch += 1 + + if args.rank == 0: + torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") + torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") + + if current_epoch % args.val_freq == 0 or done: + validate(model, args) + model.train() + if args.freeze_batch_norm: + utils.freeze_batch_norm(model.module) + + +def get_args_parser(add_help=True): + parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.") + parser.add_argument( + "--name", + default="raft", + type=str, + help="The name of the experiment - determines the name of the files where weights are saved.", + ) + parser.add_argument( + "--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored." + ) + parser.add_argument( + "--resume", + type=str, + help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", + ) + + parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.") + + parser.add_argument( + "--train-dataset", + type=str, + help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).", + ) + parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.") + parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs") + # TODO: eventually, it might be preferable to support epochs instead of num_steps. + # Keeping it this way for now to reproduce results more easily. + parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.") + parser.add_argument("--batch-size", type=int, default=6) + + parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer") + parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer") + parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer") + + parser.add_argument( + "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode." + ) + + parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.") + + parser.add_argument( + "--num_flow_updates", + type=int, + default=12, + help="number of updates (or 'iters') in the update operator of the model.", + ) + + parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.") + + parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training") + + # TODO: remove + parser.add_argument("--map-orig-to-ours", action="store_true") + + # TODO: remove the default + _DATASET_ROOT = "/data/home/nicolashug/cluster/work/downloads" + parser.add_argument( + "--dataset-root", + default=_DATASET_ROOT, + help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.", + ) + + return parser + + +if __name__ == "__main__": + args = get_args_parser().parse_args() + Path(args.output_dir).mkdir(exist_ok=True) + main(args) diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py new file mode 100644 index 00000000000..708fe68ce3f --- /dev/null +++ b/references/optical_flow/utils.py @@ -0,0 +1,390 @@ +import datetime +import os +import time +from collections import defaultdict +from collections import deque + +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"): + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, **kwargs): + self.meters[name] = SmoothedValue(**kwargs) + + def log_every(self, iterable, print_freq=5, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if print_freq is not None and i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400): + """Loss function defined over sequence of flow predictions""" + + if gamma > 1: + raise ValueError(f"Gamma should be < 1, got {gamma}.") + + # exlude invalid pixels and extremely large diplacements + norm_2 = torch.sum(flow_gt ** 2, dim=1).sqrt() + valid_flow_mask = valid_flow_mask & (norm_2 < max_flow) + + flow_loss = 0 + num_predictions = len(flow_preds) + for i, flow_pred in enumerate(flow_preds): + weight = gamma ** (num_predictions - i - 1) + abs_diff = (flow_pred - flow_gt).abs() + flow_loss += weight * (abs_diff * valid_flow_mask[:, None, :, :]).mean() + + last_pred = flow_preds[-1] + epe = ((last_pred - flow_gt) ** 2).sum(dim=1).sqrt() + epe = epe[valid_flow_mask] + + metrics = { + "flow_loss": flow_loss, + "epe": epe.mean().item(), + "1px": (epe < 1).float().mean().item(), + "3px": (epe < 3).float().mean().item(), + "5px": (epe < 5).float().mean().item(), + } + + return flow_loss, metrics + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + # TODO: Ideally, this should be part of the eval transforms preset, instead + # of being part of the validation code. It's not obvious what a good + # solution would be, because we need to unpad the predicted flows according + # to the input images' size, and in some datasets (Kitti) images can have + # variable sizes. + + def __init__(self, dims, mode="sintel"): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == "sintel": + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def _redefine_print(is_main): + """disables printing when not in main process""" + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def setup_ddp(args): + # Set the local_rank, rank, and world_size values as args fields + # This is done differently depending on how we're running the script. We + # currently support either torchrun or the custom run_with_submitit.py + # If you're confused (like I was), this might help a bit + # https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2 + + if all(key in os.environ for key in ("LOCAL_RANK", "RANK", "WORLD_SIZE")): + # if we're here, the script was called with torchrun. Otherwise + # these args will be set already by the run_with_submitit script + args.local_rank = int(os.environ["LOCAL_RANK"]) + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + + elif "gpu" in args: + # if we're here, the script was called by run_with_submitit.py + args.local_rank = args.gpu + else: + raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.") + + _redefine_print(is_main=(args.rank == 0)) + + torch.cuda.set_device(args.local_rank) + dist.init_process_group( + backend="nccl", + rank=args.rank, + world_size=args.world_size, + init_method=args.dist_url, + ) + + +def reduce_across_processes(val): + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t + + +def freeze_batch_norm(model): + for m in model.modules(): + if isinstance(m, torch.nn.BatchNorm2d): + m.eval() + + +def map_orig_to_ours(orig, mine=None): + # TODO: remove + d = {} + used_s_orig = set() + used_s_mine = set() + + def assert_and_add(s_orig, s_mine): + # print(s_orig, s_mine) + # print(orig[s_orig].shape, mine[s_mine].shape) + + assert s_orig not in used_s_orig + assert s_mine not in used_s_mine + + if mine is not None: + assert s_mine in mine + assert s_orig in orig + if mine is not None: + assert orig[s_orig].shape == mine[s_mine].shape + d["module." + s_mine] = orig[s_orig] + used_s_orig.add(s_orig) + used_s_mine.add(s_mine) + + for encoder_orig, encoder_mine in ( + ("fnet", "feature_encoder"), + ("cnet", "context_encoder"), + ): + for attr in ("bias", "weight"): + s_orig = f"module.{encoder_orig}.conv1.{attr}" + s_mine = f"{encoder_mine}.convnormrelu.0.{attr}" + assert_and_add(s_orig, s_mine) + + s_orig = f"module.{encoder_orig}.conv2.{attr}" + s_mine = f"{encoder_mine}.conv.{attr}" + assert_and_add(s_orig, s_mine) + + for layer in (1, 2, 3): + for block in (0, 1): + for conv in (1, 2): + s_orig = f"module.{encoder_orig}.layer{layer}.{block}.conv{conv}.{attr}" + s_mine = f"{encoder_mine}.layer{layer}.{block}.convnormrelu{conv}.0.{attr}" + assert_and_add(s_orig, s_mine) + + for layer in (2, 3): + s_orig = f"module.{encoder_orig}.layer{layer}.0.downsample.0.{attr}" + s_mine = f"{encoder_mine}.layer{layer}.0.downsample.0.{attr}" + assert_and_add(s_orig, s_mine) + + encoder_orig, encoder_mine = "cnet", "context_encoder" + for attr in ( + "bias", + "weight", + "running_mean", + "running_var", + "num_batches_tracked", + ): + s_orig = f"module.{encoder_orig}.norm1.{attr}" + s_mine = f"{encoder_mine}.convnormrelu.1.{attr}" + assert_and_add(s_orig, s_mine) + for layer in (1, 2, 3): + for block in (0, 1): + for norm in (1, 2): + s_orig = f"module.{encoder_orig}.layer{layer}.{block}.norm{norm}.{attr}" + s_mine = f"{encoder_mine}.layer{layer}.{block}.convnormrelu{norm}.1.{attr}" + assert_and_add(s_orig, s_mine) + for layer in (2, 3): + s_orig = f"module.{encoder_orig}.layer{layer}.0.downsample.1.{attr}" + s_mine = f"{encoder_mine}.layer{layer}.0.downsample.1.{attr}" + assert_and_add(s_orig, s_mine) + + corr_orig, corr_mine = ( + "module.update_block.encoder.", + "update_block.motion_encoder.", + ) + for attr in ("bias", "weight"): + for i in (1, 2): + s_orig = f"{corr_orig}convc{i}.{attr}" + s_mine = f"{corr_mine}convcorr{i}.0.{attr}" + assert_and_add(s_orig, s_mine) + s_orig = f"{corr_orig}convf{i}.{attr}" + s_mine = f"{corr_mine}convflow{i}.0.{attr}" + assert_and_add(s_orig, s_mine) + s_orig = f"{corr_orig}conv.{attr}" + s_mine = f"{corr_mine}conv.0.{attr}" + assert_and_add(s_orig, s_mine) + + rec_orig, rec_mine = "module.update_block.gru", "update_block.recurrent_block" + for attr in ("bias", "weight"): + for i in (1, 2): + for conv in ("convz", "convr", "convq"): + s_orig = f"{rec_orig}.{conv}{i}.{attr}" + s_mine = f"{rec_mine}.convgru{i}.{conv}.{attr}" + assert_and_add(s_orig, s_mine) + + flow_orig, flow_mine = "module.update_block.flow_head", "update_block.flow_head" + for attr in ("bias", "weight"): + for i in (1, 2): + s_orig = f"{flow_orig}.conv{i}.{attr}" + s_mine = f"{flow_mine}.conv{i}.{attr}" + assert_and_add(s_orig, s_mine) + for s_orig, s_mine in zip( + ( + "module.update_block.mask.0.weight", + "module.update_block.mask.0.bias", + "module.update_block.mask.2.weight", + "module.update_block.mask.2.bias", + ), + ( + "mask_predictor.convrelu.0.weight", + "mask_predictor.convrelu.0.bias", + "mask_predictor.conv.weight", + "mask_predictor.conv.bias", + ), + ): + assert_and_add(s_orig, s_mine) + + if mine is not None: + print(len(d), len(orig), len(mine)) + assert not (set(mine.keys()) - set(d.keys())) + return d From 4707883d7374df67d7a09c3237cbf5e1714e78b4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 4 Dec 2021 14:53:45 +0000 Subject: [PATCH 2/8] f1 computation: show percentage --- references/optical_flow/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 3bd2fa327b0..4091dfad1c9 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -105,7 +105,7 @@ def inner_loop(blob): # then average over the images. This is in contrast with the other epe computation, where we # average only once over all the pixels of all images. logger.meters["per_image_epe"].update(epe.mean().item(), n=1) # f1-epe in paper - logger.meters["f1"].update(bad_predictions.mean().item(), n=bad_predictions.numel()) # f1-all in paper + logger.meters["f1"].update(bad_predictions.mean().item() * 100, n=bad_predictions.numel()) # f1-all in paper logger = utils.MetricLogger() for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"): From e49844f91e99baef79c4250d880b090dbe2f97e2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 17:42:17 +0000 Subject: [PATCH 3/8] Unify epe and metrics computations --- references/optical_flow/train.py | 38 ++++++++++++++------------------ references/optical_flow/utils.py | 35 ++++++++++++++++++----------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 4091dfad1c9..d0e86bab0ad 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -59,13 +59,15 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b We process as many samples as possible with ddp, and process the rest on a single worker. """ + batch_size = batch_size or args.batch_size + model.eval() sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) val_loader = torch.utils.data.DataLoader( val_dataset, sampler=sampler, - batch_size=batch_size or args.batch_size, + batch_size=batch_size, pin_memory=True, num_workers=args.num_workers, ) @@ -87,25 +89,16 @@ def inner_loop(blob): flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates) flow_pred = flow_predictions[-1] - flow_pred = padder.unpad(flow_pred).cpu() - epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt() - - logger.meters["epe"].update(epe.mean().item(), n=epe.numel()) - for distance in (1, 3, 5): - logger.meters[f"{distance}px"].update((epe < distance).float().mean().item(), n=epe.numel()) - - relative_epe = epe / (flow_gt ** 2).sum(dim=1).sqrt() - if valid_flow_mask is not None: - epe, relative_epe = epe[valid_flow_mask], relative_epe[valid_flow_mask] - bad_predictions = ((epe > 3) & (relative_epe > 0.05)).float() + metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask) - # note the n=1 for per_image_epe: we compute an average over averages. We first average within each image and - # then average over the images. This is in contrast with the other epe computation, where we - # average only once over all the pixels of all images. - logger.meters["per_image_epe"].update(epe.mean().item(), n=1) # f1-epe in paper - logger.meters["f1"].update(bad_predictions.mean().item() * 100, n=bad_predictions.numel()) # f1-all in paper + # We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper). + # per-pixel epe: average epe of all pixels of all images + # per-image epe: average epe on each image independently, then average over images + for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper + logger.meters[name].update(metrics[name], n=num_pixels_tot) + logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size) logger = utils.MetricLogger() for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"): @@ -224,10 +217,15 @@ def main(args): optimizer.zero_grad() - image1, image2, flow, valid_flow_mask = (x.cuda() for x in data_blob) + image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) - loss, metrics = utils.sequence_loss(flow_predictions, flow, valid_flow_mask, args.gamma) + loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) + metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask) + + metrics.pop("f1") + logger.update(loss=loss, **metrics) + loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) @@ -235,8 +233,6 @@ def main(args): optimizer.step() scheduler.step() - logger.update(**metrics) - current_step += 1 if current_step == args.num_steps: diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py index 708fe68ce3f..34d3d494102 100644 --- a/references/optical_flow/utils.py +++ b/references/optical_flow/utils.py @@ -152,6 +152,27 @@ def log_every(self, iterable, print_freq=5, header=None): print(f"{header} Total time: {total_time_str}") +def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None): + + epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt() + flow_norm = (flow_gt ** 2).sum(dim=1).sqrt() + + if valid_flow_mask is not None: + epe = epe[valid_flow_mask] + flow_norm = flow_norm[valid_flow_mask] + + relative_epe = epe / flow_norm + + metrics = { + "epe": epe.mean().item(), + "1px": (epe < 1).float().mean().item(), + "3px": (epe < 3).float().mean().item(), + "5px": (epe < 5).float().mean().item(), + "f1": ((epe > 3) & (relative_epe > 0.05)).float().mean().item() * 100, + } + return metrics, epe.numel() + + def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400): """Loss function defined over sequence of flow predictions""" @@ -169,19 +190,7 @@ def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400) abs_diff = (flow_pred - flow_gt).abs() flow_loss += weight * (abs_diff * valid_flow_mask[:, None, :, :]).mean() - last_pred = flow_preds[-1] - epe = ((last_pred - flow_gt) ** 2).sum(dim=1).sqrt() - epe = epe[valid_flow_mask] - - metrics = { - "flow_loss": flow_loss, - "epe": epe.mean().item(), - "1px": (epe < 1).float().mean().item(), - "3px": (epe < 3).float().mean().item(), - "5px": (epe < 5).float().mean().item(), - } - - return flow_loss, metrics + return flow_loss class InputPadder: From 23967161c216bf724b1dbacc09deef9822b3b2bd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 18:17:48 +0000 Subject: [PATCH 4/8] avoid for loop in sequence_loss --- references/optical_flow/utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py index 34d3d494102..2428477ad84 100644 --- a/references/optical_flow/utils.py +++ b/references/optical_flow/utils.py @@ -180,15 +180,19 @@ def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400) raise ValueError(f"Gamma should be < 1, got {gamma}.") # exlude invalid pixels and extremely large diplacements - norm_2 = torch.sum(flow_gt ** 2, dim=1).sqrt() - valid_flow_mask = valid_flow_mask & (norm_2 < max_flow) - - flow_loss = 0 - num_predictions = len(flow_preds) - for i, flow_pred in enumerate(flow_preds): - weight = gamma ** (num_predictions - i - 1) - abs_diff = (flow_pred - flow_gt).abs() - flow_loss += weight * (abs_diff * valid_flow_mask[:, None, :, :]).mean() + flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt() + valid_flow_mask = valid_flow_mask & (flow_norm < max_flow) + + valid_flow_mask = valid_flow_mask[:, None, :, :] + + flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W) + + abs_diff = (flow_preds - flow_gt).abs() + abs_diff = (abs_diff * valid_flow_mask).mean(axis=(1, 2, 3, 4)) + + num_predictions = flow_preds.shape[0] + weights = gamma ** torch.arange(num_predictions - 1, -1, -1).to(flow_gt.device) + flow_loss = (abs_diff * weights).sum() return flow_loss From 2783402f640ed5c19e44d8984c7f43204069dd5c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 18:19:07 +0000 Subject: [PATCH 5/8] remove old code --- references/optical_flow/train.py | 3 - references/optical_flow/utils.py | 121 ------------------------------- 2 files changed, 124 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index d0e86bab0ad..e293cade40d 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -310,11 +310,8 @@ def get_args_parser(add_help=True): # TODO: remove parser.add_argument("--map-orig-to-ours", action="store_true") - # TODO: remove the default - _DATASET_ROOT = "/data/home/nicolashug/cluster/work/downloads" parser.add_argument( "--dataset-root", - default=_DATASET_ROOT, help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.", ) diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py index 2428477ad84..e3643a91663 100644 --- a/references/optical_flow/utils.py +++ b/references/optical_flow/utils.py @@ -280,124 +280,3 @@ def freeze_batch_norm(model): for m in model.modules(): if isinstance(m, torch.nn.BatchNorm2d): m.eval() - - -def map_orig_to_ours(orig, mine=None): - # TODO: remove - d = {} - used_s_orig = set() - used_s_mine = set() - - def assert_and_add(s_orig, s_mine): - # print(s_orig, s_mine) - # print(orig[s_orig].shape, mine[s_mine].shape) - - assert s_orig not in used_s_orig - assert s_mine not in used_s_mine - - if mine is not None: - assert s_mine in mine - assert s_orig in orig - if mine is not None: - assert orig[s_orig].shape == mine[s_mine].shape - d["module." + s_mine] = orig[s_orig] - used_s_orig.add(s_orig) - used_s_mine.add(s_mine) - - for encoder_orig, encoder_mine in ( - ("fnet", "feature_encoder"), - ("cnet", "context_encoder"), - ): - for attr in ("bias", "weight"): - s_orig = f"module.{encoder_orig}.conv1.{attr}" - s_mine = f"{encoder_mine}.convnormrelu.0.{attr}" - assert_and_add(s_orig, s_mine) - - s_orig = f"module.{encoder_orig}.conv2.{attr}" - s_mine = f"{encoder_mine}.conv.{attr}" - assert_and_add(s_orig, s_mine) - - for layer in (1, 2, 3): - for block in (0, 1): - for conv in (1, 2): - s_orig = f"module.{encoder_orig}.layer{layer}.{block}.conv{conv}.{attr}" - s_mine = f"{encoder_mine}.layer{layer}.{block}.convnormrelu{conv}.0.{attr}" - assert_and_add(s_orig, s_mine) - - for layer in (2, 3): - s_orig = f"module.{encoder_orig}.layer{layer}.0.downsample.0.{attr}" - s_mine = f"{encoder_mine}.layer{layer}.0.downsample.0.{attr}" - assert_and_add(s_orig, s_mine) - - encoder_orig, encoder_mine = "cnet", "context_encoder" - for attr in ( - "bias", - "weight", - "running_mean", - "running_var", - "num_batches_tracked", - ): - s_orig = f"module.{encoder_orig}.norm1.{attr}" - s_mine = f"{encoder_mine}.convnormrelu.1.{attr}" - assert_and_add(s_orig, s_mine) - for layer in (1, 2, 3): - for block in (0, 1): - for norm in (1, 2): - s_orig = f"module.{encoder_orig}.layer{layer}.{block}.norm{norm}.{attr}" - s_mine = f"{encoder_mine}.layer{layer}.{block}.convnormrelu{norm}.1.{attr}" - assert_and_add(s_orig, s_mine) - for layer in (2, 3): - s_orig = f"module.{encoder_orig}.layer{layer}.0.downsample.1.{attr}" - s_mine = f"{encoder_mine}.layer{layer}.0.downsample.1.{attr}" - assert_and_add(s_orig, s_mine) - - corr_orig, corr_mine = ( - "module.update_block.encoder.", - "update_block.motion_encoder.", - ) - for attr in ("bias", "weight"): - for i in (1, 2): - s_orig = f"{corr_orig}convc{i}.{attr}" - s_mine = f"{corr_mine}convcorr{i}.0.{attr}" - assert_and_add(s_orig, s_mine) - s_orig = f"{corr_orig}convf{i}.{attr}" - s_mine = f"{corr_mine}convflow{i}.0.{attr}" - assert_and_add(s_orig, s_mine) - s_orig = f"{corr_orig}conv.{attr}" - s_mine = f"{corr_mine}conv.0.{attr}" - assert_and_add(s_orig, s_mine) - - rec_orig, rec_mine = "module.update_block.gru", "update_block.recurrent_block" - for attr in ("bias", "weight"): - for i in (1, 2): - for conv in ("convz", "convr", "convq"): - s_orig = f"{rec_orig}.{conv}{i}.{attr}" - s_mine = f"{rec_mine}.convgru{i}.{conv}.{attr}" - assert_and_add(s_orig, s_mine) - - flow_orig, flow_mine = "module.update_block.flow_head", "update_block.flow_head" - for attr in ("bias", "weight"): - for i in (1, 2): - s_orig = f"{flow_orig}.conv{i}.{attr}" - s_mine = f"{flow_mine}.conv{i}.{attr}" - assert_and_add(s_orig, s_mine) - for s_orig, s_mine in zip( - ( - "module.update_block.mask.0.weight", - "module.update_block.mask.0.bias", - "module.update_block.mask.2.weight", - "module.update_block.mask.2.bias", - ), - ( - "mask_predictor.convrelu.0.weight", - "mask_predictor.convrelu.0.bias", - "mask_predictor.conv.weight", - "mask_predictor.conv.bias", - ), - ): - assert_and_add(s_orig, s_mine) - - if mine is not None: - print(len(d), len(orig), len(mine)) - assert not (set(mine.keys()) - set(d.keys())) - return d From 1f7159f5102de73c5e40093f798ca48a7457bf1f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 18:35:31 +0000 Subject: [PATCH 6/8] create separate train_one_epoch function --- references/optical_flow/train.py | 66 +++++++++++++++++++------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index e293cade40d..1e274b3cdd6 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -155,6 +155,35 @@ def validate(model, args): warnings.warn(f"Can't validate on {val_dataset}, skipping.") +def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args): + for data_blob in logger.log_every(train_loader): + + optimizer.zero_grad() + + image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) + flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) + + loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) + metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask) + + metrics.pop("f1") + logger.update(loss=loss, **metrics) + + loss.backward() + + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) + + optimizer.step() + scheduler.step() + + current_step += 1 + + if current_step == args.num_steps: + return True, current_step + + return False, current_step + + def main(args): utils.setup_ddp(args) @@ -210,34 +239,18 @@ def main(args): done = False current_epoch = current_step = 0 while not done: - sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs print(f"EPOCH {current_epoch}") - for data_blob in logger.log_every(train_loader): - - optimizer.zero_grad() - - image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) - flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) - - loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) - metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask) - - metrics.pop("f1") - logger.update(loss=loss, **metrics) - - loss.backward() - - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) - - optimizer.step() - scheduler.step() - - current_step += 1 - - if current_step == args.num_steps: - done = True - break + sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs + done, current_step = train_one_epoch( + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_loader=train_loader, + logger=logger, + current_step=current_step, + args=args, + ) # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 print(f"Epoch {current_epoch} done. ", logger) @@ -313,6 +326,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--dataset-root", help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.", + required=True, ) return parser From be1d9f319f853264dce8e754e54e656c0c259cb0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 18:38:56 +0000 Subject: [PATCH 7/8] Added TODO for the last remaining comment --- references/optical_flow/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 1e274b3cdd6..2a7c5f3f755 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -258,6 +258,7 @@ def main(args): current_epoch += 1 if args.rank == 0: + # TODO: Also save the optimizer and scheduler torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") From 2ef1af53b5f38bf31f63bc4c2d62e0753f4190d9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 19:10:06 +0000 Subject: [PATCH 8/8] remove old param --- references/optical_flow/train.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 2a7c5f3f755..eaf03fbe4f3 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -193,8 +193,6 @@ def main(args): if args.resume is not None: d = torch.load(args.resume, map_location="cpu") - if args.map_orig_to_ours: - d = utils.map_orig_to_ours(d) model.load_state_dict(d, strict=True) if args.train_dataset is None: @@ -321,9 +319,6 @@ def get_args_parser(add_help=True): parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training") - # TODO: remove - parser.add_argument("--map-orig-to-ours", action="store_true") - parser.add_argument( "--dataset-root", help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",