How to Pretrain with Custom Dataset¶
In this tutorial, we provide a practice example and some tips on how to train on your own dataset.
In MMPretrain, We support the CustomDataset
(similar to the ImageFolder
in torchvision
), which is able to read the images within the specified folder directly. You only need to prepare the path information of the custom dataset and edit the config.
Step-1: Prepare your dataset¶
Prepare your dataset following Prepare Dataset.
And the root folder of the dataset can be like data/custom_dataset/
.
Here, we assume you want to do unsupervised training, and use the sub-folder format CustomDataset
to
organize your dataset as:
data/custom_dataset/
├── sample1.png
├── sample2.png
├── sample3.png
├── sample4.png
└── ...
Step-2: Choose one config as template¶
Here, we would like to use configs/mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
as the example. We
first copy this config file to the same folder and rename it as
mae_vit-base-p16_8xb512-amp-coslr-300e_custom.py
.
Tip
As a convention, the last field of the config name is the dataset, e.g.,in1k
for ImageNet dataset, coco
for COCO dataset
The content of this config is:
_base_ = [
'../_base_/models/mae_vit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=260,
by_epoch=True,
begin=40,
end=300,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)
Another example: Train MAE on COCO Dataset¶
Note
You need to install MMDetection to use the mmdet.CocoDataset
follow this documentation
Follow the aforementioned idea, we also present an example of how to train MAE on COCO dataset. The edited file will be like this:
_base_ = [
'../_base_/models/mae_vit-base-p16.py',
'../_base_/datasets/imagenet_mae.py',
'../_base_/default_runtime.py',
]
# >>>>>>>>>>>>>>> Override dataset settings here >>>>>>>>>>>>>>>>>>>
train_dataloader = dict(
dataset=dict(
type='mmdet.CocoDataset',
data_root='data/coco/',
ann_file='annotations/instances_train2017.json', # Only for loading images, and the labels won't be used.
data_prefix=dict(img='train2017/'),
)
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=260,
by_epoch=True,
begin=40,
end=300,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)