Fixes to mc4 fork
Browse files- mc4/mc4.py +11 -12
- run_mlm_flax_stream.py +1 -1
mc4/mc4.py
CHANGED
@@ -283,20 +283,20 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
284 |
|
285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
286 |
-
self.
|
287 |
self.sampling_method = kwargs.pop("sampling_method", None)
|
|
|
|
|
|
|
|
|
288 |
if self.sampling_method:
|
289 |
-
|
290 |
-
|
291 |
-
self.rng = default_rng(seed)
|
292 |
else:
|
293 |
self.rng = default_rng()
|
294 |
if self.sampling_method == "random":
|
295 |
self.should_keep_doc = self._should_keep_doc_random
|
296 |
else:
|
297 |
-
self.perplexity_model = kwargs.pop("perplexity_model", None)
|
298 |
-
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
299 |
-
self.boundaries = kwargs.pop("boundaries", None)
|
300 |
# Loading 5-gram model
|
301 |
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
302 |
logger.info("loading model = %s", self.perplexity_model)
|
@@ -305,7 +305,6 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
305 |
self.should_keep_doc = self._should_keep_doc_gaussian
|
306 |
else:
|
307 |
self.should_keep_doc = self._should_keep_doc_step
|
308 |
-
|
309 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
310 |
|
311 |
def get_perplexity(self, doc):
|
@@ -375,14 +374,14 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
375 |
for lang in self.config.languages
|
376 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
377 |
]
|
378 |
-
if "train" in self.
|
379 |
-
train_downloaded_files = self.
|
380 |
if not isinstance(train_downloaded_files, (tuple, list)):
|
381 |
train_downloaded_files = [train_downloaded_files]
|
382 |
else:
|
383 |
train_downloaded_files = dl_manager.download(data_urls["train"])
|
384 |
-
if "validation" in self.
|
385 |
-
validation_downloaded_files = self.
|
386 |
if not isinstance(validation_downloaded_files, (tuple, list)):
|
387 |
validation_downloaded_files = [validation_downloaded_files]
|
388 |
else:
|
|
|
283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
284 |
|
285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
286 |
+
self.data_files = kwargs.pop("data_files", {})
|
287 |
self.sampling_method = kwargs.pop("sampling_method", None)
|
288 |
+
self.perplexity_model = kwargs.pop("perplexity_model", None)
|
289 |
+
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
290 |
+
self.boundaries = kwargs.pop("boundaries", None)
|
291 |
+
self.seed = kwargs.pop("seed", None)
|
292 |
if self.sampling_method:
|
293 |
+
if self.seed is not None:
|
294 |
+
self.rng = default_rng(self.seed)
|
|
|
295 |
else:
|
296 |
self.rng = default_rng()
|
297 |
if self.sampling_method == "random":
|
298 |
self.should_keep_doc = self._should_keep_doc_random
|
299 |
else:
|
|
|
|
|
|
|
300 |
# Loading 5-gram model
|
301 |
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
302 |
logger.info("loading model = %s", self.perplexity_model)
|
|
|
305 |
self.should_keep_doc = self._should_keep_doc_gaussian
|
306 |
else:
|
307 |
self.should_keep_doc = self._should_keep_doc_step
|
|
|
308 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
309 |
|
310 |
def get_perplexity(self, doc):
|
|
|
374 |
for lang in self.config.languages
|
375 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
376 |
]
|
377 |
+
if "train" in self.data_files:
|
378 |
+
train_downloaded_files = self.data_files["train"]
|
379 |
if not isinstance(train_downloaded_files, (tuple, list)):
|
380 |
train_downloaded_files = [train_downloaded_files]
|
381 |
else:
|
382 |
train_downloaded_files = dl_manager.download(data_urls["train"])
|
383 |
+
if "validation" in self.data_files:
|
384 |
+
validation_downloaded_files = self.data_files["validation"]
|
385 |
if not isinstance(validation_downloaded_files, (tuple, list)):
|
386 |
validation_downloaded_files = [validation_downloaded_files]
|
387 |
else:
|
run_mlm_flax_stream.py
CHANGED
@@ -402,7 +402,7 @@ if __name__ == "__main__":
|
|
402 |
boundaries=sampling_args.boundaries,
|
403 |
perplexity_model=sampling_args.perplexity_model,
|
404 |
seed=training_args.seed,
|
405 |
-
|
406 |
)
|
407 |
|
408 |
if model_args.config_name:
|
|
|
402 |
boundaries=sampling_args.boundaries,
|
403 |
perplexity_model=sampling_args.perplexity_model,
|
404 |
seed=training_args.seed,
|
405 |
+
data_files=filepaths,
|
406 |
)
|
407 |
|
408 |
if model_args.config_name:
|