Source code for crosshair.tracers

"""Provide access to and overrides for functions as they are called."""

import ctypes
import dataclasses
import dis
import os
import sys
import types
from collections import defaultdict
from sys import _getframe
from types import CodeType
from typing import (
    Any,
    Callable,
    DefaultDict,
    Dict,
    FrozenSet,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    TypeVar,
)

from _crosshair_tracers import (  # type: ignore
    CTracer,
    TraceSwap,
    call_stack_info,
    normalize_call_target,
    supported_opcodes,
)

CROSSHAIR_EXTRA_ASSERTS = os.environ.get("CROSSHAIR_EXTRA_ASSERTS", "0") == "1"

SYS_MONITORING_TOOL_ID = 4
USE_C_TRACER = True

PyObjPtr = ctypes.POINTER(ctypes.py_object)
Py_IncRef = ctypes.pythonapi.Py_IncRef
Py_DecRef = ctypes.pythonapi.Py_DecRef


_debug_header: Tuple[Tuple[str, type], ...] = (
    (
        ("_ob_next", PyObjPtr),
        ("_ob_prev", PyObjPtr),
    )
    if sys.flags.debug
    else ()
)


from _crosshair_tracers import frame_stack_read, frame_stack_write

CALL_FUNCTION = dis.opmap.get("CALL_FUNCTION", 256)
CALL_FUNCTION_KW = dis.opmap.get("CALL_FUNCTION_KW", 256)  # Removed as of 3.11
CALL_FUNCTION_EX = dis.opmap.get("CALL_FUNCTION_EX", 256)
CALL_METHOD = dis.opmap.get("CALL_METHOD", 256)
BUILD_TUPLE_UNPACK_WITH_CALL = dis.opmap.get("BUILD_TUPLE_UNPACK_WITH_CALL", 256)
CALL = dis.opmap.get("CALL", 256)
CALL_KW = dis.opmap.get("CALL_KW", 256)  # New in 3.13


class RawNullPointer:
    pass


NULL_POINTER = RawNullPointer()
_CALL_OPCODES = frozenset(
    [
        BUILD_TUPLE_UNPACK_WITH_CALL,
        CALL,
        CALL_KW,
        CALL_FUNCTION,
        CALL_FUNCTION_KW,
        CALL_FUNCTION_EX,
        CALL_METHOD,
    ]
)


class Untracable:
    pass


class TraceException(BaseException):
    # We extend BaseException instead of Exception, because it won't be considered a
    # user-level exception by CrossHair. (this is for internal assertions)
    pass


def check_opcode_support(opcodes: FrozenSet[int]):
    if sys.version_info < (3, 12):
        return
    missing_opcodes = opcodes - set(supported_opcodes())
    if missing_opcodes:
        raise TraceException(
            f"The C-level tracer does not support these opcodes: {','.join(map(dis.opname.__getitem__, missing_opcodes))}"
        )


check_opcode_support(_CALL_OPCODES)


wrapper_descriptor_type = type(int.__bool__)
assert str(wrapper_descriptor_type) == "<class 'wrapper_descriptor'>"

_NORMAL_CALLABLE_TYPES = (
    type,
    types.FunctionType,  #': <class 'function'>,
    types.MethodDescriptorType,  #': <class 'method_descriptor'>,
    types.MethodType,  #': <class 'method'>,
    types.MethodWrapperType,  #': <class 'method-wrapper'>}
    types.BuiltinFunctionType,  #': <class 'builtin_function_or_method'>,
    types.BuiltinMethodType,  #: <class 'builtin_function_or_method'>,
    types.ClassMethodDescriptorType,  #': <class 'classmethod_descriptor'>,
    wrapper_descriptor_type,
)
_SELFLESS_CALLABLE_TYPES = (
    type,
    types.FunctionType,
    types.MethodDescriptorType,
    types.ClassMethodDescriptorType,
    wrapper_descriptor_type,
)


class TracingModule:
    # override these!:
    opcodes_wanted = _CALL_OPCODES

    def __call__(self, frame, codeobj, opcodenum):
        return self.trace_op(frame, codeobj, opcodenum)

    def trace_op(self, frame, codeobj, opcodenum):
        if is_tracing():
            raise TraceException
        info = call_stack_info(frame, opcodenum)
        if info is None:
            return None
        (fn_idx, target, kwargs_idx) = info
        if target is None:
            target = NULL_POINTER
        target, binding_target = normalize_call_target(
            target,
            _SELFLESS_CALLABLE_TYPES,
            _NORMAL_CALLABLE_TYPES,
        )

        if kwargs_idx is not None:
            try:
                kwargs_dict = frame_stack_read(frame, kwargs_idx)
            except ValueError:
                pass
            else:
                replacement_kwargs = {}
                for key, val in kwargs_dict.items():
                    if isinstance(key, str):
                        replacement_kwargs[key] = val
                        continue
                    # circular import:
                    from crosshair.libimpl.builtinslib import AnySymbolicStr

                    if isinstance(key, AnySymbolicStr):
                        # NOTE: We need to ensure symbolic strings don't need tracing for realization
                        replacement_kwargs[key.__ch_realize__()] = val
                    else:
                        raise TypeError("keywords must be strings")
                frame_stack_write(frame, kwargs_idx, replacement_kwargs)

        if isinstance(target, Untracable):
            return None
        replacement = self.trace_call(frame, target, binding_target)
        if replacement is not None:
            target = replacement
            if binding_target is None:
                overwrite_target = target
            else:
                # re-bind a function object if it was originally a bound method
                # on the stack.
                overwrite_target = target.__get__(binding_target, binding_target.__class__)  # type: ignore
            frame_stack_write(frame, fn_idx, overwrite_target)
        return None

    def trace_call(
        self,
        frame: Any,
        fn: Callable,
        binding_target: object,
    ) -> Optional[Callable]:
        return None


TracerConfig = Tuple[Tuple[TracingModule, ...], DefaultDict[int, List[TracingModule]]]


class PatchingModule(TracingModule):
    """Hot-swap functions on the interpreter stack."""

    def __init__(
        self,
        overrides: Optional[Dict[Callable, Callable]] = None,
    ):
        # NOTE: you might imagine that we should use an IdKeyedDict for self.overrides
        # However, some builtin bound methods have no way to get identity for their code:
        #
        # >>> float.fromhex is float.fromhex
        # False
        #
        self.overrides: Dict[Callable, Callable] = {}
        self.nextfn: Dict[object, Callable] = {}  # code object to next, lower layer
        if overrides:
            self.add(overrides)

    def add(self, new_overrides: Dict[Callable, Callable]):
        for orig, new_override in new_overrides.items():
            prev_override = self.overrides.get(orig, orig)
            assert (
                prev_override is not new_override
            ), f"Function patch {new_override} has already been applied"
            self.nextfn[(new_override.__code__, orig)] = prev_override
            self.overrides[orig] = new_override

    def pop(self, overrides: Dict[Callable, Callable]):
        for orig, the_override in overrides.items():
            assert self.overrides[orig] is the_override
            self.overrides[orig] = self.nextfn.pop((the_override.__code__, orig))

    def __repr__(self):
        return f"PatchingModule({list(self.overrides.keys())})"

    def trace_call(
        self,
        frame: Any,
        fn: Callable,
        binding_target: object,
    ) -> Optional[Callable]:
        try:
            target = self.overrides.get(fn)
        except TypeError as exc:
            # The function is not hashable.
            # This can happen when attempting to invoke a non-function,
            # or possibly it is a method on a non-hashable object that was
            # not properly unbound by `TracingModule.trace_op`.
            return None
        if target is None:
            return None
        caller_code = frame.f_code
        if caller_code.co_name == "_crosshair_wrapper":
            return None
        target_name = getattr(fn, "__name__", "")
        if target_name.endswith("_crosshair_wrapper"):
            return None
        nextfn = self.nextfn.get((caller_code, fn))
        if nextfn is not None:
            if nextfn is fn:
                return None
            return nextfn
        return target


class CompositeTracer:
    def __init__(self):
        self.ctracer = CTracer()
        self.patching_module = PatchingModule()

    def get_modules(self) -> List[TracingModule]:
        return self.ctracer.get_modules()

    def set_postop_callback(self, callback, frame):
        self.ctracer.push_postop_callback(frame, callback)

    if sys.version_info >= (3, 12):

        def push_module(self, module: TracingModule) -> None:
            sys.monitoring.restart_events()
            self.ctracer.push_module(module)

        def pop_config(self, module: TracingModule) -> None:
            self.ctracer.pop_module(module)

        def __enter__(self) -> object:
            self.ctracer.push_module(self.patching_module)
            tool_id = SYS_MONITORING_TOOL_ID
            sys.monitoring.use_tool_id(tool_id, "CrossHair")
            sys.monitoring.register_callback(
                tool_id,
                sys.monitoring.events.INSTRUCTION,
                self.ctracer.instruction_monitor,
            )
            sys.monitoring.set_events(tool_id, sys.monitoring.events.INSTRUCTION)
            sys.monitoring.restart_events()
            self.ctracer.start()
            assert not self.ctracer.is_handling()
            assert self.ctracer.enabled()
            return self

        def __exit__(self, _etype, exc, _etb):
            tool_id = SYS_MONITORING_TOOL_ID
            sys.monitoring.register_callback(
                tool_id, sys.monitoring.events.INSTRUCTION, None
            )
            sys.monitoring.free_tool_id(tool_id)
            self.ctracer.stop()
            self.ctracer.pop_module(self.patching_module)

        def trace_caller(self):
            pass

    else:

        def push_module(self, module: TracingModule) -> None:
            self.ctracer.push_module(module)

        def pop_config(self, module: TracingModule) -> None:
            self.ctracer.pop_module(module)

        def __enter__(self) -> object:
            self.old_traceobj = sys.gettrace()
            # Enable opcode tracing before setting trace function, since Python 3.12; see https://github.com/python/cpython/issues/103615
            sys._getframe().f_trace_opcodes = True
            self.ctracer.push_module(self.patching_module)
            self.ctracer.start()
            assert not self.ctracer.is_handling()
            assert self.ctracer.enabled()
            return self

        def __exit__(self, _etype, exc, _etb):
            self.ctracer.stop()
            self.ctracer.pop_module(self.patching_module)
            sys.settrace(self.old_traceobj)

        def trace_caller(self):
            # Frame 0 is the trace_caller method itself
            # Frame 1 is the frame requesting its caller be traced
            # Frame 2 is the caller that we're targeting
            frame = _getframe(2)
            frame.f_trace_opcodes = True
            frame.f_trace = self.ctracer


# We expect the composite tracer to be used like a singleton.
# (you can only have one tracer active at a time anyway)
# TODO: Thread-unsafe global. Make this a thread local?
COMPOSITE_TRACER = CompositeTracer()


@dataclasses.dataclass
class CoverageResult:
    offsets_covered: Set[int]
    all_offsets: Set[int]
    opcode_coverage: float


class CoverageTracingModule(TracingModule):
    opcodes_wanted = frozenset(i for i in range(256))

    # TODO: this needs to be moved into a separate kind of monitor to
    # support threading (sys.monitoring probes are global)

    def __init__(self, *fns: Callable):
        assert not is_tracing()
        self.fns = fns
        self.codeobjects = set(fn.__code__ for fn in fns)
        self.opcode_offsets = {
            code: set(i.offset for i in dis.get_instructions(code))
            for code in self.codeobjects
        }
        self.offsets_seen: Dict[CodeType, Set[int]] = defaultdict(set)

    def trace_op(self, frame, codeobj, opcodenum):
        code = frame.f_code
        if code not in self.codeobjects:
            return
        lasti = frame.f_lasti
        assert lasti in self.opcode_offsets[code]
        self.offsets_seen[code].add(lasti)

    def get_results(self, fn: Optional[Callable] = None):
        if fn is None:
            assert len(self.fns) == 1
            fn = self.fns[0]
        possible = self.opcode_offsets[fn.__code__]
        seen = self.offsets_seen[fn.__code__]
        return CoverageResult(
            offsets_covered=seen,
            all_offsets=possible,
            opcode_coverage=len(seen) / len(possible),
        )


class PushedModule:
    def __init__(self, module: TracingModule):
        self.module = module

    def __enter__(self):
        COMPOSITE_TRACER.push_module(self.module)

    def __exit__(self, *a):
        COMPOSITE_TRACER.pop_config(self.module)
        return False


def is_tracing():
    return COMPOSITE_TRACER.ctracer.enabled()


[docs] def NoTracing(): return TraceSwap(COMPOSITE_TRACER.ctracer, True)
if CROSSHAIR_EXTRA_ASSERTS: def ResumedTracing(): if COMPOSITE_TRACER.ctracer.is_handling(): raise TraceException("Cannot resume tracing while opcode handling") return TraceSwap(COMPOSITE_TRACER.ctracer, False) else:
[docs] def ResumedTracing(): return TraceSwap(COMPOSITE_TRACER.ctracer, False)
_T = TypeVar("_T") def tracing_iter(itr: Iterable[_T]) -> Iterable[_T]: """Selectively re-enable tracing only during iteration.""" assert not is_tracing() # TODO: should we protect his line with ResumedTracing() too?: itr = iter(itr) while True: try: with ResumedTracing(): value = next(itr) except StopIteration: return yield value