openmoe-8b-native-pt / configuration_hf_openmoe.py
Phando's picture
init
7032fd5
raw
history blame contribute delete
No virus
2.25 kB
from transformers import LlamaConfig
__all__ = ["HFOpenMoeConfig"]
class HFOpenMoeConfig(LlamaConfig):
model_type = "openmoe"
def __init__(
self,
num_experts: int = 32,
moe_layer_interval: int = 6,
router_topk: int = 2,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
router_noisy_policy: str = None,
router_drop_tks: bool = True,
router_aux_loss_factor: float = 0.01,
router_z_loss_factor: float = 0.0001,
mlp_gated: bool = True,
label_smoothing: float = 0.001,
z_loss_factor: float = 0.01,
enable_load_balance: bool = False,
load_balance_tolerance: float = 0.1,
load_balance_beam_width: int = 8,
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_alltoall: bool = False,
**kwargs
):
self.num_experts = num_experts
self.moe_layer_interval = moe_layer_interval
self.router_topk = router_topk
self.router_capacity_factor_train = router_capacity_factor_train
self.router_capacity_factor_eval = router_capacity_factor_eval
self.router_min_capacity = router_min_capacity
self.router_noisy_policy = router_noisy_policy
self.router_drop_tks = router_drop_tks
self.router_aux_loss_factor = router_aux_loss_factor
self.router_z_loss_factor = router_z_loss_factor
self.mlp_gated = mlp_gated
self.label_smoothing = label_smoothing
self.z_loss_factor = z_loss_factor
self.enable_load_balance = enable_load_balance
self.load_balance_tolerance = load_balance_tolerance
self.load_balance_beam_width = load_balance_beam_width
self.load_balance_group_swap_factor = load_balance_group_swap_factor
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.enable_hierarchical_alltoall = enable_hierarchical_alltoall
super().__init__(**kwargs)