mmpretrain.models.utils.resize_pos_embed¶
- mmpretrain.models.utils.resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic', num_extra_tokens=1)[源代码]¶
Resize pos_embed weights.
- 参数:
pos_embed (torch.Tensor) – Position embedding weights with shape [1, L, C].
src_shape (tuple) – The resolution of downsampled origin training image, in format (H, W).
dst_shape (tuple) – The resolution of downsampled new training image, in format (H, W).
mode (str) – Algorithm used for upsampling. Choose one from ‘nearest’, ‘linear’, ‘bilinear’, ‘bicubic’ and ‘trilinear’. Defaults to ‘bicubic’.
num_extra_tokens (int) – The number of extra tokens, such as cls_token. Defaults to 1.
- 返回:
The resized pos_embed of shape [1, L_new, C]
- 返回类型: