Shortcuts

mmpretrain.models.utils.resize_pos_embed

mmpretrain.models.utils.resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic', num_extra_tokens=1)[源代码]

Resize pos_embed weights.

参数:
  • pos_embed (torch.Tensor) – Position embedding weights with shape [1, L, C].

  • src_shape (tuple) – The resolution of downsampled origin training image, in format (H, W).

  • dst_shape (tuple) – The resolution of downsampled new training image, in format (H, W).

  • mode (str) – Algorithm used for upsampling. Choose one from ‘nearest’, ‘linear’, ‘bilinear’, ‘bicubic’ and ‘trilinear’. Defaults to ‘bicubic’.

  • num_extra_tokens (int) – The number of extra tokens, such as cls_token. Defaults to 1.

返回:

The resized pos_embed of shape [1, L_new, C]

返回类型:

torch.Tensor