Skip to content

Commit

Permalink
feat: Added initial MoE support
Browse files Browse the repository at this point in the history
  • Loading branch information
mali-git committed Apr 8, 2024
1 parent 2f00705 commit 3585aed
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy
from enum import Enum
from functools import partial
from typing import Annotated, Dict, List, Tuple
from typing import Annotated, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -13,6 +13,7 @@
from modalities.config.config import PydanticPytorchModuleType
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import NNModel
from modalities.src.modalities.nn.moe import MoEFFN, MoEFFNConfig
from modalities.util import parse_enum_by_name

# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
Expand Down Expand Up @@ -309,6 +310,7 @@ def __init__(
ffn_hidden: int,
attention_norm: nn.Module,
ffn_norm: nn.Module,
moe_config: Optional[MoEFFNConfig] = None,
):
super().__init__()
self.attention_norm = attention_norm
Expand All @@ -322,13 +324,16 @@ def __init__(
dropout=dropout,
block_size=block_size,
)
if activation_type == ActivationType.GELU:
self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout)
elif activation_type == ActivationType.FUSED_SWIGLU:
hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256)
self.mlp = xops.SwiGLU(n_embd, hidden_dim, n_embd, bias=False)
if moe_config:
self.mlp = MoEFFN(hidden_size=ffn_hidden, config=moe_config)
else:
raise NotImplementedError("unimplemented activation")
if activation_type == ActivationType.GELU:
self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout)
elif activation_type == ActivationType.FUSED_SWIGLU:
hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256)
self.mlp = xops.SwiGLU(n_embd, hidden_dim, n_embd, bias=False)
else:
raise NotImplementedError("unimplemented activation")

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attention_norm(x)
Expand Down Expand Up @@ -359,6 +364,7 @@ def __init__(
attention_norm: nn.Module,
ffn_norm: nn.Module,
lm_head_norm: nn.Module,
moe_config: Optional[MoEFFNConfig] = None,
):
super().__init__()
self.sample_key = sample_key
Expand Down Expand Up @@ -404,6 +410,7 @@ def __init__(
ffn_hidden=ffn_hidden,
attention_norm=deepcopy(attention_norm),
ffn_norm=deepcopy(ffn_norm),
moe_config=moe_config,
)
for _ in range(n_layer)
]
Expand Down
159 changes: 159 additions & 0 deletions src/modalities/nn/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Callable, Optional, Tuple

import torch
import torch.nn as nn
from pydantic import BaseModel


# MoE implementation inspired from Dbrx https://github.com/databricks/dbrx/blob/main/model/modeling_dbrx.py
class MoEFFNConfig(BaseModel):
moe_num_experts: int
moe_top_k: int
moe_normalize_expert_weights: float
uniform_expert_assignment: bool
ffn_hidden_size: int
act_fn: Callable[[], nn.Module] = nn.SiLU
moe_jitter_eps: float


class MoERouter(nn.Module):
def __init__(
self,
hidden_size: int,
moe_num_experts: int,
moe_top_k: int,
moe_normalize_expert_weights: Optional[float],
uniform_expert_assignment: bool,
moe_jitter_eps: float,

This comment has been minimized.

Copy link
@thomaschhh

thomaschhh Apr 17, 2024

Member

Shouldn't we set restrictions for this value like Annotated[float, Field(strict=True, gt=0)]? That way we could also remove lines 65-66.

):
super().__init__()
self.hidden_size = hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_normalize_expert_weights = moe_normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment
self.moe_jitter_eps = moe_jitter_eps

self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
if self.training and self.moe_jitter_eps is not None:
x = x * self.__jitter(x)

weights = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1, dtype=torch.float32)
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)

if self.moe_normalize_expert_weights:
top_weights = top_weights / torch.norm(
top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)

if self.uniform_expert_assignment:
with torch.no_grad():
uniform_tensor = (
torch.arange(0, top_experts.numel(), device=top_experts.device, dtype=top_experts.dtype)
% self.moe_num_experts
)
top_experts = uniform_tensor.reshape(top_experts.shape)
# Note, weights and top_weights are not changed

weights = weights.to(x.dtype)
top_weights = top_weights.to(x.dtype)
return weights, top_weights, top_experts # type: ignore

def __jitter(self, x: torch.Tensor) -> torch.Tensor:
if self.moe_jitter_eps is None:
raise RuntimeError("The router does not have moe_jitter_eps set.")
low = 1.0 - self.moe_jitter_eps
high = 1.0 + self.moe_jitter_eps
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
return low + noise * (high - low)


class MoEExpertGLU(nn.Module):
def __init__(
self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, act_fn: Callable[[], nn.Module] = nn.GELU
):
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.activation_fn = act_fn()

self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))

def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]

x1 = x.matmul(expert_w1.t())
x2 = x.matmul(expert_v1.t())
x1 = self.activation_fn(x1)
x1 = x1 * x2
x1 = x1.matmul(expert_w2)
return x1


class MoEExperts(nn.Module):
def __init__(
self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, act_fn: Callable[[], nn.Module] = nn.GELU
):
super().__init__()
self.moe_num_experts = moe_num_experts
self.mlp = MoEExpertGLU(
hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, act_fn=act_fn
)

def forward(
self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
) -> torch.Tensor:
bsz, q_len, hidden_size = x.shape
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)

expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue

token_list = token_idx.tolist()
topk_list = topk_idx.tolist()

expert_tokens = x[None, token_list].reshape(-1, hidden_size)
expert_out = self.mlp(expert_tokens, expert_idx) * top_weights[token_list, topk_list, None]

out.index_add_(0, token_idx, expert_out)

out = out.reshape(bsz, q_len, hidden_size)
return out


class MoEFFN(nn.Module):
def __init__(self, hidden_size: int, config: MoEFFNConfig):
super().__init__()
self.config = config

self.router = MoERouter(
hidden_size,
moe_num_experts=self.config.moe_num_experts,
moe_top_k=self.config.moe_top_k,
moe_normalize_expert_weights=self.config.moe_normalize_expert_weights,
uniform_expert_assignment=self.config.uniform_expert_assignment,
moe_jitter_eps=self.config.moe_jitter_eps,
)

self.experts = MoEExperts(
hidden_size=hidden_size,
ffn_hidden_size=self.config.ffn_hidden_size,
moe_num_experts=self.config.moe_num_experts,
act_fn=self.config.act_fn,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
weights, top_weights, top_experts = self.router(x)
out = self.experts(x, weights, top_weights, top_experts)
return out

0 comments on commit 3585aed

Please sign in to comment.