File size: 2,358 Bytes
96ee597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Base tokenizer class.

Copyright PolyAI Limited.
"""
import os
from asyncio import as_completed
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

from utils import measure_duration


class BaseTokenizer:
    @measure_duration
    def encode_files_with_model_seq(
            self, folder_path: str, destination_folder: str):
        # Ensure destination folder exists
        if not os.path.exists(destination_folder):
            os.makedirs(destination_folder)

        # Go through each file in the folder
        filenames = os.listdir(folder_path)
        # encoding files has no side effects
        for filename in tqdm(filenames):
            self.encode_file(
                folder_path=folder_path,
                destination_folder=destination_folder,
                filename=filename,
            )

    def get_chunk(self, folder_path, start_percent=0, end_percent=100):
        filenames = os.listdir(folder_path)
        total_files = len(filenames)

        start_idx = int(total_files * (start_percent / 100))
        end_idx = int(total_files * (end_percent / 100))

        return filenames[start_idx:end_idx]

    @measure_duration
    def encode_files_with_model_concurrent(
        self, folder_path: str, destination_folder: str, start_percent: int,
        end_percent: int,
    ):
        # Ensure destination folder exists
        if not os.path.exists(destination_folder):
            os.makedirs(destination_folder)

        # Go through each file in the folder
        filenames = self.get_chunk(folder_path, start_percent, end_percent)

        # encoding files has no side effects
        with ThreadPoolExecutor(max_workers=40) as executor:
            futures = [
                executor.submit(
                    self.encode_file,
                    folder_path=folder_path,
                    destination_folder=destination_folder,
                    filename=filename,
                )
                for filename in filenames
            ]
            # Wait for all tasks to complete
            for future in as_completed(futures):
                future.result()

            # Explicitly shut down the thread pool
            executor.shutdown()

    def encode_file(
            self, folder_path: str, destination_folder: str, filename: str):
        raise NotImplementedError