-
Notifications
You must be signed in to change notification settings - Fork 74
[#722] fix segfault and hung threads on KeyboardIinterrupt during parallel get #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a3dd06e
481952c
1213d2b
59fd131
ba0407a
dddcc95
b96f805
ed25eda
9d2ff7a
0aaf747
ade42ea
8272b5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -131,12 +131,13 @@ def __init__(self, *a, **kwd): | |
| self._iRODS_session = kwd.pop("_session", None) | ||
| super(ManagedBufferedRandom, self).__init__(*a, **kwd) | ||
| import irods.session | ||
| self.no_close = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid the use of negatives in variable names where possible for clarity? Perhaps this could be |
||
|
|
||
| with irods.session._fds_lock: | ||
| irods.session._fds[self] = None | ||
|
|
||
| def __del__(self): | ||
| if not self.closed: | ||
| if not self.no_close and not self.closed: | ||
| self.close() | ||
| call___del__if_exists(super(ManagedBufferedRandom, self)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,14 +9,24 @@ | |
| import concurrent.futures | ||
| import threading | ||
| import multiprocessing | ||
| from typing import List, Union | ||
| from typing import List, Union, Any | ||
| import weakref | ||
|
|
||
| from irods.data_object import iRODSDataObject | ||
| from irods.exception import DataObjectDoesNotExist | ||
| import irods.keywords as kw | ||
| from queue import Queue, Full, Empty | ||
|
|
||
|
|
||
| transfer_managers: weakref.WeakKeyDictionary["_Multipart_close_manager", Any] = weakref.WeakKeyDictionary() | ||
|
|
||
| def abort_parallel_transfers(dry_run = False): | ||
| if not dry_run: | ||
| for mgr in transfer_managers: | ||
| mgr.quit() | ||
| return dict(transfer_managers) | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
| _nullh = logging.NullHandler() | ||
| logger.addHandler(_nullh) | ||
|
|
@@ -91,9 +101,11 @@ def __init__( | |
| for future in self._futures: | ||
| future.add_done_callback(self) | ||
| else: | ||
| self.__invoke_done_callback() | ||
| self.__invoke_futures_done_logic() | ||
| return | ||
|
|
||
| self.progress = [0, 0] | ||
|
|
||
| if (progress_Queue) and (total is not None): | ||
| self.progress[1] = total | ||
|
|
||
|
|
@@ -112,7 +124,7 @@ def _progress(Q, this): # - thread to update progress indicator | |
|
|
||
| self._progress_fn = _progress | ||
| self._progress_thread = threading.Thread( | ||
| target=self._progress_fn, args=(progress_Queue, self) | ||
| target=self._progress_fn, args=(progress_Queue, self), daemon=True | ||
| ) | ||
| self._progress_thread.start() | ||
|
|
||
|
|
@@ -153,11 +165,13 @@ def __call__( | |
| with self._lock: | ||
| self._futures_done[future] = future.result() | ||
| if len(self._futures) == len(self._futures_done): | ||
| self.__invoke_done_callback() | ||
| self.__invoke_futures_done_logic( | ||
| skip_user_callback=(None in self._futures_done.values()) | ||
| ) | ||
|
|
||
| def __invoke_done_callback(self): | ||
| def __invoke_futures_done_logic(self, skip_user_callback=False): | ||
| try: | ||
| if callable(self.done_callback): | ||
| if not skip_user_callback and callable(self.done_callback): | ||
| self.done_callback(self) | ||
| finally: | ||
| self.keep.pop("mgr", None) | ||
|
|
@@ -240,6 +254,10 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()): | |
| bytecount = 0 | ||
| accum = 0 | ||
| while True and bytecount < length: | ||
| print (('T' if mgr._quit else 'F'), end = '', flush=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove? |
||
| if mgr._quit: | ||
| bytecount = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So a return value of |
||
| break | ||
| buf = src.read(min(COPY_BUF_SIZE, length - bytecount)) | ||
| buf_len = len(buf) | ||
| if 0 == buf_len: | ||
|
|
@@ -274,11 +292,39 @@ class _Multipart_close_manager: | |
|
|
||
| """ | ||
|
|
||
| def __init__(self, initial_io_, exit_barrier_): | ||
| def __init__(self, initial_io_, exit_barrier_, executor = None): | ||
| self._quit = False | ||
| self.exit_barrier = exit_barrier_ | ||
| self.initial_io = initial_io_ | ||
| self.__lock = threading.Lock() | ||
| self.aux = [] | ||
| self.futures = set() | ||
| self.executor = executor | ||
|
|
||
| def add_future(self, future): self.futures.add(future) | ||
|
|
||
| @property | ||
| def active_futures(self): | ||
| return tuple(_ for _ in self.futures if not _.done()) | ||
|
|
||
| def shutdown(self): | ||
| if self.executor: | ||
| self.executor.shutdown(cancel_futures = True) | ||
|
|
||
| def quit(self): | ||
| from irods.manager.data_object_manager import ManagedBufferedRandom | ||
| # remove all descriptors from consideration for auto_close. | ||
| import irods.session | ||
| with irods.session._fds_lock: | ||
| for fd in self.aux + [self.initial_io]: | ||
| irods.session._fds.pop(fd, ()) | ||
| if type(fd) is ManagedBufferedRandom: | ||
| fd.no_close = True | ||
| # abort threads. | ||
| self._quit = True | ||
| self.exit_barrier.abort() | ||
| self.shutdown() | ||
| return self.active_futures | ||
|
|
||
| def __contains__(self, Io): | ||
| with self.__lock: | ||
|
|
@@ -297,15 +343,20 @@ def add_io(self, Io): | |
| # synchronizes all of the parallel threads just before exit, so that we know | ||
| # exactly when to perform a finalizing close on the data object | ||
|
|
||
|
|
||
| def remove_io(self, Io): | ||
| is_initial = True | ||
| with self.__lock: | ||
| if Io is not self.initial_io: | ||
| Io.close() | ||
| self.aux.remove(Io) | ||
| is_initial = False | ||
| self.exit_barrier.wait() | ||
| if is_initial: | ||
| broken = False | ||
| try: | ||
| self.exit_barrier.wait() | ||
| except threading.BrokenBarrierError: | ||
| broken = True | ||
| if is_initial and not (broken or self._quit): | ||
| self.finalize() | ||
|
|
||
| def finalize(self): | ||
|
|
@@ -393,7 +444,7 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): | |
| futures = [] | ||
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) | ||
| num_threads = min(num_threads, len(ranges)) | ||
| mgr = _Multipart_close_manager(Io, Barrier(num_threads)) | ||
| mgr = _Multipart_close_manager(Io, Barrier(num_threads), executor) | ||
| counter = 1 | ||
| gen_file_handle = lambda: open( | ||
| fname, Operation.disk_file_mode(initial_open=(counter == 1)) | ||
|
|
@@ -425,7 +476,7 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): | |
| if File is None: | ||
| File = gen_file_handle() | ||
| futures.append( | ||
| executor.submit( | ||
| f := executor.submit( | ||
| _io_part, | ||
| Io, | ||
| byte_range, | ||
|
|
@@ -436,17 +487,26 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): | |
| **thread_opts | ||
| ) | ||
| ) | ||
| mgr.add_future(f) | ||
| counter += 1 | ||
| Io = File = None | ||
|
|
||
| if Operation.isNonBlocking(): | ||
| if queueLength: | ||
| return futures, queueObject, mgr | ||
| else: | ||
| return futures | ||
| return futures, queueObject, mgr | ||
| else: | ||
| bytecounts = [f.result() for f in futures] | ||
| return sum(bytecounts), total_size | ||
| bytes_transferred = 0 | ||
| try: | ||
| transfer_managers[mgr] = 1 | ||
| bytecounts = [f.result() for f in futures] | ||
| if None not in bytecounts: | ||
| bytes_transferred = sum(bytecounts) | ||
| except (KeyboardInterrupt, #SystemExit | ||
| ): | ||
| print ('\nraising KBI\n') | ||
| raise | ||
| finally: | ||
| pass | ||
|
Comment on lines
+503
to
+508
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this debugging? |
||
| return bytes_transferred, total_size | ||
|
|
||
|
|
||
| def io_main(session, Data, opr_, fname, R="", **kwopt): | ||
|
|
@@ -559,10 +619,10 @@ def io_main(session, Data, opr_, fname, R="", **kwopt): | |
|
|
||
| if Operation.isNonBlocking(): | ||
|
|
||
| if queueLength > 0: | ||
| (futures, chunk_notify_queue, mgr) = retval | ||
| else: | ||
| futures = retval | ||
| (futures, chunk_notify_queue, mgr) = retval | ||
| transfer_managers[mgr] = None | ||
|
|
||
| if queueLength <= 0: | ||
| chunk_notify_queue = total_bytes = None | ||
|
|
||
| return AsyncNotify( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| import os | ||
| import re | ||
| import signal | ||
| import subprocess | ||
| import sys | ||
| import tempfile | ||
| import time | ||
|
|
||
| import irods.helpers | ||
| from irods.test import modules as test_modules | ||
| from irods.parallel import abort_parallel_transfers | ||
|
|
||
| OBJECT_SIZE = 2 * 1024**3 | ||
| OBJECT_NAME = "data_get_issue__722" | ||
| LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat" | ||
|
|
||
|
|
||
| _clock_polling_interval = max(0.01, time.clock_getres(time.CLOCK_BOOTTIME)) | ||
|
|
||
|
|
||
| def wait_till_true(function, timeout=None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just in case... can we set the default |
||
| start_time = time.clock_gettime_ns(time.CLOCK_BOOTTIME) | ||
| while not (truth_value := function()): | ||
| if ( | ||
| timeout is not None | ||
| and (time.clock_gettime_ns(time.CLOCK_BOOTTIME) - start_time) * 1e-9 | ||
| > timeout | ||
| ): | ||
| break | ||
| time.sleep(_clock_polling_interval) | ||
| return truth_value | ||
|
|
||
|
|
||
| def test(test_case, signal_names=("SIGTERM", "SIGINT")): | ||
| """Creates a child process executing a long get() and ensures the process can be | ||
| terminated using SIGINT or SIGTERM. | ||
| """ | ||
| program = os.path.join(test_modules.__path__[0], os.path.basename(__file__)) | ||
|
|
||
| for signal_name in signal_names: | ||
|
|
||
| with test_case.subTest(f"Testing with signal {signal_name}"): | ||
|
|
||
| # Call into this same module as a command. This will initiate another Python process that | ||
| # performs a lengthy data object "get" operation (see the main body of the script, below.) | ||
| process = subprocess.Popen( | ||
| [sys.executable, program], | ||
| stderr=subprocess.PIPE, | ||
| stdout=subprocess.PIPE, | ||
| text=True, | ||
| ) | ||
|
|
||
| # Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions | ||
| # of the concurrent.futures module, these are nondaemon threads and will block the exit of the main thread | ||
| # unless measures are taken (#722). | ||
| localfile = process.stdout.readline().strip() | ||
| test_case.assertTrue( | ||
| wait_till_true( | ||
| lambda: os.path.exists(localfile) | ||
| and os.stat(localfile).st_size > OBJECT_SIZE // 2 | ||
| ), | ||
| "Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.", | ||
| ) | ||
|
|
||
| sig = getattr(signal, signal_name) | ||
|
|
||
| signal_offset_return_code = lambda s: 128 - s if s < 0 else s | ||
| signal_plus_128 = lambda sig: 128 + sig | ||
|
|
||
| # Interrupt the subprocess with the given signal. | ||
| process.send_signal(sig) | ||
|
|
||
| # Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit | ||
| # due to misproper or incomplete handling of the signal. | ||
| try: | ||
| translated_return_code = signal_offset_return_code(process.wait(timeout=15)) | ||
| test_case.assertEqual( | ||
| translated_return_code, | ||
| signal_plus_128(sig), | ||
| f"Expected subprocess return code of {signal_plus_128(sig) = }; got {translated_return_code = }", | ||
| ) | ||
| except subprocess.TimeoutExpired as timeout_exc: | ||
| test_case.fail( | ||
| f"Subprocess timed out before terminating. " | ||
| "Non-daemon thread(s) probably prevented subprocess's main thread from exiting." | ||
| ) | ||
| # Assert that in the case of SIGINT, the process registered a KeyboardInterrupt. | ||
| if sig == signal.SIGINT: | ||
| test_case.assertTrue( | ||
| re.search("KeyboardInterrupt", process.stderr.read()), | ||
| "Did not find expected string 'KeyboardInterrupt' in log output.", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # These lines are run only if the module is launched as a process. | ||
| session = irods.helpers.make_session() | ||
| hc = irods.helpers.home_collection(session) | ||
| TESTFILE_FILL = b"_" * (1024 * 1024) | ||
| object_path = f"{hc}/{OBJECT_NAME}" | ||
|
|
||
| # Create the object to be downloaded. | ||
| with session.data_objects.open(object_path, "w") as f: | ||
| for y in range(OBJECT_SIZE // len(TESTFILE_FILL)): | ||
| f.write(TESTFILE_FILL) | ||
| local_path = None | ||
| # Establish where (ie absolute path) to place the downloaded file, i.e. the get() target. | ||
| try: | ||
| with tempfile.NamedTemporaryFile( | ||
| prefix="local_file_issue_722.dat", delete=True | ||
| ) as t: | ||
| local_path = t.name | ||
|
|
||
| # Tell the parent process the name of the local file being "get"ted (got) from iRODS | ||
| print(local_path) | ||
| sys.stdout.flush() | ||
|
|
||
| def handler(sig,*_): | ||
| abort_parallel_transfers() | ||
| exit(128+sig) | ||
|
|
||
| signal.signal(signal.SIGTERM, handler) | ||
|
|
||
| try: | ||
| # download the object | ||
| session.data_objects.get(object_path, local_path) | ||
| except KeyboardInterrupt: | ||
| abort_parallel_transfers() | ||
| raise | ||
|
|
||
| finally: | ||
| # Clean up, whether or not the download succeeded. | ||
| if local_path is not None and os.path.exists(local_path): | ||
| os.unlink(local_path) | ||
| if session.data_objects.exists(object_path): | ||
| session.data_objects.unlink(object_path, force=True) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.