Source code for crosshair.tracers

"""Provide access to and overrides for functions as they are called."""

import ctypes
import dataclasses
import dis
import sys
from collections import defaultdict
from sys import _getframe
from types import CodeType
from typing import (
    Any,
    Callable,
    DefaultDict,
    Dict,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    TypeVar,
)

from _crosshair_tracers import CTracer, TraceSwap

USE_C_TRACER = True

PyObjPtr = ctypes.POINTER(ctypes.py_object)
Py_IncRef = ctypes.pythonapi.Py_IncRef
Py_DecRef = ctypes.pythonapi.Py_DecRef


_debug_header: Tuple[Tuple[str, type], ...] = (
    (
        ("_ob_next", PyObjPtr),
        ("_ob_prev", PyObjPtr),
    )
    if sys.flags.debug
    else ()
)


from _crosshair_tracers import frame_stack_read, frame_stack_write

CALL_FUNCTION = dis.opmap.get("CALL_FUNCTION", 256)
CALL_FUNCTION_KW = dis.opmap.get("CALL_FUNCTION_KW", 256)
CALL_FUNCTION_EX = dis.opmap.get("CALL_FUNCTION_EX", 256)
CALL_METHOD = dis.opmap.get("CALL_METHOD", 256)
BUILD_TUPLE_UNPACK_WITH_CALL = dis.opmap.get("BUILD_TUPLE_UNPACK_WITH_CALL", 256)
CALL = dis.opmap.get("CALL", 256)
CALL_KW = dis.opmap.get("CALL_KW", 256)


class RawNullPointer:
    pass


NULL_POINTER = RawNullPointer()


def handle_build_tuple_unpack_with_call(frame) -> Optional[Tuple[int, Callable]]:
    idx = -(
        frame.f_code.co_code[frame.f_lasti + 1] + 1
    )  # TODO: account for EXTENDED_ARG, here and elsewhere
    try:
        return (idx, frame_stack_read(frame, idx))
    except ValueError:
        return (idx, NULL_POINTER)  # type: ignore


def handle_call_3_11(frame) -> Optional[Tuple[int, Callable]]:
    idx = -(frame.f_code.co_code[frame.f_lasti + 1] + 1)
    try:
        ret = (idx - 1, frame_stack_read(frame, idx - 1))
    except ValueError:
        ret = (idx, frame_stack_read(frame, idx))
    return ret


def handle_call_3_13(frame) -> Optional[Tuple[int, Callable]]:
    idx = -(frame.f_code.co_code[frame.f_lasti + 1] + 2)
    return (idx, frame_stack_read(frame, idx))


def handle_call_function(frame) -> Optional[Tuple[int, Callable]]:
    idx = -(frame.f_code.co_code[frame.f_lasti + 1] + 1)
    try:
        return (idx, frame_stack_read(frame, idx))
    except ValueError:
        return (idx, NULL_POINTER)  # type: ignore


def handle_call_function_kw(frame) -> Optional[Tuple[int, Callable]]:
    idx = -(frame.f_code.co_code[frame.f_lasti + 1] + 2)
    try:
        return (idx, frame_stack_read(frame, idx))
    except ValueError:
        return (idx, NULL_POINTER)  # type: ignore


def handle_call_kw(frame) -> Optional[Tuple[int, Callable]]:
    idx = -(frame.f_code.co_code[frame.f_lasti + 1] + 3)
    return (idx, frame_stack_read(frame, idx))


def handle_call_function_ex_3_6(frame) -> Optional[Tuple[int, Callable]]:
    idx = -((frame.f_code.co_code[frame.f_lasti + 1] & 1) + 2)
    try:
        # print("fetch @ idx", idx)
        return (idx, frame_stack_read(frame, idx))
    except ValueError:
        return (idx, NULL_POINTER)  # type: ignore


def handle_call_function_ex_3_13(frame) -> Optional[Tuple[int, Callable]]:
    idx = -((frame.f_code.co_code[frame.f_lasti + 1] & 1) + 3)
    try:
        return (idx, frame_stack_read(frame, idx))
    except ValueError:
        return (idx, NULL_POINTER)  # type: ignore


def handle_call_method(frame) -> Optional[Tuple[int, Callable]]:
    idx = -(frame.f_code.co_code[frame.f_lasti + 1] + 2)
    try:
        return (idx, frame_stack_read(frame, idx))
    except ValueError:
        # not a sucessful method lookup; no call happens here
        idx += 1
        return (idx, frame_stack_read(frame, idx))


_CALL_HANDLERS: Dict[int, Callable[[object], Optional[Tuple[int, Callable]]]] = {
    BUILD_TUPLE_UNPACK_WITH_CALL: handle_build_tuple_unpack_with_call,
    CALL: handle_call_3_13 if sys.version_info >= (3, 13) else handle_call_3_11,
    CALL_KW: handle_call_kw,
    CALL_FUNCTION: handle_call_function,
    CALL_FUNCTION_KW: handle_call_function_kw,
    CALL_FUNCTION_EX: handle_call_function_ex_3_13
    if sys.version_info >= (3, 13)
    else handle_call_function_ex_3_6,
    CALL_METHOD: handle_call_method,
}


class Untracable:
    pass


class TraceException(BaseException):
    # We extend BaseException instead of Exception, because it won't be considered a
    # user-level exception by CrossHair. (this is for internal assertions)
    pass


class TracingModule:
    # override these!:
    opcodes_wanted = frozenset(_CALL_HANDLERS.keys())

    def __call__(self, frame, codeobj, opcodenum):
        return self.trace_op(frame, codeobj, opcodenum)

    def trace_op(self, frame, codeobj, opcodenum):
        if is_tracing():
            raise TraceException
        call_handler = _CALL_HANDLERS.get(opcodenum)
        if not call_handler:
            return None
        maybe_call_info = call_handler(frame)
        if maybe_call_info is None:
            # TODO: this cannot happen?
            return None
        (fn_idx, target) = maybe_call_info
        binding_target = None

        try:
            __self = object.__getattribute__(target, "__self__")
        except AttributeError:
            pass
        else:
            try:
                __func = object.__getattribute__(target, "__func__")
            except AttributeError:
                # The implementation is likely in C.
                # Attempt to get a function via the type:
                typelevel_target = getattr(type(__self), target.__name__, None)
                if typelevel_target is not None:
                    binding_target = __self
                    target = typelevel_target
            else:
                binding_target = __self
                target = __func

        if isinstance(target, Untracable):
            return None
        replacement = self.trace_call(frame, target, binding_target)
        if replacement is not None:
            target = replacement
            if binding_target is None:
                overwrite_target = target
            else:
                # re-bind a function object if it was originally a bound method
                # on the stack.
                overwrite_target = target.__get__(binding_target, binding_target.__class__)  # type: ignore
            frame_stack_write(frame, fn_idx, overwrite_target)
        return None

    def trace_call(
        self,
        frame: Any,
        fn: Callable,
        binding_target: object,
    ) -> Optional[Callable]:
        return None


TracerConfig = Tuple[Tuple[TracingModule, ...], DefaultDict[int, List[TracingModule]]]


class PatchingModule(TracingModule):
    """Hot-swap functions on the interpreter stack."""

    def __init__(
        self,
        overrides: Optional[Dict[Callable, Callable]] = None,
        fn_type_overrides: Optional[Dict[type, Callable]] = None,
    ):
        # NOTE: you might imagine that we should use an IdKeyedDict for self.overrides
        # However, some builtin bound methods have no way to get identity for their code:
        #
        # >>> float.fromhex is float.fromhex
        # False
        #
        self.overrides: Dict[Callable, Callable] = {}
        self.nextfn: Dict[object, Callable] = {}  # code object to next, lower layer
        if overrides:
            self.add(overrides)
        self.fn_type_overrides = fn_type_overrides or {}

    def add(self, new_overrides: Dict[Callable, Callable]):
        for orig, new_override in new_overrides.items():
            prev_override = self.overrides.get(orig, orig)
            self.nextfn[(new_override.__code__, orig)] = prev_override
            self.overrides[orig] = new_override

    def pop(self, overrides: Dict[Callable, Callable]):
        for orig, the_override in overrides.items():
            assert self.overrides[orig] is the_override
            self.overrides[orig] = self.nextfn.pop((the_override.__code__, orig))

    def __repr__(self):
        return f"PatchingModule({list(self.overrides.keys())})"

    def trace_call(
        self,
        frame: Any,
        fn: Callable,
        binding_target: object,
    ) -> Optional[Callable]:
        try:
            target = self.overrides.get(fn)
        except TypeError as exc:
            # The function is not hashable.
            # This can happen when attempting to invoke a non-function,
            # or possibly it is a method on a non-hashable object that was
            # not properly unbound by `TracingModule.trace_op`.
            return None
        if target is None:
            fn_type_override = self.fn_type_overrides.get(type(fn))
            if fn_type_override is None:
                return None
            else:
                target = fn_type_override(fn)
        caller_code = frame.f_code
        if caller_code.co_name == "_crosshair_wrapper":
            return None
        target_name = getattr(fn, "__name__", "")
        if target_name.endswith("_crosshair_wrapper"):
            return None
        nextfn = self.nextfn.get((caller_code, fn))
        if nextfn is not None:
            if nextfn is fn:
                return None
            return nextfn
        return target


class CompositeTracer:
    def __init__(self):
        self.ctracer = CTracer()
        self.patching_module = PatchingModule()

    def get_modules(self) -> List[TracingModule]:
        return self.ctracer.get_modules()

    def set_postop_callback(self, callback, frame):
        self.ctracer.push_postop_callback(frame, callback)

    if sys.version_info >= (3, 12):

        def push_module(self, module: TracingModule) -> None:
            sys.monitoring.restart_events()
            self.ctracer.push_module(module)

        def pop_config(self, module: TracingModule) -> None:
            self.ctracer.pop_module(module)

        def __enter__(self) -> object:
            self.ctracer.push_module(self.patching_module)
            tool_id = 4
            sys.monitoring.use_tool_id(tool_id, "CrossHair")
            sys.monitoring.register_callback(
                tool_id,
                sys.monitoring.events.INSTRUCTION,
                self.ctracer.instruction_monitor,
            )
            sys.monitoring.set_events(tool_id, sys.monitoring.events.INSTRUCTION)
            sys.monitoring.restart_events()
            self.ctracer.start()
            assert not self.ctracer.is_handling()
            assert self.ctracer.enabled()
            return self

        def __exit__(self, _etype, exc, _etb):
            tool_id = 4
            sys.monitoring.register_callback(
                tool_id, sys.monitoring.events.INSTRUCTION, None
            )
            sys.monitoring.free_tool_id(tool_id)
            self.ctracer.stop()
            self.ctracer.pop_module(self.patching_module)

        def trace_caller(self):
            pass

    else:

        def push_module(self, module: TracingModule) -> None:
            self.ctracer.push_module(module)

        def pop_config(self, module: TracingModule) -> None:
            self.ctracer.pop_module(module)

        def __enter__(self) -> object:
            self.old_traceobj = sys.gettrace()
            # Enable opcode tracing before setting trace function, since Python 3.12; see https://github.com/python/cpython/issues/103615
            sys._getframe().f_trace_opcodes = True
            self.ctracer.push_module(self.patching_module)
            self.ctracer.start()
            assert not self.ctracer.is_handling()
            assert self.ctracer.enabled()
            return self

        def __exit__(self, _etype, exc, _etb):
            self.ctracer.stop()
            self.ctracer.pop_module(self.patching_module)
            sys.settrace(self.old_traceobj)

        def trace_caller(self):
            # Frame 0 is the trace_caller method itself
            # Frame 1 is the frame requesting its caller be traced
            # Frame 2 is the caller that we're targeting
            frame = _getframe(2)
            frame.f_trace_opcodes = True
            frame.f_trace = self.ctracer


# We expect the composite tracer to be used like a singleton.
# (you can only have one tracer active at a time anyway)
# TODO: Thread-unsafe global. Make this a thread local?
COMPOSITE_TRACER = CompositeTracer()


@dataclasses.dataclass
class CoverageResult:
    offsets_covered: Set[int]
    all_offsets: Set[int]
    opcode_coverage: float


class CoverageTracingModule(TracingModule):
    opcodes_wanted = frozenset(i for i in range(256))

    def __init__(self, *fns: Callable):
        assert not is_tracing()
        self.fns = fns
        self.codeobjects = set(fn.__code__ for fn in fns)
        self.opcode_offsets = {
            code: set(i.offset for i in dis.get_instructions(code))
            for code in self.codeobjects
        }
        self.offsets_seen: Dict[CodeType, Set[int]] = defaultdict(set)

    def trace_op(self, frame, codeobj, opcodenum):
        code = frame.f_code
        if code not in self.codeobjects:
            return
        lasti = frame.f_lasti
        assert lasti in self.opcode_offsets[code]
        self.offsets_seen[code].add(lasti)

    def get_results(self, fn: Optional[Callable] = None):
        if fn is None:
            assert len(self.fns) == 1
            fn = self.fns[0]
        possible = self.opcode_offsets[fn.__code__]
        seen = self.offsets_seen[fn.__code__]
        return CoverageResult(
            offsets_covered=seen,
            all_offsets=possible,
            opcode_coverage=len(seen) / len(possible),
        )


class PushedModule:
    def __init__(self, module: TracingModule):
        self.module = module

    def __enter__(self):
        COMPOSITE_TRACER.push_module(self.module)

    def __exit__(self, *a):
        COMPOSITE_TRACER.pop_config(self.module)
        return False


def is_tracing():
    return COMPOSITE_TRACER.ctracer.enabled()


[docs] def NoTracing(): return TraceSwap(COMPOSITE_TRACER.ctracer, True)
[docs] def ResumedTracing(): return TraceSwap(COMPOSITE_TRACER.ctracer, False)
_T = TypeVar("_T") def tracing_iter(itr: Iterable[_T]) -> Iterable[_T]: """Selectively re-enable tracing only during iteration.""" assert not is_tracing() # TODO: should we protect his line with ResumedTracing() too?: itr = iter(itr) while True: try: with ResumedTracing(): value = next(itr) except StopIteration: return yield value