diff --git a/README.md b/README.md index d8ee206de..466d78d51 100644 --- a/README.md +++ b/README.md @@ -312,6 +312,26 @@ will spawn a number of threads in order to optimize performance for iRODS server versions 4.2.9+ and file sizes larger than a default threshold value of 32 Megabytes. +Because multithread processes under Unix-type operating systems sometimes +need special handling, it is recommended that any put or get of a large file +be appropriately handled in the case that a terminating signal aborts the +transfer: + +```python +from irods.parallel import abort_parallel_transfers + +def handler(*arguments): + abort_parallel_transfers() + +signal(SIGINT,handler) + +try: + # a multi-1247 put or get can leave non-daemon threads running if not treated with care. + session.data_objects.put( ...) +except KeyboardInterrupt: + abort_parallel_transfers() +``` + Progress bars ------------- diff --git a/irods/manager/data_object_manager.py b/irods/manager/data_object_manager.py index f2c5ed31b..42fe15f4c 100644 --- a/irods/manager/data_object_manager.py +++ b/irods/manager/data_object_manager.py @@ -131,12 +131,13 @@ def __init__(self, *a, **kwd): self._iRODS_session = kwd.pop("_session", None) super(ManagedBufferedRandom, self).__init__(*a, **kwd) import irods.session + self.do_close = True with irods.session._fds_lock: irods.session._fds[self] = None def __del__(self): - if not self.closed: + if self.do_close and not self.closed: self.close() call___del__if_exists(super(ManagedBufferedRandom, self)) diff --git a/irods/parallel.py b/irods/parallel.py index 2ad03492d..293d2bf28 100644 --- a/irods/parallel.py +++ b/irods/parallel.py @@ -9,13 +9,29 @@ import concurrent.futures import threading import multiprocessing -from typing import List, Union +from typing import List, Union, Any +import weakref from irods.data_object import iRODSDataObject from irods.exception import DataObjectDoesNotExist import irods.keywords as kw from queue import Queue, Full, Empty +paths_active: weakref.WeakValueDictionary[str,"AsyncNotify"] = weakref.WeakValueDictionary() +transfer_managers: weakref.WeakKeyDictionary["_Multipart_close_manager", Any] = weakref.WeakKeyDictionary() + +def abort_parallel_transfers(dry_run=False, filter_function=None): + """'cls' should be tuple to extract the current synchronous transfer.""" + mgrs = dict(filter(filter_function, transfer_managers.items())) + if not dry_run: + for mgr, item in mgrs.items(): + if isinstance(item,tuple): + quit_func,args = item[:2] + quit_func(*args) + else: + mgr.quit() + return mgrs + logger = logging.getLogger(__name__) _nullh = logging.NullHandler() @@ -91,9 +107,11 @@ def __init__( for future in self._futures: future.add_done_callback(self) else: - self.__invoke_done_callback() + self.__invoke_futures_done_logic() + return self.progress = [0, 0] + if (progress_Queue) and (total is not None): self.progress[1] = total @@ -112,7 +130,7 @@ def _progress(Q, this): # - thread to update progress indicator self._progress_fn = _progress self._progress_thread = threading.Thread( - target=self._progress_fn, args=(progress_Queue, self) + target=self._progress_fn, args=(progress_Queue, self), daemon=True ) self._progress_thread.start() @@ -153,11 +171,14 @@ def __call__( with self._lock: self._futures_done[future] = future.result() if len(self._futures) == len(self._futures_done): - self.__invoke_done_callback() + # If a future returns None rather than an integer byte count, it has aborted the transfer. + self.__invoke_futures_done_logic( + skip_user_callback=(None in self._futures_done.values()) + ) - def __invoke_done_callback(self): + def __invoke_futures_done_logic(self, skip_user_callback=False): try: - if callable(self.done_callback): + if not skip_user_callback and callable(self.done_callback): self.done_callback(self) finally: self.keep.pop("mgr", None) @@ -240,6 +261,12 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()): bytecount = 0 accum = 0 while True and bytecount < length: + if mgr._quit: + # Indicate by the return value that we are aborting (this part of) the data transfer. + # In the great majority of cases, this should be seen by the application as an overall + # abort of the PUT or GET of the requested object. + bytecount = None + break buf = src.read(min(COPY_BUF_SIZE, length - bytecount)) buf_len = len(buf) if 0 == buf_len: @@ -274,11 +301,39 @@ class _Multipart_close_manager: """ - def __init__(self, initial_io_, exit_barrier_): + def __init__(self, initial_io_, exit_barrier_, executor = None): + self._quit = False self.exit_barrier = exit_barrier_ self.initial_io = initial_io_ self.__lock = threading.Lock() self.aux = [] + self.futures = set() + self.executor = executor + + def add_future(self, future): self.futures.add(future) + + @property + def active_futures(self): + return tuple(_ for _ in self.futures if not _.done()) + + def shutdown(self): + if self.executor: + self.executor.shutdown(cancel_futures = True) + + def quit(self): + from irods.manager.data_object_manager import ManagedBufferedRandom + # remove all descriptors from consideration for auto_close. + import irods.session + with irods.session._fds_lock: + for fd in self.aux + [self.initial_io]: + irods.session._fds.pop(fd, ()) + if type(fd) is ManagedBufferedRandom: + fd.do_close = False + # abort threads. + self._quit = True + self.exit_barrier.abort() + self.shutdown() + return self.active_futures def __contains__(self, Io): with self.__lock: @@ -297,6 +352,7 @@ def add_io(self, Io): # synchronizes all of the parallel threads just before exit, so that we know # exactly when to perform a finalizing close on the data object + def remove_io(self, Io): is_initial = True with self.__lock: @@ -304,8 +360,12 @@ def remove_io(self, Io): Io.close() self.aux.remove(Io) is_initial = False - self.exit_barrier.wait() - if is_initial: + broken = False + try: + self.exit_barrier.wait() + except threading.BrokenBarrierError: + broken = True + if is_initial and not (broken or self._quit): self.finalize() def finalize(self): @@ -393,7 +453,7 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): futures = [] executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) num_threads = min(num_threads, len(ranges)) - mgr = _Multipart_close_manager(Io, Barrier(num_threads)) + mgr = _Multipart_close_manager(Io, Barrier(num_threads), executor) counter = 1 gen_file_handle = lambda: open( fname, Operation.disk_file_mode(initial_open=(counter == 1)) @@ -405,49 +465,86 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): "queueObject": queueObject, } - for byte_range in ranges: - if Io is None: - Io = session.data_objects.open( - Data_object.path, - Operation.data_object_mode(initial_open=False), - create=False, - finalize_on_close=False, - allow_redirect=False, - **{ - kw.NUM_THREADS_KW: str(num_threads), - kw.DATA_SIZE_KW: str(total_size), - kw.RESC_HIER_STR_KW: hier_str, - kw.REPLICA_TOKEN_KW: replica_token, - } - ) - mgr.add_io(Io) - logger.debug("target_host = %s", Io.raw.session.pool.account.host) - if File is None: - File = gen_file_handle() - futures.append( - executor.submit( - _io_part, - Io, - byte_range, - File, - Operation, - mgr, - thread_debug_id=str(counter), - **thread_opts - ) - ) - counter += 1 - Io = File = None + transfer_managers[mgr] = (_quit_current_transfer, [id(mgr)]) - if Operation.isNonBlocking(): - if queueLength: - return futures, queueObject, mgr + try: + transfer_aborted = False + + for byte_range in ranges: + if Io is None: + Io = session.data_objects.open( + Data_object.path, + Operation.data_object_mode(initial_open=False), + create=False, + finalize_on_close=False, + allow_redirect=False, + **{ + kw.NUM_THREADS_KW: str(num_threads), + kw.DATA_SIZE_KW: str(total_size), + kw.RESC_HIER_STR_KW: hier_str, + kw.REPLICA_TOKEN_KW: replica_token, + } + ) + mgr.add_io(Io) + logger.debug("target_host = %s", Io.raw.session.pool.account.host) + if File is None: + File = gen_file_handle() + try: + f = None + futures.append( + f := executor.submit( + _io_part, + Io, + byte_range, + File, + Operation, + mgr, + thread_debug_id=str(counter), + **thread_opts + ) + ) + except RuntimeError as error: + # Executor was probably shut down before parallel transfer could be initiated. + transfer_aborted = True + break + else: + mgr.add_future(f) + + counter += 1 + Io = File = None + + if transfer_aborted: + return ((bytes_transferred:=0), total_size) + + if Operation.isNonBlocking(): + transfer_managers[mgr] = None + return (futures, mgr, queueObject) else: - return futures - else: - bytecounts = [f.result() for f in futures] - return sum(bytecounts), total_size + bytes_transferred = 0 + # Enable user attempts to cancel the current synchronous transfer. + # At any given time, only one transfer manager key should map to a tuple object T. + # You should be able to quit all threads of the current transfer by calling T[0](*T[1]). + bytecounts = [f.result() for f in futures] + # If, rather than an integer byte-count, the "None" object was included as one of futures' return values, this + # is an indication that the PUT or GET operation should be marked as aborted, i.e. no bytes transferred. + if None not in bytecounts: + bytes_transferred = sum(bytecounts) + + return (bytes_transferred, total_size) + + except BaseException as e: + # TODO - examine this experimentally restored code, as + # library should react to these two exception types(and perhaps others) by quitting all transfer threads + + if isinstance(e, (SystemExit, KeyboardInterrupt)): + mgr.quit() + raise + +def _quit_current_transfer(obj_id): + l = [_ for _ in transfer_managers if id(_) == obj_id] + if l: + l[0].quit() def io_main(session, Data, opr_, fname, R="", **kwopt): """ @@ -559,18 +656,21 @@ def io_main(session, Data, opr_, fname, R="", **kwopt): if Operation.isNonBlocking(): - if queueLength > 0: - (futures, chunk_notify_queue, mgr) = retval - else: - futures = retval - chunk_notify_queue = total_bytes = None +# if queueLength > 0: + (futures, mgr, chunk_notify_queue) = retval +# else: +# futures = retval + # TODO: investigate: Huh? Why were we zeroing out total_bytes when there is no progress queue? + #chunk_notify_queue = total_bytes = None - return AsyncNotify( + transfer_managers[mgr] = Data.path + paths_active[Data.path] = async_notify = AsyncNotify( futures, # individual futures, one per transfer thread progress_Queue=chunk_notify_queue, # for notifying the progress indicator thread total=total_bytes, # total number of bytes for parallel transfer keep_={"mgr": mgr}, ) # an open raw i/o object needing to be persisted, if any + return async_notify else: (_bytes_transferred, _bytes_total) = retval return _bytes_transferred == _bytes_total diff --git a/irods/test/data_obj_test.py b/irods/test/data_obj_test.py index 071771717..6dd3663c1 100644 --- a/irods/test/data_obj_test.py +++ b/irods/test/data_obj_test.py @@ -3320,6 +3320,14 @@ def test_access_time__issue_700(self): # Test that access_time is there, and of the right type. self.assertIs(type(data.access_time), datetime) + def test_handling_of_termination_signals_during_multithread_get__issue_722(self): + from irods.test.modules.test_signal_handling_in_multithread_get import ( + test as test__issue_722, + ) + + test__issue_722(self) + + if __name__ == "__main__": # let the tests find the parent irods lib sys.path.insert(0, os.path.abspath("../..")) diff --git a/irods/test/modules/test_signal_handling_in_multithread_get.py b/irods/test/modules/test_signal_handling_in_multithread_get.py new file mode 100644 index 000000000..0a5f1863d --- /dev/null +++ b/irods/test/modules/test_signal_handling_in_multithread_get.py @@ -0,0 +1,151 @@ +import os +import re +import signal +import subprocess +import sys +import tempfile +import time + +import irods.helpers +from irods.test import modules as test_modules +from irods.parallel import abort_parallel_transfers + +OBJECT_SIZE = 2 * 1024**3 +OBJECT_NAME = "data_get_issue__722" +LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat" + + +_clock_polling_interval = max(0.01, time.clock_getres(time.CLOCK_BOOTTIME)) + +LARGE_TEST_TIMEOUT = (10 * 60.0) # ten minutes. + +def wait_till_true(function, timeout=LARGE_TEST_TIMEOUT, msg = ''): + """Wait for test purposes until a condition becomes true , as determined by the + return value of the provided test function. + + By default, we wait at most LARGE_TEST_TIMEOUT seconds for the function to return true, and then + quit or time out. Alternatively, a timeout of None translates as a request never to time out. + + If the msg value passed in is a nonzero-length string, it can be used to raise a timeout exception; + otherwise timing out causes a normal exit, relaying as the return value the last value returned + from the test function. + """ + start_time = time.clock_gettime_ns(time.CLOCK_BOOTTIME) + while not (truth_value := function()): + if ( + timeout is not None + and (time.clock_gettime_ns(time.CLOCK_BOOTTIME) - start_time) * 1e-9 + > timeout + ): + if msg: + raise TimeoutError(msg) + else: + break + time.sleep(_clock_polling_interval) + return truth_value + + +def test(test_case, signal_names=("SIGTERM", "SIGINT")): + """Creates a child process executing a long get() and ensures the process can be + terminated using SIGINT or SIGTERM. + """ + program = os.path.join(test_modules.__path__[0], os.path.basename(__file__)) + + for signal_name in signal_names: + + with test_case.subTest(f"Testing with signal {signal_name}"): + + # Call into this same module as a command. This will initiate another Python process that + # performs a lengthy data object "get" operation (see the main body of the script, below.) + process = subprocess.Popen( + [sys.executable, program], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + ) + + # Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions + # of the concurrent.futures module, these are nondaemon threads and will block the exit of the main thread + # unless measures are taken (#722). + localfile = process.stdout.readline().strip() + # Use timeout of 10 minutes for test transfer, which should be more than enough. + test_case.assertTrue( + wait_till_true( + lambda: os.path.exists(localfile) + and os.stat(localfile).st_size > OBJECT_SIZE // 2, + ), + "Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.", + ) + + sig = getattr(signal, signal_name) + + signal_offset_return_code = lambda s: 128 - s if s < 0 else s + signal_plus_128 = lambda sig: 128 + sig + + # Interrupt the subprocess with the given signal. + process.send_signal(sig) + + # Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit + # due to misproper or incomplete handling of the signal. + try: + translated_return_code = signal_offset_return_code(process.wait(timeout=15)) + test_case.assertEqual( + translated_return_code, + signal_plus_128(sig), + f"Expected subprocess return code of {signal_plus_128(sig) = }; got {translated_return_code = }", + ) + except subprocess.TimeoutExpired as timeout_exc: + test_case.fail( + f"Subprocess timed out before terminating. " + "Non-daemon thread(s) probably prevented subprocess's main thread from exiting." + ) + # Assert that in the case of SIGINT, the process registered a KeyboardInterrupt. + if sig == signal.SIGINT: + test_case.assertTrue( + re.search("KeyboardInterrupt", process.stderr.read()), + "Did not find expected string 'KeyboardInterrupt' in log output.", + ) + + +if __name__ == "__main__": + # These lines are run only if the module is launched as a process. + session = irods.helpers.make_session() + hc = irods.helpers.home_collection(session) + TESTFILE_FILL = b"_" * (1024 * 1024) + object_path = f"{hc}/{OBJECT_NAME}" + + # Create the object to be downloaded. + with session.data_objects.open(object_path, "w") as f: + for y in range(OBJECT_SIZE // len(TESTFILE_FILL)): + f.write(TESTFILE_FILL) + local_path = None + # Establish where (ie absolute path) to place the downloaded file, i.e. the get() target. + try: + with tempfile.NamedTemporaryFile( + prefix="local_file_issue_722.dat", delete=True + ) as t: + local_path = t.name + + # Tell the parent process the name of the local file being "get"ted (got) from iRODS + print(local_path) + sys.stdout.flush() + + def handler(sig,*_): + abort_parallel_transfers() + exit(128+sig) + + signal.signal(signal.SIGTERM, handler) + + try: + # download the object + session.data_objects.get(object_path, local_path) + except KeyboardInterrupt: + abort_parallel_transfers() + raise + + finally: + # Clean up, whether or not the download succeeded. + if local_path is not None and os.path.exists(local_path): + os.unlink(local_path) + if session.data_objects.exists(object_path): + session.data_objects.unlink(object_path, force=True) diff --git a/irods/test/multithread_put_test.py b/irods/test/multithread_put_test.py new file mode 100644 index 000000000..55abcbead --- /dev/null +++ b/irods/test/multithread_put_test.py @@ -0,0 +1,86 @@ +import os +import re +from signal import setitimer, SIGALRM, signal, SIG_DFL, ITIMER_REAL, SIGUSR1 +import subprocess +import sys +import tempfile +import threading +import time +import unittest +import irods.test.helpers +from irods.parallel import abort_parallel_transfers + +OBJECT_SIZE = 3 * 1024**3 +TESTFILE_FILL = b"_" * (1024 * 1024) +OBJECT_NAME = "data_put_issue__722" +LOCAL_TEMPFILE_NAME = "data_object_to_put_issue_722.dat" + +def wait_until_condition_true(func, interval, sleep=.1): + t0 = time.time() + while ((t:=time.time()) - t0) < interval: + if (value:=func()): break + time.sleep(sleep) + return value + +class Test(unittest.TestCase): + + def test_put__issue_722(self): + signal_names=("SIGTERM", "SIGINT") + + with tempfile.NamedTemporaryFile(mode="wb") as f: # Create the object to be uploaded. + for y in range(OBJECT_SIZE // len(TESTFILE_FILL)): + f.write(TESTFILE_FILL) + local_path = f.name + + def _abort_them(*_): + print ("aborted1") + abort_parallel_transfers() + print ("aborted2") + + if True: + #with test_case.subTest(f"Testing with signal {signal_name}"): +# signal(SIGALRM, +# _abort_them) + signal(SIGUSR1, + _abort_them) + + session = irods.helpers.make_session() + hc = irods.helpers.home_collection(session) + object_path = f"{hc}/put_target_issue_722_{irods.test.helpers.unique_name(time.time())}" + + # Establish where (ie absolute path) to place the downloaded file, i.e. the get() target. + try: + + # Tell the parent process the name of the local file being "get"ted (got) from iRODS + + tsession = session.clone() + data_object_exists = lambda:tsession.data_objects.exists(object_path) + pid=os.getpid() + dc = {} + def signal_after_object_exists(): + while not data_object_exists(): + time.sleep(.01) + nonlocal dc + dc = abort_parallel_transfers(dry_run = True) + print("waiting for futures before kill sig") + while not [m for m in dc.values() if m.futures]: + time.sleep(.01) + print("killsent") + os.kill(pid,SIGUSR1) + + (t:=threading.Thread(target = signal_after_object_exists)).start() + session. data_objects. put(local_path, object_path) + + + # Assert that transfer threads terminate. +# self.assertTrue( +# wait_until_condition_true( +# lambda: threading.enumerate() == [threading.current_thread()], +# 10*60.0)) + print(f'{threading.enumerate() = }') + finally: + # Clean up, whether or not the download succeeded. + pass +# if session.data_objects.exists(object_path): +# session.data_objects.unlink(object_path, force=True) +# ))