from __future__ import annotations

import concurrent.futures
import contextlib
import filecmp
import logging
import math
import multiprocessing
import os
import platform
import queue
import shutil
import subprocess
import sys
import tempfile
import time
from collections.abc import Mapping, Sequence
from concurrent.futures import FIRST_COMPLETED, Future, wait
from dataclasses import dataclass
from enum import Enum, auto, unique
from pathlib import Path
from typing import Any, Callable

import pebble

from cvise.cvise import CVise
from cvise.passes.abstract import AbstractPass, PassResult
from cvise.passes.hint_based import HintBasedPass, HintState
from cvise.utils import cache, fileutil, mplogging, sigmonitor
from cvise.utils.error import (
    AbsolutePathTestCaseError,
    InsaneTestCaseError,
    InvalidInterestingnessTestError,
    InvalidTestCaseError,
    PassBugError,
    ScriptInsideTestCaseError,
    ZeroSizeError,
)
from cvise.utils.folding import FoldingManager, FoldingStateIn, FoldingStateOut
from cvise.utils.hint import is_special_hint_type, load_hints
from cvise.utils.process import MPContextHook, MPTaskLossWorkaround, ProcessEventNotifier, ProcessMonitor
from cvise.utils.readkey import KeyLogger

MAX_PASS_INCREASEMENT_THRESHOLD = 3


@unique
class PassCheckingOutcome(Enum):
    """Outcome of checking the result of an invocation of a pass."""

    ACCEPT = auto()
    IGNORE = auto()
    STOP = auto()


def rmfolder(name):
    assert 'cvise' in str(name)
    try:
        shutil.rmtree(name)
    except OSError:
        pass


@dataclass
class InitEnvironment:
    """Holds data for executing a Pass new() method in a worker."""

    pass_new: Callable
    test_case: Path
    tmp_dir: Path
    job_timeout: int
    pid_queue: queue.Queue
    dependee_bundle_paths: list[Path]

    def run(self) -> Any:
        dependee_hints = [load_hints(p, begin_index=None, end_index=None) for p in self.dependee_bundle_paths]
        try:
            return self.pass_new(
                self.test_case,
                tmp_dir=self.tmp_dir,
                job_timeout=self.job_timeout,
                process_event_notifier=ProcessEventNotifier(self.pid_queue),
                dependee_hints=dependee_hints,
            )
        except UnicodeDecodeError:
            # most likely the pass is incompatible with non-UTF files - abort it
            logging.debug('Skipping pass due to a unicode issue')
            return None


@dataclass
class AdvanceOnSuccessEnvironment:
    """Holds data for executing a Pass advance_on_success() method in a worker."""

    pass_advance_on_success: Callable
    test_case: Path
    pass_previous_state: Any
    new_tmp_dir: Path
    pass_succeeded_state: Any
    job_timeout: int
    pid_queue: queue.Queue
    dependee_bundle_paths: list[Path]

    def run(self) -> Any:
        dependee_hints = [load_hints(p, begin_index=None, end_index=None) for p in self.dependee_bundle_paths]
        return self.pass_advance_on_success(
            self.test_case,
            state=self.pass_previous_state,
            new_tmp_dir=self.new_tmp_dir,
            succeeded_state=self.pass_succeeded_state,
            job_timeout=self.job_timeout,
            process_event_notifier=ProcessEventNotifier(self.pid_queue),
            dependee_hints=dependee_hints,
        )


class TestEnvironment:
    """Holds data for running a Pass transform() method and the interestingness test in a worker.

    The transform call is optional - in that case, the interestingness test is simply executed for the unchanged input
    (this is useful for implementing the "sanity check" of the input on the C-Vise startup).
    """

    def __init__(
        self,
        state,
        order,
        test_script,
        folder: Path,
        test_case: Path,
        all_test_cases: set[Path],
        should_copy_test_cases: bool,
        transform,
        pid_queue: queue.Queue | None = None,
    ):
        self.state = state
        self.folder: Path = folder
        self.base_size = None
        self.test_script = test_script
        self.exitcode = None
        self.result = None
        self.order = order
        self.transform = transform
        self.pid_queue = pid_queue
        self.test_case: Path = test_case
        self.should_copy_test_cases = should_copy_test_cases
        self.base_size = fileutil.get_file_size(test_case)
        self.all_test_cases: set[Path] = all_test_cases

    @property
    def size_improvement(self):
        return self.base_size - fileutil.get_file_size(self.test_case_path)

    @property
    def test_case_path(self) -> Path:
        return self.folder / self.test_case

    @property
    def success(self):
        return self.result == PassResult.OK and self.exitcode == 0

    def dump(self, dst):
        for f in self.all_test_cases:
            shutil.copy(self.folder / f, dst)

        shutil.copy(self.test_script, dst)

    def copy_test_cases(self):
        for test_case in self.all_test_cases:
            fileutil.copy_test_case(test_case, self.folder)

    def run(self):
        try:
            # If the pass needs this, copy files to the created folder (e.g., hint-based passes don't need this).
            if self.should_copy_test_cases:
                self.copy_test_cases()

            # transform by state
            written_paths: set[Path] = set()
            (result, self.state) = self.transform(
                self.test_case_path,
                self.state,
                process_event_notifier=ProcessEventNotifier(self.pid_queue),
                original_test_case=self.test_case.resolve(),
                written_paths=written_paths,
            )
            self.result = result
            if self.result != PassResult.OK:
                return self

            # run test script
            self.exitcode = self.run_test(False)

            # cleanup (only useful for successful case - otherwise job's dir will be anyway deleted by main process)
            if self.exitcode == 0:
                fileutil.remove_extraneous_files(self.test_case_path, written_paths)

            return self
        except UnicodeDecodeError:
            # most likely the pass is incompatible with non-UTF files - terminate it
            logging.debug('Skipping pass due to a unicode issue')
            self.result = PassResult.STOP
            return self
        except Exception:
            logging.exception('Unexpected TestEnvironment::run failure')
            return self

    def run_test(self, verbose):
        with fileutil.chdir(self.folder):
            # Make the job use our custom temp dir instead of the standard one, so that the standard location doesn't
            # get cluttered with files it might leave undeleted (the process might do this because of an oversight in
            # the interestingness test, or because C-Vise abruptly kills our job without a chance for a proper cleanup).
            with tempfile.TemporaryDirectory(dir=self.folder, prefix='overridetmp') as tmp_override:
                env = override_tmpdir_env(os.environ.copy(), Path(tmp_override))
                stdout, stderr, returncode = ProcessEventNotifier(self.pid_queue).run_process(
                    str(self.test_script), shell=True, env=env
                )
            if verbose and returncode != 0:
                # Drop invalid UTF sequences.
                logging.debug('stdout:\n%s', stdout.decode('utf-8', 'ignore'))
                logging.debug('stderr:\n%s', stderr.decode('utf-8', 'ignore'))
        return returncode


@unique
class PassStage(Enum):
    BEFORE_INIT = auto()
    IN_INIT = auto()
    ENUMERATING = auto()


@dataclass
class PassContext:
    """Stores runtime data for a currently active pass."""

    pass_: AbstractPass
    stage: PassStage
    # Whether the pass is enabled for the current test case.
    enabled: bool
    # Stores pass-specific files to be used during transform jobs (e.g., hints generated during initialization), and
    # temporary folders for each transform job.
    temporary_root: Path | None
    # The pass state as returned by the pass new()/advance()/advance_on_success() methods.
    state: Any
    # The state that succeeded in the previous batch of jobs - to be passed as succeeded_state to advance_on_success().
    taken_succeeded_state: Any
    # States that succeeded in the current batch of jobs.
    current_batch_succeeded_states: list[Any]
    # The overall number of jobs that have been started for the pass, throughout the whole run_passes() invocation.
    pass_job_counter: int
    # The value of pass_job_counter used for scheduling the most recent successful job.
    last_success_pass_job_counter: int
    # The value of pass_job_counter when the current batch (run_parallel_tests()) started.
    current_batch_start_job_counter: int
    # Currently running transform jobs, as the (order, state) mapping.
    running_transform_order_to_state: dict[int, Any]
    # When True, the pass is considered dysfunctional (due to an issue) and shouldn't be used anymore.
    defunct: bool
    # How many times a job for this pass timed out.
    timeout_count: int
    # Mapping from a hint type to a bundle path that the pass generated, for a hint-based pass.
    hint_bundle_paths: dict[bytes, Path]

    @staticmethod
    def create(pass_: AbstractPass) -> PassContext:
        return PassContext(
            pass_=pass_,
            stage=PassStage.BEFORE_INIT,
            enabled=True,
            temporary_root=None,
            state=None,
            taken_succeeded_state=None,
            current_batch_succeeded_states=[],
            pass_job_counter=0,
            last_success_pass_job_counter=0,
            current_batch_start_job_counter=0,
            running_transform_order_to_state={},
            defunct=False,
            timeout_count=0,
            hint_bundle_paths={},
        )

    def jobs_in_current_batch(self) -> int:
        assert self.pass_job_counter >= self.current_batch_start_job_counter
        return self.pass_job_counter - self.current_batch_start_job_counter

    def jobs_since_last_success(self) -> int:
        assert self.pass_job_counter >= self.last_success_pass_job_counter
        return self.pass_job_counter - self.last_success_pass_job_counter

    def can_init_now(self, ready_hint_types: set[bytes]) -> bool:
        """Whether the pass new() method can be scheduled."""
        if not self.enabled or self.defunct or self.stage != PassStage.BEFORE_INIT:
            return False
        if isinstance(self.pass_, HintBasedPass) and not ready_hint_types.issuperset(self.pass_.input_hint_types()):
            # Not all dependee passes completed their initialization.
            return False
        return True

    def can_transform_now(self) -> bool:
        """Whether the pass transform() method can be scheduled."""
        return self.enabled and not self.defunct and self.stage == PassStage.ENUMERATING and self.state is not None

    def can_start_job_now(self, ready_hint_types: set[bytes]) -> bool:
        """Whether any of the pass methods can be scheduled."""
        return self.can_init_now(ready_hint_types) or self.can_transform_now()

    def can_schedule_for_restart(self) -> bool:
        """Whether the restart of the pass could be scheduled.

        The restart means reinitializing the pass and iterating through its states; it's useful in the interleaving mode
        since after a pass finished its enumeration once, other passes might've performed reductions that unblocked new
        reduction possibilities for this pass again.

        Restarting isn't useful for passes that only produce "special" hints, since such hints are only used to convey
        information to other passes but aren't enumerated as reduction attempts themselves.
        """
        return (
            self.stage == PassStage.ENUMERATING
            and self.state is None
            and any(not is_special_hint_type(t) for t in self.hint_bundle_paths.keys())
        )

    def should_reinit_after_test_case_update(self, other_contexts: Sequence[PassContext]) -> bool:
        """Whether the pass should be initialized in the next round, after a reduction is applied to the test case."""
        if self.stage == PassStage.ENUMERATING and self.state is not None:
            # A pass that hasn't finished enumerating its states has to be reinitialized immediately, to continue the
            # enumeration.
            return True
        if isinstance(self.pass_, HintBasedPass):
            # Also reinitialize a pass if it generates input data for other passes.
            outputs = set(self.pass_.output_hint_types())
            for ctx in other_contexts:
                if ctx.enabled and not ctx.defunct and isinstance(ctx.pass_, HintBasedPass):
                    inputs = set(ctx.pass_.input_hint_types())
                    if not outputs.isdisjoint(inputs):
                        return True
        return False


@unique
class JobType(Enum):
    INIT = auto()
    TRANSFORM = auto()


@dataclass
class Job:
    type: JobType
    future: Future
    order: int

    # If this job executes a method of a pass, these store pointers to it; None otherwise.
    pass_: AbstractPass | None
    pass_id: int | None
    pass_user_visible_name: str
    pass_job_counter: int | None

    start_time: float
    timeout: float
    temporary_folder: Path | None


@dataclass
class SuccessCandidate:
    order: int
    pass_: AbstractPass | None
    pass_id: int | None
    pass_state: Any
    size_delta: int
    tmp_dir: Path | None = None
    test_case_path: Path | None = None

    def take_file_ownership(self, test_case_path: Path) -> None:
        assert self.tmp_dir is None
        assert self.test_case_path is None
        self.tmp_dir = Path(tempfile.mkdtemp(prefix=f'{TestManager.TEMP_PREFIX}candidate-'))
        self.test_case_path = self.tmp_dir / test_case_path.name
        shutil.move(test_case_path, self.test_case_path)

    def release(self) -> None:
        if self.tmp_dir is not None:
            rmfolder(self.tmp_dir)
        self.tmp_dir = None
        self.test_case_path = None

    def better_than(self, other: SuccessCandidate) -> bool:
        return self._comparison_key() < other._comparison_key()

    def _comparison_key(self) -> tuple:
        # We prefer folds over a reduction via a single pass, since folds perform a more diverse transformation of the
        # test case and since all single-pass successes will eventually end up as part of a fold.
        is_fold = isinstance(self.pass_state, FoldingStateOut)
        # The more reduced the better; if there's nothing reduced or the size grew, treat this as the same case (zero)
        # to be disambiguated by the other criteria below.
        reduction = -self.size_delta if self.size_delta < 0 else 0
        # The more "instances" (e.g., hints) taken for this attempt the better; some of legacy passes don't have this
        # property, and for them it's always assumed "1 instance taken".
        taken_instance_count = self.pass_state.real_chunk() if hasattr(self.pass_state, 'real_chunk') else 1
        return (
            0 if is_fold else 1,
            -reduction,
            -taken_instance_count,
        )


class TestManager:
    GIVEUP_CONSTANT = 50000
    MAX_TIMEOUTS = 20
    MAX_CRASH_DIRS = 10
    MAX_EXTRA_DIRS = 25000
    TEMP_PREFIX = 'cvise-'
    BUG_DIR_PREFIX = 'cvise_bug_'
    EXTRA_DIR_PREFIX = 'cvise_extra_'
    # How often passes should be restarted (see maybe_schedule_job()). Chosen at 1% to not slow down the overall
    # reduction in case restarts don't lead to new discoveries.
    RESTART_JOB_INTERVAL = 100
    # Used for setting up timeouts on pass init jobs - the regular timeout is multiplied by this factor.
    INIT_TIMEOUT_FACTOR = 10

    def __init__(
        self,
        pass_statistic,
        test_script: Path,
        timeout,
        save_temps,
        test_cases: list[Path],
        parallel_tests,
        no_cache,
        skip_key_off,
        silent_pass_bug,
        die_on_pass_bug,
        print_diff,
        max_improvement,
        no_give_up,
        also_interesting,
        start_with_pass,
        skip_after_n_transforms,
        stopping_threshold,
    ):
        self.test_script: Path = test_script.absolute()
        self.timeout = timeout
        self.save_temps = save_temps
        self.pass_statistic = pass_statistic
        self.test_cases: set[Path] = set()
        self.test_cases_modes: dict[Path, int] = {}
        self.parallel_tests = parallel_tests
        self.no_cache = no_cache
        self.skip_key_off = skip_key_off
        self.silent_pass_bug = silent_pass_bug
        self.die_on_pass_bug = die_on_pass_bug
        self.print_diff = print_diff
        self.max_improvement = max_improvement
        self.no_give_up = no_give_up
        self.also_interesting = also_interesting
        self.start_with_pass = start_with_pass
        self.skip_after_n_transforms = skip_after_n_transforms
        self.stopping_threshold = stopping_threshold
        self.exit_stack = contextlib.ExitStack()

        for test_case in test_cases:
            test_case = Path(test_case)
            self.test_cases_modes[test_case] = test_case.stat().st_mode
            self.check_file_permissions(test_case, [os.F_OK, os.R_OK, os.W_OK], InvalidTestCaseError)
            if test_case.parent.is_absolute():
                raise AbsolutePathTestCaseError(test_case)
            if test_case.resolve() in self.test_script.resolve().parents:
                raise ScriptInsideTestCaseError(test_case, self.test_script)
            self.test_cases.add(test_case)

        self.orig_total_file_size = self.total_file_size
        self.cache = None if self.no_cache else cache.Cache(f'{self.TEMP_PREFIX}cache-')
        self.pass_contexts: list[PassContext] = []
        self.interleaving: bool = False
        if not self.is_valid_test(self.test_script):
            raise InvalidInterestingnessTestError(self.test_script)
        self.current_test_case: Path = Path()
        self.jobs: list[Job] = []
        # The "order" is an incremental counter for numbering jobs.
        self.order: int = 0
        # Remembers the "order" that the first job in the current batch (run_parallel_tests()) got.
        self.current_batch_start_order: int = 0
        # Identifies the most recent pass restart job (whether in the current batch or not).
        self.last_restart_job_order: int | None = None
        self.success_candidate: SuccessCandidate | None = None
        self.folding_manager: FoldingManager | None = None
        # Ids of passes that are eligible for the restart, in FIFO order.
        self.pass_restart_queue: list[int] = []

        self.use_colordiff = (
            sys.stdout.isatty()
            and subprocess.run(
                'colordiff --version',
                shell=True,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
            ).returncode
            == 0
        )

        self.key_logger = None if self.skip_key_off else KeyLogger()
        self.mpmanager = multiprocessing.Manager()
        self.process_monitor = ProcessMonitor(self.mpmanager, self.parallel_tests)
        self.mplogger = mplogging.MPLogger(self.parallel_tests)
        self.worker_pool: pebble.ProcessPool | None = None
        self.mp_task_loss_workaround = MPTaskLossWorkaround(self.parallel_tests)

    def __enter__(self):
        if self.cache:
            self.exit_stack.enter_context(self.cache)

        if self.key_logger:
            self.exit_stack.enter_context(self.key_logger)
        self.exit_stack.enter_context(self.process_monitor)

        worker_initializers = [
            self.mplogger.worker_process_initializer(),
            self.mp_task_loss_workaround.initialize_in_worker,
        ]
        self.worker_pool = pebble.ProcessPool(
            max_workers=self.parallel_tests,
            initializer=_init_worker_process,
            initargs=[worker_initializers],
            context=MPContextHook(self.process_monitor),
        )

        self.exit_stack.enter_context(self.worker_pool)
        self.exit_stack.enter_context(self.mplogger)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.worker_pool.stop()
        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
        self.worker_pool = None

    def remove_roots(self):
        if self.save_temps:
            return
        for ctx in self.pass_contexts:
            if not ctx.temporary_root:
                continue
            rmfolder(ctx.temporary_root)
            ctx.temporary_root = None
        if self.success_candidate:
            self.success_candidate.release()
            self.success_candidate = None

    def restore_mode(self):
        for test_case in self.test_cases:
            test_case.chmod(self.test_cases_modes[test_case])

    @classmethod
    def is_valid_test(cls, test_script: Path):
        for mode in {os.F_OK, os.X_OK}:
            if not os.access(test_script, mode):
                return False
        return True

    @property
    def sorted_test_cases(self):
        return sorted(self.test_cases, key=lambda x: x.stat().st_size, reverse=True)

    @property
    def total_file_size(self) -> int:
        return sum(fileutil.get_file_size(p) for p in self.test_cases)

    @property
    def total_line_count(self) -> int:
        return sum(fileutil.get_line_count(p) for p in self.test_cases)

    @property
    def total_file_count(self) -> int:
        return sum(fileutil.get_file_count(p) for p in self.test_cases)

    @property
    def total_dir_count(self) -> int:
        return sum(fileutil.get_dir_count(p) for p in self.test_cases)

    def backup_test_cases(self):
        for f in self.test_cases:
            orig_file = Path(f'{f}.orig')

            if not orig_file.exists():
                # Copy file and preserve attributes
                shutil.copy2(f, orig_file)

    @staticmethod
    def check_file_permissions(path: Path, modes, error):
        for m in modes:
            if not os.access(path, m):
                if error is not None:
                    raise error(path, m)
                else:
                    return False

        return True

    @staticmethod
    def get_extra_dir(prefix, max_number) -> Path | None:
        extra_dir = None
        for i in range(0, max_number + 1):
            digits = int(round(math.log10(max_number), 0))
            extra_dir = Path(('{0}{1:0' + str(digits) + 'd}').format(prefix, i))

            if not extra_dir.exists():
                break

        # just bail if we've already created enough of these dirs, no need to
        # clutter things up even more...
        if not extra_dir or extra_dir.exists():
            return None

        return extra_dir

    def report_pass_bug(self, job: Job, problem: str):
        """Create pass report bug and return True if the directory is created."""

        if not self.die_on_pass_bug:
            logging.warning(f'{job.pass_} has encountered a non fatal bug: {problem}')

        crash_dir = self.get_extra_dir(self.BUG_DIR_PREFIX, self.MAX_CRASH_DIRS)

        if crash_dir is None:
            return False

        crash_dir.mkdir()
        test_env: TestEnvironment = job.future.result()
        test_env.dump(crash_dir)

        if not self.die_on_pass_bug:
            logging.debug(
                f'Please consider tarring up {crash_dir} and creating an issue at https://github.com/marxin/cvise/issues and we will try to fix the bug.'
            )

        (crash_dir / 'PASS_BUG_INFO.TXT').write_text(
            f'Package: {CVise.Info.PACKAGE_STRING}\n'
            + f'Git version: {CVise.Info.GIT_VERSION}\n'
            + f'LLVM version: {CVise.Info.LLVM_VERSION}\n'
            + f'System: {str(platform.uname())}\n'
            + PassBugError.MSG.format(job.pass_, problem, test_env.state, crash_dir)
        )

        if self.die_on_pass_bug:
            raise PassBugError(job.pass_, problem, test_env.state, crash_dir)
        else:
            return True

    def diff_files(self, orig_test_case: Path, changed_test_case: Path) -> str:
        diff_bytes = fileutil.diff_test_cases(orig_test_case, changed_test_case)

        if self.use_colordiff:
            try:
                diff_bytes = subprocess.check_output('colordiff', input=diff_bytes)
            except Exception as e:
                logging.warning('Failed to generate color diff: %s', e)
                # Fall back to non-colored diff.

        # Drop invalid UTF sequences, if any, from the diff, to make it easy to log.
        return diff_bytes.decode('utf-8', 'ignore')

    def check_sanity(self):
        logging.debug('perform sanity check... ')

        folder = Path(tempfile.mkdtemp(prefix=f'{self.TEMP_PREFIX}sanity-'))
        test_env = TestEnvironment(
            None,
            0,
            self.test_script,
            folder,
            list(self.test_cases)[0],
            self.test_cases,
            should_copy_test_cases=True,
            transform=None,
        )
        logging.debug(f'sanity check tmpdir = {test_env.folder}')

        test_env.copy_test_cases()
        returncode = test_env.run_test(verbose=True)
        if returncode == 0:
            rmfolder(folder)
            logging.debug('sanity check successful')
        else:
            if not self.save_temps:
                rmfolder(folder)
            raise InsaneTestCaseError(self.test_cases, self.test_script)

    @classmethod
    def log_key_event(cls, event):
        logging.info(f'****** {event} ******')

    def release_job(self, job: Job) -> None:
        if not self.save_temps and job.temporary_folder is not None:
            rmfolder(job.temporary_folder)
        self.jobs.remove(job)

    def release_all_jobs(self) -> None:
        while self.jobs:
            self.release_job(self.jobs[0])

    def save_extra_dir(self, test_case_path: Path):
        extra_dir = self.get_extra_dir(self.EXTRA_DIR_PREFIX, self.MAX_EXTRA_DIRS)
        if extra_dir is not None:
            try:
                os.mkdir(extra_dir)
                shutil.move(test_case_path, extra_dir)
            except OSError as e:
                logging.warning('Failed to create extra directory %s: %s', extra_dir, e)
                # Gracefully handle exceptions here - storing "extra" dirs is not critical for the reduction use case,
                # and an exception can occur simply due to child processes of the interestingness test creating/deleting
                # files in its work dir. Just make sure to delete the half-created extra dir.
                rmfolder(extra_dir)
            else:
                logging.info(f'Created extra directory {extra_dir} for you to look at later')

    def workaround_missing_timeouts(self) -> None:
        """Workaround for Pebble sometimes losing a task, with its future neither resolving nor timing out.

        To avoid hanging C-Vise waiting for such never-ending tasks, we double-check timeouts of each task ourselves
        and force-cancel violating tasks.
        """
        THRESHOLD = 10  # usually this factor is around 1, but durations can grow on a heavily loaded machine
        now = time.monotonic()
        for job in self.jobs:
            if not job.future.done() and now - job.start_time >= THRESHOLD * job.timeout:
                self.cancel_job(job)

    def process_done_futures(self) -> None:
        jobs_to_remove = []
        for job in self.jobs:
            if not job.future.done():
                continue
            jobs_to_remove.append(job)
            if job.future.cancelled():
                # Within the task loop, we only cancel jobs in workaround_missing_timeouts().
                self.handle_timed_out_job(job)
                continue
            if exc := job.future.exception():
                # starting with Python 3.11: concurrent.futures.TimeoutError == TimeoutError
                if type(exc) in (TimeoutError, concurrent.futures.TimeoutError):
                    self.handle_timed_out_job(job)
                    continue
                raise exc
            if job.type == JobType.INIT:
                self.handle_finished_init_job(job)
            elif job.type == JobType.TRANSFORM:
                self.handle_finished_transform_job(job)
            else:
                raise ValueError(f'Unexpected job type {job.type}')

        for job in jobs_to_remove:
            self.release_job(job)

    def handle_timed_out_job(self, job: Job) -> None:
        logging.warning('Test timed out for %s.', job.pass_user_visible_name)
        if job.temporary_folder:
            self.save_extra_dir(job.temporary_folder)
        if job.pass_id is None:
            # The logic of disabling a pass after repeated timeouts isn't applicable to folding jobs.
            return
        ctx = self.pass_contexts[job.pass_id]
        ctx.timeout_count += 1
        if job.type == JobType.TRANSFORM:
            ctx.running_transform_order_to_state.pop(job.order)
        if ctx.timeout_count < self.MAX_TIMEOUTS or ctx.defunct:
            return
        logging.warning(
            'Maximum number of timeout were reached for %s: %s', job.pass_user_visible_name, self.MAX_TIMEOUTS
        )
        ctx.defunct = True

    def handle_finished_init_job(self, job: Job) -> None:
        assert job.pass_id is not None
        ctx: PassContext = self.pass_contexts[job.pass_id]
        assert ctx.stage == PassStage.IN_INIT
        ctx.stage = PassStage.ENUMERATING
        ctx.state = job.future.result()

        # Put the job's folder into the pass context for future transform jobs; the old folder, if any, has to be cleaned up.
        if ctx.temporary_root is not None:
            rmfolder(ctx.temporary_root)
            ctx.temporary_root = None
        ctx.temporary_root = job.temporary_folder
        job.temporary_folder = None

        self.pass_statistic.add_initialized(job.pass_, job.start_time)
        if isinstance(ctx.pass_, HintBasedPass):
            ctx.hint_bundle_paths = {} if ctx.state is None else ctx.state.hint_bundle_paths()

    def handle_finished_transform_job(self, job: Job) -> None:
        env: TestEnvironment = job.future.result()
        self.pass_statistic.add_executed(job.pass_, job.start_time, self.parallel_tests)

        ctx = self.pass_contexts[job.pass_id] if job.pass_id is not None else None
        if ctx:
            ctx.running_transform_order_to_state.pop(job.order)

        outcome = self.check_pass_result(job)
        if outcome == PassCheckingOutcome.STOP:
            assert ctx is not None
            ctx.state = None
            return
        if outcome == PassCheckingOutcome.IGNORE:
            self.pass_statistic.add_failure(job.pass_)
            if self.interleaving:
                self.folding_manager.on_transform_job_failure(env.state)
            return
        assert outcome == PassCheckingOutcome.ACCEPT
        self.pass_statistic.add_success(job.pass_)
        self.maybe_update_success_candidate(job.order, job.pass_, job.pass_id, env)
        if self.interleaving:
            self.folding_manager.on_transform_job_success(env.state)
        if ctx:
            ctx.current_batch_succeeded_states.append(env.state)
            assert job.pass_job_counter is not None
            ctx.last_success_pass_job_counter = max(ctx.last_success_pass_job_counter, job.pass_job_counter)

    def check_pass_result(self, job: Job):
        test_env: TestEnvironment = job.future.result()
        if test_env.success:
            if self.max_improvement is not None and test_env.size_improvement > self.max_improvement:
                logging.debug(f'Too large improvement: {test_env.size_improvement} B')
                return PassCheckingOutcome.IGNORE
            # Report bug if transform did not change the file
            if filecmp.cmp(self.current_test_case, test_env.test_case_path):
                if not self.silent_pass_bug:
                    if not self.report_pass_bug(job, 'pass failed to modify the variant'):
                        return PassCheckingOutcome.STOP
                return PassCheckingOutcome.IGNORE
            return PassCheckingOutcome.ACCEPT

        if test_env.result == PassResult.OK:
            assert test_env.exitcode
            if self.also_interesting is not None and test_env.exitcode == self.also_interesting:
                self.save_extra_dir(test_env.test_case_path)
        elif test_env.result == PassResult.STOP:
            return PassCheckingOutcome.STOP
        elif test_env.result == PassResult.ERROR:
            if not self.silent_pass_bug:
                self.report_pass_bug(job, 'pass error')
                return PassCheckingOutcome.STOP

        if not self.no_give_up and job.pass_id is not None:
            ctx = self.pass_contexts[job.pass_id]
            if not ctx.defunct and ctx.jobs_since_last_success() > self.GIVEUP_CONSTANT:
                self.report_pass_bug(job, 'pass got stuck')
                ctx.defunct = True
                return PassCheckingOutcome.STOP
        return PassCheckingOutcome.IGNORE

    def maybe_update_success_candidate(
        self, order: int, pass_: AbstractPass | None, pass_id: int | None, env: TestEnvironment
    ) -> None:
        assert env.success
        new = SuccessCandidate(
            order=order,
            pass_=pass_,
            pass_id=pass_id,
            pass_state=env.state,
            size_delta=-env.size_improvement,
        )
        if self.success_candidate and not new.better_than(self.success_candidate):
            return
        if self.success_candidate:
            # Make sure to clean up old temporary files.
            self.success_candidate.release()
        new.take_file_ownership(env.test_case_path)
        self.success_candidate = new

    def terminate_all(self) -> None:
        for job in self.jobs:
            self.cancel_job(job)
        self.worker_pool.stop()  # will also stop tasks not tracked in self.jobs
        self.release_all_jobs()

    def cancel_job(self, job: Job) -> None:
        self.mplogger.ignore_logs_from_job(job.order)
        job.future.cancel()

    def run_parallel_tests(self) -> None:
        assert self.worker_pool
        assert not self.jobs
        self.current_batch_start_order = self.order
        assert self.success_candidate is None
        if self.interleaving:
            self.folding_manager = FoldingManager()

        for pass_id, ctx in enumerate(self.pass_contexts):
            # Clean up the information about previously running jobs.
            ctx.running_transform_order_to_state = {}
            ctx.current_batch_succeeded_states = []
            ctx.current_batch_start_job_counter = ctx.pass_job_counter
            # Unfinished initializations from the last run will need to be restarted.
            if ctx.stage == PassStage.IN_INIT:
                ctx.stage = PassStage.BEFORE_INIT
            # Previously finished passes are eligible for restart (used for "interleaving" mode only - in the old
            # single-pass mode we're expected to return to let subsequent passes work).
            if self.interleaving and pass_id not in self.pass_restart_queue and ctx.can_schedule_for_restart():
                self.pass_restart_queue.append(pass_id)

        ready_hint_types = self.get_fully_initialized_hint_types()
        while self.jobs or any(c.can_start_job_now(ready_hint_types) for c in self.pass_contexts):
            sigmonitor.maybe_retrigger_action()

            # schedule new jobs, as long as there are free workers
            while len(self.jobs) < self.parallel_tests and self.maybe_schedule_job():
                pass

            # no more jobs could be scheduled at the moment - wait for some results
            wait([j.future for j in self.jobs] + [sigmonitor.get_future()], return_when=FIRST_COMPLETED)
            sigmonitor.maybe_retrigger_action()

            self.workaround_missing_timeouts()
            self.process_done_futures()

            # exit if we found successful transformation(s) and don't want to try better ones
            if self.success_candidate and self.should_proceed_with_success_candidate():
                break

        if self.jobs:
            for job in self.jobs:
                self.cancel_job(job)
            self.mp_task_loss_workaround.execute(self.worker_pool)  # only do it if at least one job canceled
            self.release_all_jobs()

    def run_passes(self, passes: Sequence[AbstractPass], interleaving: bool):
        # Automatically add subordinate (helper) passes where needed. Prioritize them early in the list, so that they
        # get executed as early as possible and unblock their main pass' execution.
        extra_passes = []
        for p in passes:
            extra_passes += p.create_subordinate_passes()
        passes = extra_passes + list(passes)

        assert len(passes) == 1 or interleaving

        if self.start_with_pass:
            current_pass_names = [str(c.pass_) for c in self.pass_contexts]
            if self.start_with_pass in current_pass_names:
                self.start_with_pass = None
            else:
                return

        self.order = 1
        self.last_restart_job_order = None
        self.pass_restart_queue = []
        self.pass_contexts = []
        for pass_ in passes:
            self.pass_contexts.append(PassContext.create(pass_))
        self.interleaving = interleaving
        self.jobs = []

        pass_titles = [c.pass_.user_visible_name() for c in self.pass_contexts]
        pass_titles_str = ', '.join(sorted(set(pass_titles)))
        logging.info(f'===< {pass_titles_str} >===')

        if self.total_file_size == 0:
            raise ZeroSizeError(self.test_cases)

        try:
            for test_case in self.sorted_test_cases:
                self.current_test_case = test_case
                starting_test_case_size = fileutil.get_file_size(test_case)
                success_count = 0

                if starting_test_case_size == 0:
                    continue

                if not self.no_cache:
                    hash_before_pass = fileutil.hash_test_case(test_case)
                    if cached_path := self.cache.lookup(passes, hash_before_pass):
                        fileutil.replace_test_case_atomically(cached_path, test_case, move=False)
                        logging.info(f'cache hit for {test_case}')
                        continue
                else:
                    hash_before_pass = None

                is_dir = test_case.is_dir()
                for ctx in self.pass_contexts:
                    ctx.enabled = not is_dir or ctx.pass_.supports_dir_test_cases()

                self.skip = False
                ready_hint_types = self.get_fully_initialized_hint_types()
                while any(c.can_start_job_now(ready_hint_types) for c in self.pass_contexts) and not self.skip:
                    # Ignore more key presses after skip has been detected
                    if not self.skip_key_off and not self.skip:
                        key = self.key_logger.pressed_key()
                        if key == 's':
                            self.skip = True
                            self.log_key_event('skipping the rest of this pass')
                        elif key == 'd':
                            self.log_key_event('toggle print diff')
                            self.print_diff = not self.print_diff

                    self.run_parallel_tests()

                    is_success = self.success_candidate is not None
                    if is_success:
                        self.process_result()
                        success_count += 1

                    # if the file increases significantly, bail out the current pass
                    test_case_size = fileutil.get_file_size(self.current_test_case)
                    if test_case_size >= MAX_PASS_INCREASEMENT_THRESHOLD * starting_test_case_size:
                        logging.info(
                            f'skipping the rest of the pass (huge file increasement '
                            f'{MAX_PASS_INCREASEMENT_THRESHOLD * 100}%)'
                        )
                        break

                    if not is_success:
                        break

                    # skip after N transformations if requested
                    skip_rest = self.skip_after_n_transforms and success_count >= self.skip_after_n_transforms
                    if not self.interleaving:  # max-transforms is only supported for non-interleaving passes
                        assert len(self.pass_contexts) == 1
                        if (
                            self.pass_contexts[0].pass_.max_transforms
                            and success_count >= self.pass_contexts[0].pass_.max_transforms
                        ):
                            skip_rest = True
                    if skip_rest:
                        logging.info(f'skipping after {success_count} successful transformations')
                        break

                if not self.no_cache:
                    assert hash_before_pass is not None
                    self.cache.add(passes, hash_before_pass, test_case)

            self.restore_mode()
            self.remove_roots()
        except (KeyboardInterrupt, SystemExit):
            logging.info('Exiting now ...')
            # Clean temporary files for all jobs and passes.
            self.terminate_all()
            self.remove_roots()
            sys.exit(1)

    def process_result(self) -> None:
        assert self.success_candidate
        new_test_case = self.success_candidate.test_case_path
        assert new_test_case
        if self.print_diff:
            logging.info('%s', self.diff_files(self.current_test_case, new_test_case))

        try:
            fileutil.replace_test_case_atomically(new_test_case, self.current_test_case)
        except FileNotFoundError:
            raise RuntimeError(
                f"Can't find {self.current_test_case} -- did your interestingness test move it?"
            ) from None

        # Update global stats.
        if isinstance(self.success_candidate.pass_state, FoldingStateOut):
            self.pass_statistic.add_committed_success(None, self.success_candidate.size_delta)
            for (
                pass_user_visible_name,
                size_delta,
            ) in self.success_candidate.pass_state.size_delta_per_pass.items():
                self.pass_statistic.add_committed_success(pass_user_visible_name, size_delta)
        else:
            self.pass_statistic.add_committed_success(
                self.success_candidate.pass_.user_visible_name(), self.success_candidate.size_delta
            )

        for pass_id, ctx in enumerate(self.pass_contexts):
            # If there's an earlier state whose check hasn't completed - rewind to this state.
            rewind_to = (
                min(ctx.running_transform_order_to_state.keys()) if ctx.running_transform_order_to_state else None
            )
            # The only exception is when the earliest job is the one that succeeded - in that case take the state that
            # its transform() returned.
            if self.success_candidate.pass_id == pass_id and (
                rewind_to is None or self.success_candidate.order <= rewind_to
            ):
                ctx.state = self.success_candidate.pass_state
            elif rewind_to is not None:
                ctx.state = ctx.running_transform_order_to_state[rewind_to]
            ctx.running_transform_order_to_state = {}

            # Also explicitly remember the state that succeeded - advance_on_success() expects it as a separate argument.
            ctx.taken_succeeded_state = (
                self.success_candidate.pass_state if pass_id == self.success_candidate.pass_id else None
            )

            # Decide which passes to reinit in the next round.
            if ctx.should_reinit_after_test_case_update([c for c in self.pass_contexts if c != ctx]):
                ctx.stage = PassStage.BEFORE_INIT

        if len(self.pass_contexts) > 1:
            if isinstance(self.success_candidate.pass_state, FoldingStateOut):
                pass_name = ' + '.join(self.success_candidate.pass_state.passes_ordered_by_delta)
            else:
                pass_name = self.success_candidate.pass_.user_visible_name()
            log_note = f'via {pass_name}'
        else:
            log_note = ''

        self.success_candidate.release()
        self.success_candidate = None

        self.log_test_case_metrics(log_note)

    def log_test_case_metrics(self, extra_note: str | None = None) -> None:
        pct = 100 - (float(self.total_file_size) * 100.0 / self.orig_total_file_size)
        notes = []
        notes.append(f'{round(pct, 1)}%')
        notes.append(f'{self.total_file_size} bytes')
        if self.total_line_count:
            notes.append(f'{self.total_line_count} lines')
        if len(self.test_cases) > 1 and self.current_test_case:
            notes.append(str(self.current_test_case.name))
        if any(p.is_dir() for p in self.test_cases):
            files = self.total_file_count
            dirs = self.total_dir_count
            notes.append(f'{files} file{"s" if files > 1 else ""} in {dirs} dir{"s" if dirs > 1 else ""}')
        if extra_note is not None:
            notes.append(extra_note)
        logging.info('(' + ', '.join(notes) + ')')

    def should_proceed_with_success_candidate(self):
        assert self.success_candidate
        if not self.interleaving:
            return True
        return not self.folding_manager.continue_attempting_folds(
            self.order - self.current_batch_start_order, self.parallel_tests, len(self.pass_contexts)
        )

    def maybe_schedule_job(self) -> bool:
        # The order matters below - higher-priority job types come earlier:
        # 1. Initializing a pass regularly (at the beginning of the batch of jobs).
        ready_hint_types = self.get_fully_initialized_hint_types()
        for pass_id, ctx in enumerate(self.pass_contexts):
            if ctx.can_init_now(ready_hint_types):
                self.schedule_init(pass_id, ready_hint_types)
                return True
        # 2. Restarting a previously finished pass.
        # We throttle restarts (only once out of RESTART_JOB_INTERVAL jobs) because they're only occasionally useful:
        # for an unused code removal pass it's possible that more unused code after other passes made some deletions,
        # meanwhile for a comment removal pass there's nothing more to discover after all comments have been removed.
        # We use a FIFO queue, spanning across multiple job batches, to avoid repeatedly restarting some passes and
        # never getting to others due to throttling.
        if self.pass_restart_queue and (
            self.last_restart_job_order is None or self.order - self.last_restart_job_order >= self.RESTART_JOB_INTERVAL
        ):
            pass_id = self.pass_restart_queue.pop(0)
            ctx = self.pass_contexts[pass_id]
            assert ctx.stage == PassStage.ENUMERATING
            assert ctx.state is None
            ctx.stage = PassStage.BEFORE_INIT
            if ctx.can_init_now(ready_hint_types):
                self.last_restart_job_order = self.order
                self.schedule_init(pass_id, ready_hint_types)
                return True
        # 3. Attempting a fold (simultaneous application) of previously discovered successful transformations; only
        # supported in the "interleaving" pass execution mode.
        if self.interleaving:
            folding_state = self.folding_manager.maybe_prepare_folding_job(
                self.order - self.current_batch_start_order,
                self.success_candidate.pass_state if self.success_candidate else None,
            )
            if folding_state:
                self.schedule_fold(folding_state)
                return True
        # 4. Attempting a transformation using the next heuristic in the round-robin fashion.
        pass_id = None
        for cand_id, ctx in enumerate(self.pass_contexts):
            if not ctx.can_transform_now():
                continue
            self.advance_while_subset_of_succeeded(ctx)
            if not ctx.can_transform_now():
                continue
            if (
                pass_id is not None
                and self.pass_contexts[pass_id].jobs_in_current_batch() <= ctx.jobs_in_current_batch()
            ):
                continue
            pass_id = cand_id
        if pass_id is not None:
            self.schedule_transform(pass_id)
            return True
        return False

    def schedule_init(self, pass_id: int, ready_hint_types: set[bytes]) -> None:
        ctx = self.pass_contexts[pass_id]
        assert ctx.can_init_now(ready_hint_types)

        dependee_types = set(ctx.pass_.input_hint_types()) if isinstance(ctx.pass_, HintBasedPass) else set()
        dependee_bundle_paths = []
        for other in self.pass_contexts:
            if isinstance(other.pass_, HintBasedPass) and other.stage == PassStage.ENUMERATING:
                dependee_bundle_paths += [
                    path for type, path in other.hint_bundle_paths.items() if type in dependee_types
                ]

        sanitized_name = fileutil.sanitize_for_file_name(str(ctx.pass_))
        tmp_dir = Path(tempfile.mkdtemp(prefix=f'{TestManager.TEMP_PREFIX}{sanitized_name}-'))
        logging.debug(f'Creating pass root folder: {tmp_dir}')

        # Either initialize the pass from scratch, or advance from the previous state.
        if ctx.state is None:
            env = InitEnvironment(
                pass_new=ctx.pass_.new,
                test_case=self.current_test_case,
                tmp_dir=tmp_dir,
                job_timeout=self.timeout,
                pid_queue=self.process_monitor.pid_queue,
                dependee_bundle_paths=dependee_bundle_paths,
            )
        else:
            env = AdvanceOnSuccessEnvironment(
                pass_advance_on_success=ctx.pass_.advance_on_success,
                test_case=self.current_test_case,
                pass_previous_state=ctx.state,
                new_tmp_dir=tmp_dir,
                pass_succeeded_state=ctx.taken_succeeded_state,
                job_timeout=self.timeout,
                pid_queue=self.process_monitor.pid_queue,
                dependee_bundle_paths=dependee_bundle_paths,
            )
        init_timeout = self.INIT_TIMEOUT_FACTOR * self.timeout
        future = self.worker_pool.schedule(
            _worker_process_job_wrapper, args=[self.order, env.run], timeout=init_timeout
        )
        self.jobs.append(
            Job(
                type=JobType.INIT,
                future=future,
                order=self.order,
                pass_=ctx.pass_,
                pass_id=pass_id,
                pass_user_visible_name=ctx.pass_.user_visible_name(),
                pass_job_counter=ctx.pass_job_counter,
                start_time=time.monotonic(),
                timeout=init_timeout,
                temporary_folder=tmp_dir,
            )
        )

        ctx.pass_job_counter += 1
        ctx.stage = PassStage.IN_INIT
        self.order += 1

    def schedule_transform(self, pass_id: int) -> None:
        ctx = self.pass_contexts[pass_id]
        assert ctx.can_transform_now()
        assert ctx.state is not None
        assert ctx.temporary_root is not None

        # Whether we should copy input files to the temporary work directory, or the pass does it itself. For now, we
        # simply hardcode that hint-based passes are capable of this (and they actually need the original files anyway).
        should_copy_test_cases = not isinstance(ctx.pass_, HintBasedPass)

        folder = Path(tempfile.mkdtemp(prefix=self.TEMP_PREFIX, dir=ctx.temporary_root))
        env = TestEnvironment(
            ctx.state,
            self.order,
            self.test_script,
            folder,
            self.current_test_case,
            self.test_cases,
            should_copy_test_cases,
            ctx.pass_.transform,
            self.process_monitor.pid_queue,
        )
        future = self.worker_pool.schedule(
            _worker_process_job_wrapper, args=[self.order, env.run], timeout=self.timeout
        )
        self.jobs.append(
            Job(
                type=JobType.TRANSFORM,
                future=future,
                order=self.order,
                pass_=ctx.pass_,
                pass_id=pass_id,
                pass_user_visible_name=ctx.pass_.user_visible_name(),
                pass_job_counter=ctx.pass_job_counter,
                start_time=time.monotonic(),
                timeout=self.timeout,
                temporary_folder=folder,
            )
        )
        assert self.order not in ctx.running_transform_order_to_state
        ctx.running_transform_order_to_state[self.order] = ctx.state

        ctx.pass_job_counter += 1
        self.order += 1
        ctx.state = ctx.pass_.advance(self.current_test_case, ctx.state)

    def schedule_fold(self, folding_state: FoldingStateIn) -> None:
        assert self.interleaving

        should_copy_test_cases = False  # the fold transform creates the files itself
        folder = Path(tempfile.mkdtemp(prefix=self.TEMP_PREFIX + 'folding-'))
        env = TestEnvironment(
            folding_state,
            self.order,
            self.test_script,
            folder,
            self.current_test_case,
            self.test_cases,
            should_copy_test_cases,
            FoldingManager.transform,
            self.process_monitor.pid_queue,
        )
        future = self.worker_pool.schedule(
            _worker_process_job_wrapper, args=[self.order, env.run], timeout=self.timeout
        )
        self.jobs.append(
            Job(
                type=JobType.TRANSFORM,
                future=future,
                order=self.order,
                pass_=None,
                pass_id=None,
                pass_user_visible_name='Folding',
                pass_job_counter=None,
                start_time=time.monotonic(),
                timeout=self.timeout,
                temporary_folder=folder,
            )
        )

        self.order += 1

    def advance_while_subset_of_succeeded(self, ctx: PassContext) -> None:
        if not isinstance(ctx.state, HintState):
            return
        while ctx.state is not None and any(ctx.state.subset_of(s) for s in ctx.current_batch_succeeded_states):
            ctx.state = ctx.pass_.advance(self.current_test_case, ctx.state)

    def get_fully_initialized_hint_types(self) -> set[bytes]:
        ready_types = set()
        missing_types = set()
        for ctx in self.pass_contexts:
            if not isinstance(ctx.pass_, HintBasedPass):
                continue
            types = ctx.pass_.output_hint_types()
            ready = ctx.stage == PassStage.ENUMERATING
            (ready_types if ready else missing_types).update(types)
        return ready_types - missing_types


def override_tmpdir_env(old_env: Mapping[str, str], tmp_override: Path) -> Mapping[str, str]:
    new_env = dict(old_env)
    for var in ('TMPDIR', 'TEMP', 'TMP'):
        new_env[var] = str(tmp_override)
    return new_env


def _init_worker_process(initializers: list[Callable]) -> None:
    # By default (when not executing a job), terminate a worker immediately on relevant signals. Raising an exception at
    # unexpected times, especially inside multiprocessing internals, can put the worker into a bad state.
    sigmonitor.init(sigmonitor.Mode.QUICK_EXIT)
    for func in initializers:
        func()


def _worker_process_job_wrapper(job_order: int, func: Callable) -> Any:
    # Handle signals as exceptions within the job, to let the code do proper resource deallocation (like terminating
    # subprocesses), but once the func returns after a signal was triggered, terminate the worker.
    with sigmonitor.scoped_mode(sigmonitor.Mode.RAISE_EXCEPTION):
        # Annotate each log message with the job order, for the log recipient in the main process to discard logs coming
        # from canceled jobs.
        with mplogging.worker_process_job_wrapper(job_order):
            return func()
