DALLEEncoder¶
- class mmpretrain.models.selfsup.DALLEEncoder(group_count=4, n_hid=256, n_blk_per_group=2, input_channels=3, vocab_size=8192, device=device(type='cpu'), requires_grad=False, init_cfg=None)[源代码]¶
DALL-E Encoder for feature extraction.
- 参数:
group_count (int) – Number of groups in DALL-E encoder. Defaults to 4.
n_hid (int) – Dimension of hidden layers. Defaults to 256.
n_blk_per_group (int) – Number of blocks per group. Defaults to 2.
input_channels – (int): The channels of input images. Defaults to 3.
vocab_size (int) – Vocabulary size, indicating the number of classes. Defaults to 8192.
device (torch.device) – Device of parameters. Defaults to
torch.device('cpu')
.requires_grad (bool) – Require gradient or not. Defaults to False.
init_cfg (Union[List[dict], dict], optional) – Config dict for weight initialization. Defaults to None.
- forward(x)[源代码]¶
Forward function of DALL-E encoder.
- 参数:
x (torch.Tensor) – The input images with shape (B, C, H, W).
- 返回:
The output with shape (B, vocab_size, h, w).
- 返回类型: