Training Strong Models
Omnidata is a means to train models in different vision tasks. Here, we provide the code for training our depth and surface normal estimation models. You can train the models with the following commands:
Depth Estimation
We train DPT-based models on Omnidata using 3 different losses: scale- and shift-invariant loss
and scale-invariant gradient matching term
introduced in MiDaS, and also virtual normal loss
introduced here.
python train_depth.py --config_file config/depth.yml --experiment_name rgb2depth --val_check_interval 3000 --limit_val_batches 100 --max_epochs 10
MiDaS Implementation
We provide an implementation of the MiDaS Loss, specifically the ssimae (scale- and shift invariant MAE) loss
and the scale-invariant gradient matching term
in losses/midas_loss
.py. The video below shows the output of our MiDaS reimplementation (a DPT trained on the Omnidata starer datset) compared to the original DPT w/ MiDaS trained on a mix of 10 depth datasets that contains both real images and depth sensor readings. The resampled data from Omnidata seems not to hurt training, since the reimplemented version better captures the 3D shape (quantitative comparisons of depth estimation are in the paper).
MiDaS loss is useful for training depth estimation models on mixed datasets with different depth ranges and scales, similar to our dataset. An example usage is shown below:
from losses.midas_loss import MidasLoss
midas_loss = MidasLoss(alpha=0.1)
midas_loss, ssi_mae_loss, reg_loss = midas_loss(depth_prediction, depth_gt, mask)
alpha
specifies the weight of the gradient matching term
in the total loss, and mask
indicates the valid pixels of the image.
Surface Normal Estimation
We train a UNet architecture (6 down/6 up) for surface normal estimation using L1 Loss
and Cosine Angular Loss
.
python train_normal.py --config_file config/normal.yml --experiment_name rgb2normal --val_check_interval 3000 --limit_val_batches 100 --max_epochs 10
Here are some results (compared to X-Task Consistency):
3D Depth-of-Field Augmentation
Mid-level cues can be used for data augmentations in addition to training targets. The availability of full scene geometry in our dataset makes the possibility of doing Image Refocusing as a 3D data augmentation. You can find an implementation of this augmentation in data/refocus_augmentation.py
. You can run this augmentation on some sample images from our dataset with the following command.
python demo_refocus.py --input_path assets/demo_refocus/ --output_path assets/demo_refocus
This will refocus RGB images by blurring them according to depth_euclidean
for each image. You can specify some parameters of the augmentation with the following tags: --num_quantiles
(number of qualtiles to use in blur stack), --min_aperture
(smallest aperture to use), --max_aperture
(largest aperture to use). Aperture size is selected log-uniformly in the range between min and max aperture.
Shallow Focus | Mid Focus | Far Focus |
---|---|---|
Citation
If you find the code or models useful, please cite our paper:
@inproceedings{eftekhar2021omnidata,
title={Omnidata: A Scalable Pipeline for Making Multi-Task Mid-Level Vision Datasets From 3D Scans},
author={Eftekhar, Ainaz and Sax, Alexander and Malik, Jitendra and Zamir, Amir},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={10786--10796},
year={2021}
}