Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified main.py
100755 → 100644
Empty file.
23 changes: 22 additions & 1 deletion src/bert_layers/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
loss_kwargs: dict = {},
mlp_dropout_prob: float = 0.0,
mlp_in_bias: bool = False,
mlp_layer: str = "mlp",
mlp_layer: str = "glu_moe",
mlp_out_bias: bool = False,
norm_kwargs: dict = {},
normalization: str = "rmsnorm",
Expand Down Expand Up @@ -97,6 +97,13 @@ def __init__(
pad_logits: bool = False,
compile_model: bool = False,
masked_prediction: bool = False,
moe_num_experts: int = 8,
moe_top_k: int = 2,
moe_use_noisy_top_k: bool = True,
moe_capacity_factor: float = 1.25,
moe_compute_aux_loss: bool = True,
moe_load_balance_loss_weight: float = 0.01,
moe_router_z_loss_weight: float = 0.001,
**kwargs,
):
"""
Expand Down Expand Up @@ -156,6 +163,13 @@ def __init__(
pad_logits (bool): Pad logits after the calculating the loss.
compile_model (bool): Compile the subset of the model which can be compiled.
masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
moe_num_experts (int): Number of experts for Mixture of Experts layers.
moe_top_k (int): Number of top experts to select for each token in MoE.
moe_use_noisy_top_k (bool): Use noisy top-k gating for exploration during training.
moe_capacity_factor (float): Capacity factor for expert assignment in MoE.
moe_compute_aux_loss (bool): Whether to compute and add auxiliary losses for MoE.
moe_load_balance_loss_weight (float): Weight for the load balancing auxiliary loss.
moe_router_z_loss_weight (float): Weight for the router z-loss auxiliary loss.
**kwargs: Additional keyword arguments.
"""
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
Expand Down Expand Up @@ -213,6 +227,13 @@ def __init__(
self.pad_logits = pad_logits
self.compile_model = compile_model
self.masked_prediction = masked_prediction
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_use_noisy_top_k = moe_use_noisy_top_k
self.moe_capacity_factor = moe_capacity_factor
self.moe_compute_aux_loss = moe_compute_aux_loss
self.moe_load_balance_loss_weight = moe_load_balance_loss_weight
self.moe_router_z_loss_weight = moe_router_z_loss_weight

if loss_kwargs.get("return_z_loss", False):
if loss_function != "fa_cross_entropy":
Expand Down
75 changes: 75 additions & 0 deletions src/bert_layers/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# License: Apache-2.0

import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configuration_bert import FlexBertConfig

try:
Expand All @@ -20,6 +22,79 @@
LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss


class MoELoadBalancingLoss(nn.Module):
"""Computes Switch Transformer auxiliary loss for load balancing.

Reference: https://arxiv.org/abs/2101.03961 (equations 4-6, page 7)

This loss encourages balanced token allocation across experts to avoid
scenarios where some experts are overloaded while others are underutilized.
"""

def __init__(self, num_experts: int, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k

def forward(
self,
router_logits: torch.Tensor,
expert_indices: torch.Tensor,
) -> torch.Tensor:
"""
Args:
router_logits: Router logits [batch_size, seq_len, num_experts]
expert_indices: Top-k expert indices [batch_size, seq_len, top_k]

Returns:
load_balance_loss: Scalar loss value
"""
# Compute expert probabilities
expert_probs = F.softmax(router_logits, dim=-1) # [B, C, n_exp]

# Equation (5): compute ratio of tokens allocated to each expert
with torch.no_grad():
one_hot_indices = F.one_hot(expert_indices, num_classes=self.num_experts) # [B, C, K, n_exp]
one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, C, n_exp]
tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1)) # [n_exp]

# Equation (6): compute ratio of router probability allocated to each expert
prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1)) # [n_exp]

# Equation (4): scaled dot product between prob / token allocation vectors
load_balance_loss = self.num_experts * torch.sum(prob_per_expert * tokens_per_expert)

return load_balance_loss


class MoERouterZLoss(nn.Module):
"""Computes router z-loss for MoE models.

Reference: https://arxiv.org/abs/2202.08906 (equation 5, page 7)

This loss constrains the size of router logits to avoid numerical instability
during training. Large logits can lead to round-off errors in the softmax computation,
even in float32 precision.
"""

def forward(self, router_logits: torch.Tensor) -> torch.Tensor:
"""
Args:
router_logits: Router logits [batch_size, seq_len, num_experts]

Returns:
router_z_loss: Scalar loss value
"""
# Numerically stable computation: logsumexp is equivalent to log(sum(exp(x)))
# This avoids overflow issues from directly exponentiating large logits
router_z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, C]

# Average over all tokens
router_z_loss = torch.mean(router_z_loss)

return router_z_loss


def get_loss_fn(config: FlexBertConfig) -> nn.Module:
try:
loss_class = LOSS2CLS[config.loss_function]
Expand Down
191 changes: 191 additions & 0 deletions src/bert_layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from .configuration_bert import FlexBertConfig
from .activation import get_act_fn
from .normalization import get_norm_layer
from .initialization import ModuleType, init_weights
from .loss import MoELoadBalancingLoss, MoERouterZLoss


class BertResidualGLU(nn.Module):
Expand Down Expand Up @@ -190,10 +192,197 @@ def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor:
return self.Wo(self.drop(self.act(input) * gate))


class Router(nn.Module):
"""Top-K router for selecting experts."""

def __init__(
self,
d: int,
n_exp: int,
top_k: int = 2,
use_noisy_top_k: bool = True,
capacity_factor: float = 1.25,
):
super().__init__()
self.d = d
self.n_exp = n_exp
self.top_k = top_k
self.use_noisy_top_k = use_noisy_top_k
self.capacity_factor = capacity_factor

# Router weights to compute logits for each expert
self.gate = nn.Linear(d, n_exp, bias=False)

# Noise parameters for noisy top-k gating
if use_noisy_top_k:
self.w_noise = nn.Linear(d, n_exp, bias=False)

def forward(self, x: torch.Tensor):
"""
Args:
x: [batch_size, seq_len, d]
Returns:
exp_weight: Expert weights [batch_size * seq_len, top_k]
exp_mask: Expert mask [batch_size * seq_len, n_exp, exp_capacity]
exp_batches: Token assignments [n_exp, exp_capacity, d]
"""
B, C, d = x.size()
num_tokens = B * C
x_flat = x.view(num_tokens, d)

# Compute router logits
logits = self.gate(x_flat) # [num_tokens, n_exp]

# Add noise for exploration (optional)
if self.use_noisy_top_k and self.training:
noise_stddev = F.softplus(self.w_noise(x_flat))
noise = torch.randn_like(logits) * noise_stddev
logits = logits + noise

# Select top-k experts
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
top_k_gates = F.softmax(top_k_logits, dim=-1) # [num_tokens, top_k]

# Compute expert capacity
exp_capacity = int((num_tokens * self.top_k * self.capacity_factor) / self.n_exp)

# Create expert assignment mask and batches
exp_mask = torch.zeros(num_tokens, self.n_exp, exp_capacity, device=x.device)
exp_batches = torch.zeros(self.n_exp, exp_capacity, d, device=x.device)

# Count tokens assigned to each expert
expert_counts = torch.zeros(self.n_exp, dtype=torch.long, device=x.device)

# Assign tokens to experts
for token_idx in range(num_tokens):
for k_idx in range(self.top_k):
expert_idx = top_k_indices[token_idx, k_idx]
if expert_counts[expert_idx] < exp_capacity:
pos = expert_counts[expert_idx]
exp_mask[token_idx, expert_idx, pos] = top_k_gates[token_idx, k_idx]
exp_batches[expert_idx, pos] = x_flat[token_idx]
expert_counts[expert_idx] += 1

return top_k_gates, exp_mask, exp_batches





class FlexBertGLUMoE(FlexBertMLPBase):
"""Mixture of Experts with GLU activation for FlexBERT."""

def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)

self.n_exp = getattr(config, 'moe_num_experts', 4)
self.top_k = getattr(config, 'moe_top_k', 2)
self.use_noisy_top_k = getattr(config, 'moe_use_noisy_top_k', True)
self.capacity_factor = getattr(config, 'moe_capacity_factor', 1.25)
self.compute_aux_loss = getattr(config, 'moe_compute_aux_loss', True)
self.load_balance_loss_weight = getattr(config, 'moe_load_balance_loss_weight', 0.01)
self.router_z_loss_weight = getattr(config, 'moe_router_z_loss_weight', 0.001)

self.router = Router(
d=config.hidden_size,
n_exp=self.n_exp,
top_k=self.top_k,
use_noisy_top_k=self.use_noisy_top_k,
capacity_factor=self.capacity_factor,
)

# GLU experts (each projects to 2x intermediate size)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=config.mlp_in_bias),
nn.Identity(), # Placeholder for chunking + activation
nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity(),
nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias),
)
for _ in range(self.n_exp)
])
self.act = get_act_fn(config.hidden_act)

# Initialize auxiliary loss modules
if self.compute_aux_loss:
self.load_balance_loss = MoELoadBalancingLoss(num_experts=self.n_exp, top_k=self.top_k)
self.router_z_loss = MoERouterZLoss()

def _init_weights(self, reset_params: bool = False):
init_weights(
self.config,
self.router.gate,
layer_dim=self.config.hidden_size,
layer_id=self.layer_id,
type_of_module=ModuleType.in_module,
)

for expert in self.experts:
for i, module in enumerate(expert):
if isinstance(module, nn.Linear):
init_weights(
self.config,
module,
layer_dim=self.config.hidden_size if i == 0 else self.config.intermediate_size,
layer_id=self.layer_id,
type_of_module=ModuleType.in_module if i == 0 else ModuleType.out_module,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
original_shape = hidden_states.shape
if hidden_states.dim() == 2:
hidden_states = hidden_states.unsqueeze(0)

B, C, d = hidden_states.size()
num_tokens = B * C
x_flat = hidden_states.view(num_tokens, d)

# Compute router logits for auxiliary loss calculation
router_logits = self.router.gate(x_flat) # [num_tokens, n_exp]

exp_weight, exp_mask, exp_batches = self.router(hidden_states)

# Extract top-k indices from router for load balancing loss
_, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) # [num_tokens, top_k]

# Apply GLU experts
exp_out = torch.zeros_like(exp_batches)
for i, expert in enumerate(self.experts):
x = expert[0](exp_batches[i]) # Linear projection
input, gate = x.chunk(2, dim=-1) # Split for GLU
x = self.act(input) * gate # GLU activation
x = expert[2](x) # Dropout
exp_out[i] = expert[3](x) # Output projection

exp_weight_flat = exp_mask.view(num_tokens, -1)
exp_out_flat = exp_out.view(-1, d)
output = torch.matmul(exp_weight_flat, exp_out_flat)

# Compute auxiliary losses
self.aux_loss = None
if self.compute_aux_loss:
# Reshape for loss computation
router_logits_reshaped = router_logits.view(B, C, -1)
top_k_indices_reshaped = top_k_indices.view(B, C, -1)

# Compute load balancing loss
lb_loss = self.load_balance_loss(router_logits_reshaped, top_k_indices_reshaped)

# Compute router z-loss
z_loss = self.router_z_loss(router_logits_reshaped)

# Combine auxiliary losses with weights
self.aux_loss = self.load_balance_loss_weight * lb_loss + self.router_z_loss_weight * z_loss

return output.view(*original_shape)


# Update the MLP registry
MLP2CLS = {
"mlp": FlexBertMLP,
"glu": FlexBertGLU,
"parallel_glu": FlexBertParallelGLU,
"glu_moe": FlexBertGLUMoE,
}


Expand All @@ -212,3 +401,5 @@ def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> Fle
)
else:
raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}")


Loading