Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.utils.position_embedding 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch


[文档]def build_2d_sincos_position_embedding(patches_resolution, embed_dims, temperature=10000., cls_token=False): """The function is to build position embedding for model to obtain the position information of the image patches.""" if isinstance(patches_resolution, int): patches_resolution = (patches_resolution, patches_resolution) h, w = patches_resolution grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dims % 4 == 0, \ 'Embed dimension must be divisible by 4.' pos_dim = embed_dims // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature**omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat( [ torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h) ], dim=1, )[None, :, :] if cls_token: cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) return pos_emb
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.