Shortcuts

Note

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.models.utils.dall_e

# Copyright (c)
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
# Copied from BEiT
import math
from collections import OrderedDict
from functools import partial

import attr
import torch
import torch.nn as nn
import torch.nn.functional as F


@attr.s(eq=False)
class Conv2d(nn.Module):
    n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
    n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
    kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)

    use_float16: bool = attr.ib(default=True)
    device: torch.device = attr.ib(default=torch.device('cpu'))
    requires_grad: bool = attr.ib(default=False)

    def __attrs_post_init__(self) -> None:
        super().__init__()

        w = torch.empty((self.n_out, self.n_in, self.kw, self.kw),
                        dtype=torch.float32,
                        device=self.device,
                        requires_grad=self.requires_grad)
        w.normal_(std=1 / math.sqrt(self.n_in * self.kw**2))

        b = torch.zeros((self.n_out, ),
                        dtype=torch.float32,
                        device=self.device,
                        requires_grad=self.requires_grad)
        self.w, self.b = nn.Parameter(w), nn.Parameter(b)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_float16 and 'cuda' in self.w.device.type:
            if x.dtype != torch.float16:
                x = x.half()

            w, b = self.w.half(), self.b.half()
        else:
            if x.dtype != torch.float32:
                x = x.float()

            w, b = self.w, self.b

        return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)


@attr.s(eq=False, repr=False)
class EncoderBlock(nn.Module):
    n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
    n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0)
    n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)

    device: torch.device = attr.ib(default=None)
    requires_grad: bool = attr.ib(default=False)

    def __attrs_post_init__(self) -> None:
        super().__init__()
        self.n_hid = self.n_out // 4
        self.post_gain = 1 / (self.n_layers**2)

        make_conv = partial(
            Conv2d, device=self.device, requires_grad=self.requires_grad)
        self.id_path = make_conv(
            self.n_in, self.n_out,
            1) if self.n_in != self.n_out else nn.Identity()
        self.res_path = nn.Sequential(
            OrderedDict([
                ('relu_1', nn.ReLU()),
                ('conv_1', make_conv(self.n_in, self.n_hid, 3)),
                ('relu_2', nn.ReLU()),
                ('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
                ('relu_3', nn.ReLU()),
                ('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
                ('relu_4', nn.ReLU()),
                ('conv_4', make_conv(self.n_hid, self.n_out, 1)),
            ]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.id_path(x) + self.post_gain * self.res_path(x)


[docs]@attr.s(eq=False, repr=False) class Encoder(nn.Module): group_count: int = 4 n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) device: torch.device = attr.ib(default=torch.device('cpu')) requires_grad: bool = attr.ib(default=False) use_mixed_precision: bool = attr.ib(default=True) def __attrs_post_init__(self) -> None: super().__init__() blk_range = range(self.n_blk_per_group) n_layers = self.group_count * self.n_blk_per_group make_conv = partial( Conv2d, device=self.device, requires_grad=self.requires_grad) make_blk = partial( EncoderBlock, n_layers=n_layers, device=self.device, requires_grad=self.requires_grad) self.blocks = nn.Sequential( OrderedDict([ ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), ('group_1', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_2', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk( 1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_3', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk( 2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_4', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk( 4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], ]))), ('output', nn.Sequential( OrderedDict([ ('relu', nn.ReLU()), ('conv', make_conv( 8 * self.n_hid, self.vocab_size, 1, use_float16=False)), ]))), ]))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.float() if len(x.shape) != 4: raise ValueError(f'input shape {x.shape} is not 4d') if x.shape[1] != self.input_channels: raise ValueError(f'input has {x.shape[1]} channels but model \ built for {self.input_channels}') if x.dtype != torch.float32: raise ValueError('input must have dtype torch.float32') return self.blocks(x)
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.