Skip to content

biobench.webssl

DinoVisionTransformer(img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, embed_layer=PatchEmbed, act_layer=torch.nn.GELU, block_fn=Block, ffn_layer='mlp', block_chunks=1, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1)

Bases: Module

Parameters:

Name Type Description Default
img_size (int, tuple)

input image size

224
patch_size (int, tuple)

patch size

16
in_chans int

number of input channels

3
embed_dim int

embedding dimension

768
depth int

depth of transformer

12
num_heads int

number of attention heads

12
mlp_ratio int

ratio of mlp hidden dim to embedding dim

4.0
qkv_bias bool

enable bias for qkv if True

True
proj_bias bool

enable bias for proj in attn if True

True
ffn_bias bool

enable bias for ffn if True

True
drop_path_rate float

stochastic depth rate

0.0
drop_path_uniform bool

apply uniform drop rate across blocks

False
weight_init str

weight init scheme

required
init_values float

layer-scale init values

None
embed_layer Module

patch embedding layer

PatchEmbed
act_layer Module

MLP activation layer

GELU
block_fn Module

transformer block class

Block
ffn_layer str

"mlp", "swiglu", "swiglufused" or "identity"

'mlp'
block_chunks

(int) split block sequence into block_chunks units for FSDP wrap

1
num_register_tokens

(int) number of extra cls tokens (so-called "registers")

0
interpolate_antialias

(str) flag to apply anti-aliasing when interpolating positional embeddings

False
interpolate_offset

(float) work-around offset to apply when interpolating positional embeddings

0.1
Source code in src/biobench/webssl.py
def __init__(
    self,
    img_size=224,
    patch_size=16,
    in_chans=3,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4.0,
    qkv_bias=True,
    ffn_bias=True,
    proj_bias=True,
    drop_path_rate=0.0,
    drop_path_uniform=False,
    init_values=None,  # for layerscale: None or 0 => no layerscale
    embed_layer=PatchEmbed,
    act_layer=torch.nn.GELU,
    block_fn=Block,
    ffn_layer="mlp",
    block_chunks=1,
    num_register_tokens=0,
    interpolate_antialias=False,
    interpolate_offset=0.1,
):
    """
    Args:
        img_size (int, tuple): input image size
        patch_size (int, tuple): patch size
        in_chans (int): number of input channels
        embed_dim (int): embedding dimension
        depth (int): depth of transformer
        num_heads (int): number of attention heads
        mlp_ratio (int): ratio of mlp hidden dim to embedding dim
        qkv_bias (bool): enable bias for qkv if True
        proj_bias (bool): enable bias for proj in attn if True
        ffn_bias (bool): enable bias for ffn if True
        drop_path_rate (float): stochastic depth rate
        drop_path_uniform (bool): apply uniform drop rate across blocks
        weight_init (str): weight init scheme
        init_values (float): layer-scale init values
        embed_layer (nn.Module): patch embedding layer
        act_layer (nn.Module): MLP activation layer
        block_fn (nn.Module): transformer block class
        ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
        block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
        num_register_tokens: (int) number of extra cls tokens (so-called "registers")
        interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
        interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
    """
    super().__init__()
    norm_layer = functools.partial(torch.nn.LayerNorm, eps=1e-6)

    self.num_features = self.embed_dim = (
        embed_dim  # num_features for consistency with other models
    )
    self.num_tokens = 1
    self.n_blocks = depth
    self.num_heads = num_heads
    self.img_size = img_size
    self.patch_size = patch_size
    self.num_register_tokens = num_register_tokens
    self.interpolate_antialias = interpolate_antialias
    self.interpolate_offset = interpolate_offset

    self.patch_embed = embed_layer(
        img_size=img_size,
        patch_size=patch_size,
        in_chans=in_chans,
        embed_dim=embed_dim,
    )
    num_patches = self.patch_embed.num_patches

    self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embed = torch.nn.Parameter(
        torch.zeros(1, num_patches + self.num_tokens, embed_dim)
    )
    assert num_register_tokens >= 0
    self.register_tokens = (
        torch.nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
        if num_register_tokens
        else None
    )

    if drop_path_uniform is True:
        dpr = [drop_path_rate] * depth
    else:
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule

    if ffn_layer == "mlp":
        logger.info("using MLP layer as FFN")
        ffn_layer = Mlp
    elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
        logger.info("using SwiGLU layer as FFN")
        ffn_layer = SwiGLUFFNFused
    elif ffn_layer == "identity":
        logger.info("using Identity layer as FFN")

        def f(*args, **kwargs):
            return torch.nn.Identity()

        ffn_layer = f
    else:
        raise NotImplementedError

    blocks_list = [
        block_fn(
            dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            proj_bias=proj_bias,
            ffn_bias=ffn_bias,
            drop_path=dpr[i],
            norm_layer=norm_layer,
            act_layer=act_layer,
            ffn_layer=ffn_layer,
            init_values=init_values,
        )
        for i in range(depth)
    ]
    if block_chunks > 0:
        self.chunked_blocks = True
        chunked_blocks = []
        chunksize = depth // block_chunks
        for i in range(0, depth, chunksize):
            # this is to keep the block index consistent if we chunk the block list
            chunked_blocks.append(
                [torch.nn.Identity()] * i + blocks_list[i : i + chunksize]
            )
        self.blocks = torch.nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
    else:
        self.chunked_blocks = False
        self.blocks = torch.nn.ModuleList(blocks_list)

    self.norm = norm_layer(embed_dim)
    self.head = torch.nn.Identity()

    self.mask_token = torch.nn.Parameter(torch.zeros(1, embed_dim))

    logger.info("Initializing weights")
    self.init_weights()

DropPath(drop_prob=None)

Bases: Module

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Source code in src/biobench/webssl.py
def __init__(self, drop_prob=None):
    super(DropPath, self).__init__()
    self.drop_prob = drop_prob

NestedTensorBlock(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, proj_bias=True, ffn_bias=True, drop=0.0, attn_drop=0.0, init_values=None, drop_path=0.0, act_layer=torch.nn.GELU, norm_layer=torch.nn.LayerNorm, attn_class=Attention, ffn_layer=Mlp)

Bases: Block

Source code in src/biobench/webssl.py
def __init__(
    self,
    dim: int,
    num_heads: int,
    mlp_ratio: float = 4.0,
    qkv_bias: bool = False,
    proj_bias: bool = True,
    ffn_bias: bool = True,
    drop: float = 0.0,
    attn_drop: float = 0.0,
    init_values=None,
    drop_path: float = 0.0,
    act_layer: typing.Callable[..., torch.nn.Module] = torch.nn.GELU,
    norm_layer: typing.Callable[..., torch.nn.Module] = torch.nn.LayerNorm,
    attn_class: typing.Callable[..., torch.nn.Module] = Attention,
    ffn_layer: typing.Callable[..., torch.nn.Module] = Mlp,
) -> None:
    super().__init__()
    # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
    self.norm1 = norm_layer(dim)
    self.attn = attn_class(
        dim,
        num_heads=num_heads,
        qkv_bias=qkv_bias,
        proj_bias=proj_bias,
        attn_drop=attn_drop,
        proj_drop=drop,
    )
    self.ls1 = (
        LayerScale(dim, init_values=init_values)
        if init_values
        else torch.nn.Identity()
    )
    self.drop_path1 = (
        DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
    )

    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = ffn_layer(
        in_features=dim,
        hidden_features=mlp_hidden_dim,
        act_layer=act_layer,
        drop=drop,
        bias=ffn_bias,
    )
    self.ls2 = (
        LayerScale(dim, init_values=init_values)
        if init_values
        else torch.nn.Identity()
    )
    self.drop_path2 = (
        DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
    )

    self.sample_drop_ratio = drop_path

forward_nested(x_list)

x_list contains a list of tensors to nest together and run

Source code in src/biobench/webssl.py
def forward_nested(self, x_list: list[torch.Tensor]) -> list[torch.Tensor]:
    """
    x_list contains a list of tensors to nest together and run
    """
    assert isinstance(self.attn, MemEffAttention)

    if self.training and self.sample_drop_ratio > 0.0:

        def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
            return self.attn(self.norm1(x), attn_bias=attn_bias)

        def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
            return self.mlp(self.norm2(x))

        x_list = drop_add_residual_stochastic_depth_list(
            x_list,
            residual_func=attn_residual_func,
            sample_drop_ratio=self.sample_drop_ratio,
            scaling_vector=self.ls1.gamma
            if isinstance(self.ls1, LayerScale)
            else None,
        )
        x_list = drop_add_residual_stochastic_depth_list(
            x_list,
            residual_func=ffn_residual_func,
            sample_drop_ratio=self.sample_drop_ratio,
            scaling_vector=self.ls2.gamma
            if isinstance(self.ls1, LayerScale)
            else None,
        )
        return x_list
    else:

        def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
            return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))

        def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
            return self.ls2(self.mlp(self.norm2(x)))

        attn_bias, x = get_attn_bias_and_cat(x_list)
        x = x + attn_residual_func(x, attn_bias=attn_bias)
        x = x + ffn_residual_func(x)
        return attn_bias.split(x)

PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten_embedding=True)

Bases: Module

2D image to patch embedding: (B,C,H,W) -> (B,N,D)

Parameters:

Name Type Description Default
img_size int | tuple[int, int]

Image size.

224
patch_size int | tuple[int, int]

Patch token size.

16
in_chans int

Number of input image channels.

3
embed_dim int

Number of linear projection output channels.

768
norm_layer Callable | None

Normalization layer.

None
Source code in src/biobench/webssl.py
def __init__(
    self,
    img_size: int | tuple[int, int] = 224,
    patch_size: int | tuple[int, int] = 16,
    in_chans: int = 3,
    embed_dim: int = 768,
    norm_layer: typing.Callable | None = None,
    flatten_embedding: bool = True,
) -> None:
    super().__init__()

    image_HW = make_2tuple(img_size)
    patch_HW = make_2tuple(patch_size)
    patch_grid_size = (
        image_HW[0] // patch_HW[0],
        image_HW[1] // patch_HW[1],
    )

    self.img_size = image_HW
    self.patch_size = patch_HW
    self.patches_resolution = patch_grid_size
    self.num_patches = patch_grid_size[0] * patch_grid_size[1]

    self.in_chans = in_chans
    self.embed_dim = embed_dim

    self.flatten_embedding = flatten_embedding

    self.proj = torch.nn.Conv2d(
        in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
    )
    self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()

get_attn_bias_and_cat(x_list, branges=None)

this will perform the index select, cat the tensors, and provide the attn_bias from cache

Source code in src/biobench/webssl.py
def get_attn_bias_and_cat(x_list, branges=None):
    """
    this will perform the index select, cat the tensors, and provide the attn_bias from cache
    """
    batch_sizes = (
        [b.shape[0] for b in branges]
        if branges is not None
        else [x.shape[0] for x in x_list]
    )
    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
    if all_shapes not in attn_bias_cache.keys():
        seqlens = []
        for b, x in zip(batch_sizes, x_list):
            for _ in range(b):
                seqlens.append(x.shape[1])
        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
        attn_bias._batch_sizes = batch_sizes
        attn_bias_cache[all_shapes] = attn_bias

    if branges is not None:
        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
            1, -1, x_list[0].shape[-1]
        )
    else:
        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
        cat_tensors = torch.cat(tensors_bs1, dim=1)

    return attn_bias_cache[all_shapes], cat_tensors

init_weights_vit_timm(module, name='')

ViT weight initialization, original timm impl (for reproducibility)

Source code in src/biobench/webssl.py
def init_weights_vit_timm(module: torch.nn.Module, name: str = ""):
    """ViT weight initialization, original timm impl (for reproducibility)"""
    if isinstance(module, torch.nn.Linear):
        torch.nn.init.trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)

webssl_dino1b_full2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-1B DINOv2's "giant2" architecture / ViT-little g Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64

Source code in src/biobench/webssl.py
def webssl_dino1b_full2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """
    Web-DINO ViT-1B
    DINOv2's "giant2" architecture / ViT-little g
    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
    """
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=1536,
        depth=40,
        num_heads=24,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino2b_full2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-2B (LLM-inspired scaling)

Source code in src/biobench/webssl.py
def webssl_dino2b_full2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-2B (LLM-inspired scaling)"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=2688,
        depth=24,
        num_heads=21,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino2b_heavy2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-2B (LLM-inspired scaling)

Source code in src/biobench/webssl.py
def webssl_dino2b_heavy2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-2B (LLM-inspired scaling)"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=2688,
        depth=24,
        num_heads=21,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino2b_light2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-2B (LLM-inspired scaling)

Source code in src/biobench/webssl.py
def webssl_dino2b_light2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-2B (LLM-inspired scaling)"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=2688,
        depth=24,
        num_heads=21,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino300m_full2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-300M DINOv2's "large" architecture / ViT-L

Source code in src/biobench/webssl.py
def webssl_dino300m_full2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """
    Web-DINO ViT-300M
    DINOv2's "large" architecture / ViT-L
    """
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino3b_full2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-3B (LLM-inspired scaling)

Source code in src/biobench/webssl.py
def webssl_dino3b_full2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-3B (LLM-inspired scaling)"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=3072,
        depth=26,
        num_heads=24,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino3b_heavy2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-3B (LLM-inspired scaling)

Source code in src/biobench/webssl.py
def webssl_dino3b_heavy2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-3B (LLM-inspired scaling)"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=3072,
        depth=26,
        num_heads=24,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino3b_light2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-3B (LLM-inspired scaling)

Source code in src/biobench/webssl.py
def webssl_dino3b_light2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-3B (LLM-inspired scaling)"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=3072,
        depth=26,
        num_heads=24,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino5b_full2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-5B (LLM-inspired scaling)

Source code in src/biobench/webssl.py
def webssl_dino5b_full2b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-5B (LLM-inspired scaling)"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=3584,
        depth=32,
        num_heads=28,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino7b_full8b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-7B (LLM-inspired scaling) pretrained with 224x224 resolution

Source code in src/biobench/webssl.py
def webssl_dino7b_full8b_224(
    img_size=224, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-7B (LLM-inspired scaling) pretrained with 224x224 resolution"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=4096,
        depth=32,
        num_heads=32,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino7b_full8b_378(img_size=378, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-7B (LLM-inspired scaling) pretrained with 378x378 resolution

Source code in src/biobench/webssl.py
def webssl_dino7b_full8b_378(
    img_size=378, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-7B (LLM-inspired scaling) pretrained with 378x378 resolution"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=4096,
        depth=32,
        num_heads=32,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model

webssl_dino7b_full8b_518(img_size=518, patch_size=14, num_register_tokens=0, **kwargs)

Web-DINO ViT-7B (LLM-inspired scaling) pretrained with 518x518 resolution

Source code in src/biobench/webssl.py
def webssl_dino7b_full8b_518(
    img_size=518, patch_size=14, num_register_tokens=0, **kwargs
):
    """Web-DINO ViT-7B (LLM-inspired scaling) pretrained with 518x518 resolution"""
    model = DinoVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        embed_dim=4096,
        depth=32,
        num_heads=32,
        mlp_ratio=4,
        ffn_layer="swiglu",
        init_values=1.0e-05,
        block_chunks=4,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        block_fn=functools.partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model