from transformers import LlamaConfig __all__ = ["HFOpenMoeConfig"] class HFOpenMoeConfig(LlamaConfig): model_type = "openmoe" def __init__( self, num_experts: int = 32, moe_layer_interval: int = 6, router_topk: int = 2, router_capacity_factor_train: float = 1.25, router_capacity_factor_eval: float = 2.0, router_min_capacity: int = 4, router_noisy_policy: str = None, router_drop_tks: bool = True, router_aux_loss_factor: float = 0.01, router_z_loss_factor: float = 0.0001, mlp_gated: bool = True, label_smoothing: float = 0.001, z_loss_factor: float = 0.01, enable_load_balance: bool = False, load_balance_tolerance: float = 0.1, load_balance_beam_width: int = 8, load_balance_group_swap_factor: float = 0.4, enable_kernel: bool = False, enable_comm_overlap: bool = False, enable_hierarchical_alltoall: bool = False, **kwargs ): self.num_experts = num_experts self.moe_layer_interval = moe_layer_interval self.router_topk = router_topk self.router_capacity_factor_train = router_capacity_factor_train self.router_capacity_factor_eval = router_capacity_factor_eval self.router_min_capacity = router_min_capacity self.router_noisy_policy = router_noisy_policy self.router_drop_tks = router_drop_tks self.router_aux_loss_factor = router_aux_loss_factor self.router_z_loss_factor = router_z_loss_factor self.mlp_gated = mlp_gated self.label_smoothing = label_smoothing self.z_loss_factor = z_loss_factor self.enable_load_balance = enable_load_balance self.load_balance_tolerance = load_balance_tolerance self.load_balance_beam_width = load_balance_beam_width self.load_balance_group_swap_factor = load_balance_group_swap_factor self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap self.enable_hierarchical_alltoall = enable_hierarchical_alltoall super().__init__(**kwargs)