You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					106 lines
				
				3.9 KiB
			
		
		
			
		
	
	
					106 lines
				
				3.9 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								"""
							 | 
						||
| 
								 | 
							
								Thin wrappers around `concurrent.futures`.
							 | 
						||
| 
								 | 
							
								"""
							 | 
						||
| 
								 | 
							
								from contextlib import contextmanager
							 | 
						||
| 
								 | 
							
								from operator import length_hint
							 | 
						||
| 
								 | 
							
								from os import cpu_count
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from ..auto import tqdm as tqdm_auto
							 | 
						||
| 
								 | 
							
								from ..std import TqdmWarning
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								__author__ = {"github.com/": ["casperdcl"]}
							 | 
						||
| 
								 | 
							
								__all__ = ['thread_map', 'process_map']
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@contextmanager
							 | 
						||
| 
								 | 
							
								def ensure_lock(tqdm_class, lock_name=""):
							 | 
						||
| 
								 | 
							
								    """get (create if necessary) and then restore `tqdm_class`'s lock"""
							 | 
						||
| 
								 | 
							
								    old_lock = getattr(tqdm_class, '_lock', None)  # don't create a new lock
							 | 
						||
| 
								 | 
							
								    lock = old_lock or tqdm_class.get_lock()  # maybe create a new lock
							 | 
						||
| 
								 | 
							
								    lock = getattr(lock, lock_name, lock)  # maybe subtype
							 | 
						||
| 
								 | 
							
								    tqdm_class.set_lock(lock)
							 | 
						||
| 
								 | 
							
								    yield lock
							 | 
						||
| 
								 | 
							
								    if old_lock is None:
							 | 
						||
| 
								 | 
							
								        del tqdm_class._lock
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        tqdm_class.set_lock(old_lock)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Implementation of `thread_map` and `process_map`.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Parameters
							 | 
						||
| 
								 | 
							
								    ----------
							 | 
						||
| 
								 | 
							
								    tqdm_class  : [default: tqdm.auto.tqdm].
							 | 
						||
| 
								 | 
							
								    max_workers  : [default: min(32, cpu_count() + 4)].
							 | 
						||
| 
								 | 
							
								    chunksize  : [default: 1].
							 | 
						||
| 
								 | 
							
								    lock_name  : [default: "":str].
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    kwargs = tqdm_kwargs.copy()
							 | 
						||
| 
								 | 
							
								    if "total" not in kwargs:
							 | 
						||
| 
								 | 
							
								        kwargs["total"] = length_hint(iterables[0])
							 | 
						||
| 
								 | 
							
								    tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
							 | 
						||
| 
								 | 
							
								    max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
							 | 
						||
| 
								 | 
							
								    chunksize = kwargs.pop("chunksize", 1)
							 | 
						||
| 
								 | 
							
								    lock_name = kwargs.pop("lock_name", "")
							 | 
						||
| 
								 | 
							
								    with ensure_lock(tqdm_class, lock_name=lock_name) as lk:
							 | 
						||
| 
								 | 
							
								        # share lock in case workers are already using `tqdm`
							 | 
						||
| 
								 | 
							
								        with PoolExecutor(max_workers=max_workers, initializer=tqdm_class.set_lock,
							 | 
						||
| 
								 | 
							
								                          initargs=(lk,)) as ex:
							 | 
						||
| 
								 | 
							
								            return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def thread_map(fn, *iterables, **tqdm_kwargs):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Equivalent of `list(map(fn, *iterables))`
							 | 
						||
| 
								 | 
							
								    driven by `concurrent.futures.ThreadPoolExecutor`.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Parameters
							 | 
						||
| 
								 | 
							
								    ----------
							 | 
						||
| 
								 | 
							
								    tqdm_class  : optional
							 | 
						||
| 
								 | 
							
								        `tqdm` class to use for bars [default: tqdm.auto.tqdm].
							 | 
						||
| 
								 | 
							
								    max_workers  : int, optional
							 | 
						||
| 
								 | 
							
								        Maximum number of workers to spawn; passed to
							 | 
						||
| 
								 | 
							
								        `concurrent.futures.ThreadPoolExecutor.__init__`.
							 | 
						||
| 
								 | 
							
								        [default: max(32, cpu_count() + 4)].
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    from concurrent.futures import ThreadPoolExecutor
							 | 
						||
| 
								 | 
							
								    return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def process_map(fn, *iterables, **tqdm_kwargs):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Equivalent of `list(map(fn, *iterables))`
							 | 
						||
| 
								 | 
							
								    driven by `concurrent.futures.ProcessPoolExecutor`.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Parameters
							 | 
						||
| 
								 | 
							
								    ----------
							 | 
						||
| 
								 | 
							
								    tqdm_class  : optional
							 | 
						||
| 
								 | 
							
								        `tqdm` class to use for bars [default: tqdm.auto.tqdm].
							 | 
						||
| 
								 | 
							
								    max_workers  : int, optional
							 | 
						||
| 
								 | 
							
								        Maximum number of workers to spawn; passed to
							 | 
						||
| 
								 | 
							
								        `concurrent.futures.ProcessPoolExecutor.__init__`.
							 | 
						||
| 
								 | 
							
								        [default: min(32, cpu_count() + 4)].
							 | 
						||
| 
								 | 
							
								    chunksize  : int, optional
							 | 
						||
| 
								 | 
							
								        Size of chunks sent to worker processes; passed to
							 | 
						||
| 
								 | 
							
								        `concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
							 | 
						||
| 
								 | 
							
								    lock_name  : str, optional
							 | 
						||
| 
								 | 
							
								        Member of `tqdm_class.get_lock()` to use [default: mp_lock].
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    from concurrent.futures import ProcessPoolExecutor
							 | 
						||
| 
								 | 
							
								    if iterables and "chunksize" not in tqdm_kwargs:
							 | 
						||
| 
								 | 
							
								        # default `chunksize=1` has poor performance for large iterables
							 | 
						||
| 
								 | 
							
								        # (most time spent dispatching items to workers).
							 | 
						||
| 
								 | 
							
								        longest_iterable_len = max(map(length_hint, iterables))
							 | 
						||
| 
								 | 
							
								        if longest_iterable_len > 1000:
							 | 
						||
| 
								 | 
							
								            from warnings import warn
							 | 
						||
| 
								 | 
							
								            warn("Iterable length %d > 1000 but `chunksize` is not set."
							 | 
						||
| 
								 | 
							
								                 " This may seriously degrade multiprocess performance."
							 | 
						||
| 
								 | 
							
								                 " Set `chunksize=1` or more." % longest_iterable_len,
							 | 
						||
| 
								 | 
							
								                 TqdmWarning, stacklevel=2)
							 | 
						||
| 
								 | 
							
								    if "lock_name" not in tqdm_kwargs:
							 | 
						||
| 
								 | 
							
								        tqdm_kwargs = tqdm_kwargs.copy()
							 | 
						||
| 
								 | 
							
								        tqdm_kwargs["lock_name"] = "mp_lock"
							 | 
						||
| 
								 | 
							
								    return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
							 |