@@ -90,7 +90,9 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
9090
9191 num_patches = (image_height // patch_height ) * (image_width // patch_width )
9292 patch_dim = channels * patch_height * patch_width
93+
9394 assert pool in {'cls' , 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)'
95+ num_cls_tokens = 1 if pool == 'cls' else 0
9496
9597 self .to_patch_embedding = nn .Sequential (
9698 Rearrange ('b c (h p1) (w p2) -> b (h w) (p1 p2 c)' , p1 = patch_height , p2 = patch_width ),
@@ -99,8 +101,9 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
99101 nn .LayerNorm (dim ),
100102 )
101103
102- self .pos_embedding = nn .Parameter (torch .randn (1 , num_patches + 1 , dim ))
103- self .cls_token = nn .Parameter (torch .randn (1 , 1 , dim ))
104+ self .cls_token = nn .Parameter (torch .randn (1 , num_cls_tokens , dim ))
105+ self .pos_embedding = nn .Parameter (torch .randn (1 , num_patches + num_cls_tokens , dim ))
106+
104107 self .dropout = nn .Dropout (emb_dropout )
105108
106109 self .transformer = Transformer (dim , depth , heads , dim_head , mlp_dim , dropout )
@@ -114,7 +117,7 @@ def forward(self, img):
114117 x = self .to_patch_embedding (img )
115118 b , n , _ = x .shape
116119
117- cls_tokens = repeat (self .cls_token , '1 1 d -> b 1 d' , b = b )
120+ cls_tokens = repeat (self .cls_token , '1 ... d -> b ... d' , b = b )
118121 x = torch .cat ((cls_tokens , x ), dim = 1 )
119122 x += self .pos_embedding [:, :(n + 1 )]
120123 x = self .dropout (x )
0 commit comments