Skip to content

Commit 0ebd4ed

Browse files
committed
address #351
1 parent aa49c27 commit 0ebd4ed

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.15.7"
7+
version = "1.16.0"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vaat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def forward(
735735
mlp_dim = 384 * 4
736736
)
737737

738-
vat = VAAT(
738+
vaat = VAAT(
739739
vit,
740740
ast,
741741
dim = 512,
@@ -767,11 +767,11 @@ def forward(
767767

768768
actions = torch.randn(2, 7, 20) # actions for learning
769769

770-
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
770+
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
771771
loss.backward()
772772

773773
# after much training
774774

775-
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
775+
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
776776

777777
assert pred_actions.shape == (2, 7, 20)

vit_pytorch/vit.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
9090

9191
num_patches = (image_height // patch_height) * (image_width // patch_width)
9292
patch_dim = channels * patch_height * patch_width
93+
9394
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
95+
num_cls_tokens = 1 if pool == 'cls' else 0
9496

9597
self.to_patch_embedding = nn.Sequential(
9698
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
@@ -99,8 +101,9 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
99101
nn.LayerNorm(dim),
100102
)
101103

102-
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
103-
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
104+
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim))
105+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim))
106+
104107
self.dropout = nn.Dropout(emb_dropout)
105108

106109
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -114,7 +117,7 @@ def forward(self, img):
114117
x = self.to_patch_embedding(img)
115118
b, n, _ = x.shape
116119

117-
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
120+
cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b)
118121
x = torch.cat((cls_tokens, x), dim=1)
119122
x += self.pos_embedding[:, :(n + 1)]
120123
x = self.dropout(x)

0 commit comments

Comments
 (0)