Source code for crosshair.core

# TODO: drop to PDB option
# TODO: detect problems with backslashes in docstrings

# *** Not prioritized for v0 ***
# TODO: increase test coverage: TypeVar('T', int, str) vs bounded type vars
# TODO: consider raises conditions (guaranteed to raise, guaranteed to not raise?)
# TODO: precondition strengthening ban (Subclass constraint rule)
# TODO: mutating symbolic Callables?
# TODO: contracts on the contracts of function and object inputs/outputs?

import enum
import functools
import inspect
import itertools
import linecache
import os.path
import sys
import time
import traceback
import types
import typing
from collections import ChainMap, defaultdict, deque
from contextlib import ExitStack
from dataclasses import dataclass, replace
from inspect import BoundArguments, Signature, isabstract
from typing import (

import typing_inspect  # type: ignore
import z3  # type: ignore

from crosshair import dynamic_typing
from crosshair.codeconfig import collect_options
from crosshair.condition_parser import (
from crosshair.copyext import CopyMode, deepcopyext
from crosshair.enforce import (
from crosshair.fnutil import (
from crosshair.options import DEFAULT_OPTIONS, AnalysisOptions, AnalysisOptionSet
from crosshair.register_contract import get_contract
from crosshair.statespace import (
from crosshair.tracers import (
from crosshair.type_repo import get_subclass_map
from crosshair.util import (

_MISSING = object()

_OPCODE_PATCHES: List[TracingModule] = []

_PATCH_REGISTRATIONS: Dict[Callable, Callable] = {}
_PATCH_FN_TYPE_REGISTRATIONS: Dict[type, Callable] = {}

class Patched(TracingModule):
    def __enter__(self):
        COMPOSITE_TRACER.patching_module.fn_type_overrides = (
        if len(_OPCODE_PATCHES) == 0:
            raise CrosshairInternal("Opcode patches haven't been loaded yet.")
        for module in _OPCODE_PATCHES:
        self.pushed = _OPCODE_PATCHES[:]
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        for module in reversed(self.pushed):
        COMPOSITE_TRACER.patching_module.fn_type_overrides = {}
        return False

class _StandaloneStatespace(ExitStack):
    def __enter__(self) -> StateSpace:  # type: ignore
        # We explicitly don't set up contexts to enforce conditions - that's because
        # conditions involve a choice, and standalone_statespace is for testing that
        # does not require making any choices.
        space = SimpleStateSpace()
        return space

standalone_statespace = _StandaloneStatespace()

class ExceptionFilter:
    analysis: CallAnalysis
    ignore: bool = False
    ignore_with_confirmation: bool = False
    user_exc: Optional[Tuple[BaseException, traceback.StackSummary]] = None
    expected_exceptions: Tuple[Type[BaseException], ...]

    def __init__(
        self, expected_exceptions: FrozenSet[Type[BaseException]] = frozenset()
        self.expected_exceptions = (NotImplementedError,) + tuple(expected_exceptions)

    def has_user_exception(self) -> bool:
        return self.user_exc is not None

    def __enter__(self) -> "ExceptionFilter":
        return self

    def __exit__(self, exc_type, exc_value, tb) -> bool:
        with NoTracing():
            if isinstance(exc_value, (PostconditionFailed, IgnoreAttempt)):
                if isinstance(exc_value, PostconditionFailed):
                    # Postcondition : although this indicates a problem, it's with a
                    # subroutine; not this function.
                    # Usualy we want to ignore this because it will be surfaced more locally
                    # in the subroutine.
                        f"Ignoring based on internal failed post condition: {exc_value}"
                self.ignore = True
                self.analysis = CallAnalysis()
                return True
            if isinstance(exc_value, self.expected_exceptions):
                exc_type_name = type(exc_value).__name__
                debug(f"Hit expected exception: {exc_type_name}: {exc_value}")
                self.ignore = True
                self.analysis = CallAnalysis(VerificationStatus.CONFIRMED)
                return True
            if isinstance(exc_value, TypeError):
                exc_str = str(exc_value)
                if (
                    "SymbolicStr" in exc_str
                    or "SymbolicInt" in exc_str
                    or "SymbolicFloat" in exc_str
                    or "__hash__ method should return an integer" in exc_str
                    or "expected string or bytes-like object" in exc_str
                    # Ideally we'd attempt literal strings after encountering this.
                    # See
                    debug("Proxy intolerace at: ", traceback.format_exc())
                    raise CrosshairUnsupported("Detected proxy intolerance: " + exc_str)
            if isinstance(exc_value, (Exception, PreconditionFailed)):
                if isinstance(exc_value, z3.Z3Exception):
                    return False  # internal issue: re-raise
                # Most other issues are assumed to be user-facing exceptions:
                self.user_exc = (exc_value, traceback.extract_tb(sys.exc_info()[2]))
                self.analysis = CallAnalysis(VerificationStatus.REFUTED)
                return True  # suppress user-level exception
            return False  # re-raise resource and system issues

_T = TypeVar("_T")

def realize(value: Any) -> Any:
    with NoTracing():
        if hasattr(type(value), "__ch_realize__"):
            return value.__ch_realize__()  # type: ignore
            return value

def deep_realize(value: _T) -> _T:
    with NoTracing():
        return deepcopyext(value, CopyMode.REALIZE, {})

class CrossHairValue:

def normalize_pytype(typ: Type) -> Type:
    if typing_inspect.is_typevar(typ):
        # we treat type vars in the most general way possible (the bound, or as 'object')
        bound = typing_inspect.get_bound(typ)
        if bound is not None:
            return normalize_pytype(bound)
        constraints = typing_inspect.get_constraints(typ)
        if constraints:
            raise CrosshairUnsupported
            # TODO: not easy; interpreting as a Union allows the type to be
            # instantiated differently in different places. So, this doesn't work:
            # return Union.__getitem__(tuple(map(normalize_pytype, constraints)))
        return object
    if typ is Any:
        # The distinction between any and object is for type checking, crosshair treats them the same
        return object
    if typ is Type:
        return type
    return typ

def python_type(o: object) -> Type:
    if is_tracing():
        raise CrosshairInternal("should not be tracing while getting pytype")
    if hasattr(type(o), "__ch_pytype__"):
        obj_type = o.__ch_pytype__()  # type: ignore
        if hasattr(obj_type, "__origin__"):
            obj_type = obj_type.__origin__
        return obj_type
        return type(o)

def with_realized_args(fn: Callable) -> Callable:
    def realizer(*a, **kw):
        with NoTracing():
            a = map(realize, a)
            kw = {k: realize(v) for (k, v) in kw.items()}
            return fn(*a, **kw)

    functools.update_wrapper(realizer, fn)
    return realizer

def with_symbolic_self(symbolic_cls: Type, fn: Callable):
    def call_with_symbolic_self(self, *args, **kwargs):
        with NoTracing():
            if isinstance(self, symbolic_cls):
                # Handles (unlikely!) cases like str.isspace(<symbolic string>)
                target_fn = getattr(symbolic_cls, fn.__name__)
            elif any(isinstance(a, CrossHairValue) for a in args) or (
                kwargs and any(isinstance(a, CrossHairValue) for a in kwargs.values())
                self = symbolic_cls._smt_promote_literal(self)
                target_fn = getattr(symbolic_cls, fn.__name__)
                args = map(realize, args)
                kwargs = {k: realize(v) for (k, v) in kwargs.items()}
                target_fn = fn
        return target_fn(self, *args, **kwargs)

    functools.update_wrapper(call_with_symbolic_self, fn)
    return call_with_symbolic_self

def with_uniform_probabilities(
    collection: Collection[_T],
) -> List[Tuple[_T, float]]:
    count = len(collection)
    return [(item, 1.0 / (count - idx)) for (idx, item) in enumerate(collection)]

def iter_types(from_type: Type, include_abstract: bool) -> List[Tuple[Type, float]]:
    types = []
    queue = deque([from_type])
    subclassmap = get_subclass_map()
    while queue:
        cur = queue.popleft()
        if include_abstract or not isabstract(cur):
    ret = with_uniform_probabilities(types)
    if ret and ret[0][0] is from_type:
        # Bias a little extra for the base type;
        # e.g. pick `int` more readily than the subclasses of int:
        first_probability = ret[0][1]
        ret[0] = (from_type, (first_probability + 3.0) / 4.0)
    return ret

def choose_type(space: StateSpace, from_type: Type, varname: str) -> Optional[Type]:
    pairs = iter_types(from_type, include_abstract=False)
    if not pairs:
        return None
    for typ, probability_true in pairs:
        # true_probability=1.0 does not guarantee selection
        # (in particular, when the true path is exhausted)
        if probability_true == 1.0:
            return typ
        if space.smt_fork(
            return typ
    raise CrosshairInternal

def get_constructor_signature(cls: Type) -> Optional[inspect.Signature]:
    # pydantic sets __signature__ on the class, so we look for that as well as on
    # __init__ (see
    if hasattr(cls, "__signature__"):
        sig = resolve_signature(cls)
        if isinstance(sig, inspect.Signature):
            return sig
    new_fn = cls.__new__
    sig = resolve_signature(new_fn)
    # TODO: merge the type signatures of __init__ and __new__, pulling the
    # most specific types from each.
    # Fall back to __init__ if we don't have types:
    if isinstance(sig, str) or all(
        p.annotation is Signature.empty for p in sig.parameters.values()
        init_fn = cls.__init__
        if init_fn is not object.__init__:
            sig = resolve_signature(init_fn)
            return inspect.Signature([])
    if isinstance(sig, inspect.Signature):
        # strip first argument
        newparams = list(sig.parameters.values())[1:]
        return sig.replace(parameters=newparams)
    return None

_TYPE_HINTS = IdKeyedDict()

def proxy_for_class(typ: Type, varname: str) -> object:
    data_members = _TYPE_HINTS.get(typ, None)
    if data_members is None:
        data_members = get_type_hints(typ)
        _TYPE_HINTS[typ] = data_members

    if sys.version_info >= (3, 8) and type(typ) is typing._TypedDictMeta:  # type: ignore
        # Handling for TypedDict
        optional_keys = getattr(typ, "__optional_keys__", ())
        keys = (
            for k in data_members.keys()
            if k not in optional_keys or context_statespace().smt_fork()
        return {k: proxy_for_type(data_members[k], varname + "." + k) for k in keys}

    constructor_sig = get_constructor_signature(typ)
    if constructor_sig is None:
        raise CrosshairUnsupported(
            f"unable to create concrete instance of {typ} due to bad constructor"
    # TODO: use dynamic_typing.get_bindings_from_type_arguments(typ) to instantiate
    # type variables in `constructor_sig`
    args = gen_args(constructor_sig)
    typename = name_of_type(typ)
        with ResumedTracing():
            obj = WithEnforcement(typ)(*args.args, **args.kwargs)
    except (PreconditionFailed, PostconditionFailed):
        # preconditions can be invalidated when the __init__ method has preconditions.
        # postconditions can be invalidated when the class has invariants.
        raise IgnoreAttempt
    except Exception as e:
        debug("Root-cause type construction traceback:", test_stack(e.__traceback__))
        raise CrosshairUnsupported(
            f"error constructing {typename} instance: {name_of_type(type(e))}: {e}",
        ) from e

    debug("Proxy as a concrete instance of", typename)
    reprer = context_statespace().extra(LazyCreationRepr)

    def regenerate_construction_string(_):
        with NoTracing():
            realized_args = reprer.deep_realize(args)

        return f"{repr(typ)}({format_boundargs(realized_args)})"

    reprer.reprs[obj] = regenerate_construction_string

    debug("repr register lazy", hex(id(obj)), typename)
    return obj

def register_patch(entity: Callable, patch_value: Callable):
    if entity in _PATCH_REGISTRATIONS:
        raise CrosshairInternal(f"Doubly registered patch: {entity}")
    _PATCH_REGISTRATIONS[entity] = patch_value

def register_fn_type_patch(typ: type, patch_value: Callable[[Callable], Callable]):
        raise CrosshairInternal(f"Doubly registered fn type patch: {typ}")
    _PATCH_FN_TYPE_REGISTRATIONS[typ] = patch_value

def register_opcode_patch(module: TracingModule) -> None:

[docs] class SymbolicFactory: """ A callable object that creates symbolic values. .. automethod:: __call__ """ def __init__(self, space: StateSpace, pytype: object, varname: str): = space self.pytype: Any = pytype self.varname = varname @overload def __call__( self, typ: Callable[..., _T], suffix: str = "", allow_subtypes: bool = True ) -> _T: ... @overload def __call__(self, typ: Any, suffix: str = "", allow_subtypes: bool = True) -> Any: ...
[docs] def __call__(self, typ, suffix: str = "", allow_subtypes: bool = True): """ Create a new symbolic value. :param typ: The corresponding Python type for the returned symbolic. :type typ: type :param suffix: A descriptive suffix used to name variable(s) in the solver. :type suffix: str :param allow_subtypes: Whether it's ok to return a subtype of given type. :type allow_subtypes: bool :returns: A new symbolic value. """ return proxy_for_type( typ, self.varname + suffix +, allow_subtypes=allow_subtypes, )
_SIMPLE_PROXIES: MutableMapping[object, Callable] = {} SymbolicCreationCallback = Union[ # Sadly Callable[] doesn't support variable arguments. Just enumerate: Callable[[SymbolicFactory], object], Callable[[SymbolicFactory, Type], object], Callable[[SymbolicFactory, Type, Type], object], Callable[[SymbolicFactory, Type, Type, Type], object], Callable[[SymbolicFactory, Type, Type, Type, Type], object], ]
[docs] def register_type(typ: Type, creator: SymbolicCreationCallback) -> None: """ Register a custom creation function to create symbolic values for a type. :param typ: The Python type (or typing annotation) to handle. :param creator: A function that takes a :class:`SymbolicFactory` instance and returns a symbolic value. When creating a parameterized type (e.g. List[int]), type parameters will be given to `creator` as additional arguments following the factory. """ assert typ is origin_of( typ ), f'Only origin types may be registered, not "{typ}": try "{origin_of(typ)}" instead.' if typ in _SIMPLE_PROXIES: raise CrosshairInternal(f'Duplicate type "{typ}" registered') _SIMPLE_PROXIES[typ] = creator
@dataclass class LazyCreationRepr: def __init__(self, *a) -> None: self.reprs = IdKeyedDict() self.repr_references: Set[ReferencedIdentifier] = set() def deep_realize(self, symbolic_val: object) -> Any: assert not is_tracing() reprs = self.reprs arg_memo: dict = {} realized_val = deepcopyext(symbolic_val, CopyMode.REALIZE, arg_memo) for orig_id, new_obj in arg_memo.items(): old_repr = reprs.inner.get(orig_id, None) if old_repr: reprs.inner[id(new_obj)] = old_repr return realized_val def eval_friendly_format( self, obj: _T, result_formatter: Callable[[_T], str] ) -> str: assert is_tracing() with NoTracing(): obj = self.deep_realize(obj) with EvalFriendlyReprContext(self.reprs) as ctx: args_string = result_formatter(obj) self.repr_references |= ctx.repr_references return ctx.cleanup(args_string) @overload def proxy_for_type( typ: Callable[..., _T], varname: str, allow_subtypes: bool = False, ) -> _T: ... @overload def proxy_for_type( typ: Any, varname: str, allow_subtypes: bool = False, ) -> Any: ... def proxy_for_type( typ: Any, varname: str, allow_subtypes: bool = False, ) -> Any: space = context_statespace() with NoTracing(): typ = normalize_pytype(typ) origin = origin_of(typ) type_args = type_args_of(typ) # special cases if isinstance(typ, type) and issubclass(typ, enum.Enum): enum_values = list(typ) # type:ignore if not enum_values: raise IgnoreAttempt("No values for enum") for enum_value in enum_values[:-1]: if space.smt_fork(desc="choose_enum_" + str(enum_value)): return enum_value return enum_values[-1] # It's easy to forget to import crosshair.core_and_libs; check: assert _SIMPLE_PROXIES, "No proxy type registrations exist" proxy_factory = _SIMPLE_PROXIES.get(origin) if proxy_factory: recursive_proxy_factory = SymbolicFactory(space, typ, varname) return proxy_factory(recursive_proxy_factory, *type_args) if hasattr(typ, "__supertype__") and typing_inspect.is_new_type(typ): return proxy_for_type(typ.__supertype__, varname, allow_subtypes) # type: ignore if allow_subtypes and typ is not object: typ = choose_type(space, typ, varname) if typ is None: # (happens if typ and all subtypes are abstract) raise IgnoreAttempt return proxy_for_class(typ, varname) _ARG_GENERATION_RENAMES: Dict[str, Callable] = {} def gen_args(sig: inspect.Signature) -> inspect.BoundArguments: if is_tracing(): raise CrosshairInternal args = sig.bind_partial() space = context_statespace() for param in sig.parameters.values(): smt_name = + space.uniq() allow_subtypes = True # For each argument, we call a special version of `proxy_for_type` that # includes the argument name in the function name. # This is nice while debugging stack traces, but also helps (e.g.) # `CoveragePathingOracle` distinguish the decisions for each argument. proxy_maker = _ARG_GENERATION_RENAMES.get( if not proxy_maker: if sys.version_info < (3, 8): proxy_maker = proxy_for_type else: proxy_maker = renamed_function(proxy_for_type, "proxy_arg_" + # type: ignore _ARG_GENERATION_RENAMES[] = proxy_maker has_annotation = param.annotation != inspect.Parameter.empty value: object if param.kind == inspect.Parameter.VAR_POSITIONAL: if has_annotation: varargs_type = List[param.annotation] # type: ignore value = proxy_maker(varargs_type, smt_name, allow_subtypes) else: value = proxy_maker(List[Any], smt_name, allow_subtypes) elif param.kind == inspect.Parameter.VAR_KEYWORD: if has_annotation: varargs_type = Dict[str, param.annotation] # type: ignore value = cast(dict, proxy_maker(varargs_type, smt_name, allow_subtypes)) # Using ** on a dict requires concrete string keys. Force # instiantiation of keys here: value = {k.__str__(): v for (k, v) in value.items()} else: value = proxy_maker(Dict[str, Any], smt_name, allow_subtypes) else: is_self = == "self" # Object parameters can be any valid subtype iff they are not the # class under test ("self"). allow_subtypes = not is_self if has_annotation: value = proxy_maker(param.annotation, smt_name, allow_subtypes) else: value = proxy_maker(cast(type, Any), smt_name, allow_subtypes) if in_debug(): debug( "created proxy for",, "as type:", name_of_type(type(value)), hex(id(value)), ) args.arguments[] = value return args def message_sort_key(m: AnalysisMessage) -> tuple: return (m.state, UNABLE_TO_REPR_TEXT not in m.message, -len(m.message)) class MessageCollector: def __init__(self): self.by_pos = {} def extend(self, messages: Iterable[AnalysisMessage]) -> None: for message in messages: self.append(message) def append(self, message: AnalysisMessage) -> None: key = (message.filename, message.line, message.column) if key in self.by_pos: self.by_pos[key] = max(self.by_pos[key], message, key=message_sort_key) else: self.by_pos[key] = message def get(self) -> List[AnalysisMessage]: return [m for (k, m) in sorted(self.by_pos.items())] class Checkable: def analyze(self) -> Iterable[AnalysisMessage]: raise NotImplementedError @dataclass class ConditionCheckable(Checkable): ctxfn: FunctionInfo options: AnalysisOptions conditions: Conditions def analyze(self) -> Iterable[AnalysisMessage]: options = self.options conditions = self.conditions debug('Analyzing postcondition: "',[0].expr_source, '"') debug( "assuming preconditions: ", ",".join([p.expr_source for p in conditions.pre]), ) options.deadline = time.monotonic() + options.per_condition_timeout with condition_parser(options.analysis_kind): analysis = analyze_calltree(options, conditions) (condition,) = if analysis.verification_status is VerificationStatus.UNKNOWN: message = "Not confirmed." analysis.messages = [ AnalysisMessage( MessageType.CANNOT_CONFIRM, message, condition.filename, condition.line, 0, "", ) ] elif analysis.verification_status is VerificationStatus.CONFIRMED: message = "Confirmed over all paths." analysis.messages = [ AnalysisMessage( MessageType.CONFIRMED, message, condition.filename, condition.line, 0, "", ) ] return analysis.messages class ClampedCheckable(Checkable): """ Clamp messages for a class method to appear on the class itself. So, even if the method is defined on a superclass, or defined dynamically (via decorator etc), we report it on the class definition instead. """ def __init__(self, checkable: Checkable, cls: type): self.checkable = checkable filename, start_line, _ = sourcelines(cls) self.cls_file = filename self.cls_start_line = start_line def analyze(self) -> Iterable[AnalysisMessage]: cls_file = self.cls_file ret = [] for message in self.checkable.analyze(): if not samefile(message.filename, cls_file): ret.append( replace(message, filename=cls_file, line=self.cls_start_line) ) else: ret.append(message) return ret @dataclass class SyntaxErrorCheckable(Checkable): messages: List[AnalysisMessage] def analyze(self) -> Iterable[AnalysisMessage]: return self.messages def run_checkables(checkables: Iterable[Checkable]) -> List[AnalysisMessage]: collector = MessageCollector() for checkable in checkables: collector.extend(checkable.analyze()) return collector.get() def analyze_any( entity: Union[types.ModuleType, type, FunctionInfo], options: AnalysisOptionSet ) -> Iterable[Checkable]: if inspect.isclass(entity): yield from analyze_class(cast(Type, entity), options) elif isinstance(entity, FunctionInfo): yield from analyze_function(entity, options) elif inspect.ismodule(entity): yield from analyze_module(cast(types.ModuleType, entity), options) else: raise CrosshairInternal("Entity type not analyzable: " + str(type(entity))) def analyze_module( module: types.ModuleType, options: AnalysisOptionSet ) -> Iterable[Checkable]: """Analyze the classes and functions defined in a module.""" for name, member in get_top_level_classes_and_functions(module): if isinstance(member, type): yield from analyze_class(member, options) else: yield from analyze_function(member, options) def analyze_class( cls: type, options: AnalysisOptionSet = AnalysisOptionSet() ) -> Iterable[Checkable]: debug("Analyzing class ", cls.__name__) analysis_kinds = DEFAULT_OPTIONS.overlay(options).analysis_kind with condition_parser(analysis_kinds) as parser: class_conditions = parser.get_class_conditions(cls) for method_name, conditions in class_conditions.methods.items(): if method_name == "__init__": # Don't check invariants on __init__. # (too often this just requires turning the invariant into a very # similar precondition) filtered_post = [ c for c in if c.condition_type != ConditionExprType.INVARIANT ] conditions = replace(conditions, post=filtered_post) if conditions.has_any(): # Note the use of getattr_static to check superclass contracts on # functions that the subclass doesn't define. ctxfn = FunctionInfo( cls, method_name, inspect.getattr_static(cls, method_name) ) for checkable in analyze_function(ctxfn, options=options): yield ClampedCheckable(checkable, cls) def analyze_function( ctxfn: Union[FunctionInfo, types.FunctionType, Callable], options: AnalysisOptionSet = AnalysisOptionSet(), ) -> List[Checkable]: if not isinstance(ctxfn, FunctionInfo): ctxfn = FunctionInfo.from_fn(ctxfn) debug("Analyzing ", pair = ctxfn.get_callable() fn_options = collect_options(pair[0]) if pair else AnalysisOptionSet() full_options = DEFAULT_OPTIONS.overlay(fn_options).overlay(options) if not full_options.enabled: debug("Skipping",, " because CrossHair is not enabled") return [] with condition_parser(full_options.analysis_kind) as parser: if not isinstance(ctxfn.context, type): conditions = parser.get_fn_conditions(ctxfn) else: class_conditions = parser.get_class_conditions(ctxfn.context) conditions = class_conditions.methods.get( if conditions is None: debug("Skipping",, " because it has no conditions") return [] syntax_messages = list(conditions.syntax_messages()) if syntax_messages: debug("Syntax error(s): ", *(m.message for m in syntax_messages)) messages = [ AnalysisMessage( MessageType.SYNTAX_ERR, syntax_message.message, syntax_message.filename, syntax_message.line_num, 0, "", ) for syntax_message in syntax_messages ] return [SyntaxErrorCheckable(messages)] return [ ConditionCheckable( ctxfn, full_options, replace(conditions, post=[post_condition]) ) for post_condition in if post_condition.evaluate is not None ] def fn_returning(values: list) -> Callable: itr = iter(values) def patched_call(*a, **kw): try: return next(itr) except StopIteration: raise NotDeterministic return patched_call class patch_to_return: def __init__(self, return_values: Dict[Callable, list]): self.patches = PatchingModule( {fn: fn_returning(values) for (fn, values) in return_values.items()} ) def __enter__(self): COMPOSITE_TRACER.push_module(self.patches) return COMPOSITE_TRACER.__enter__() def __exit__(self, *a): ret = COMPOSITE_TRACER.__exit__(*a) COMPOSITE_TRACER.pop_config(self.patches) return ret class FunctionInterps: _interpretations: Dict[Callable, List[object]] def __init__(self, *a): self._interpretations = defaultdict(list) def append_return(self, callable: Callable, retval: object) -> None: self._interpretations[callable].append(retval) def patch_string(self) -> Optional[str]: if self._interpretations: patches = ",".join( f"{method_identifier(fn)}: {eval_friendly_repr(deep_realize(vals))}" for fn, vals in self._interpretations.items() ) return f"crosshair.patch_to_return({{{patches}}})" return None class ShortCircuitingContext: def __init__(self): self.engaged = False # Note: this cache is not really for performance; it preserves # function identity so that contract enforcement can correctly detect # re-entrant contracts. self.interceptor_cache = {} def __enter__(self): assert not self.engaged self.engaged = True def __exit__(self, exc_type, exc_value, tb): assert self.engaged self.engaged = False return False def make_interceptor(self, original: Callable) -> Callable: interceptor = self.interceptor_cache.get(original) if interceptor: return interceptor # TODO: calling from_fn is wrong here subconditions = get_current_parser().get_fn_conditions( FunctionInfo.from_fn(original) ) original_name = original.__name__ if subconditions is None: self.interceptor_cache[original] = original return original sig = subconditions.sig def _crosshair_wrapper(*a: object, **kw: Dict[str, object]) -> object: space = optional_context_statespace() if (not self.engaged) or (not space): debug("Not short-circuiting", original_name, "(not engaged)") return original(*a, **kw) with NoTracing(): assert subconditions is not None # Skip function body if it has the option `specs_complete`. short_circuit = collect_options(original).specs_complete # Also skip if the function was manually registered to be skipped. contract = get_contract(original) if contract and contract.skip_body: short_circuit = True # TODO: In the future, sig should be a list of sigs and the parser # would directly return contract.sigs, so no need to fetch it here. sigs = [sig] if contract and contract.sigs: sigs = contract.sigs best_sig = sigs[0] # The function is overloaded, find the best signature. if len(sigs) > 1: new_sig = find_best_sig(sigs, *a, *kw) if new_sig: best_sig = new_sig else: # If no signature is valid, we cannot shortcircuit. short_circuit = False warn( "No signature match with the given parameters for function", original_name, ) bound = best_sig.bind(*a, **kw) return_type = consider_shortcircuit( original, best_sig, bound, subconditions, allow_interpretation=not short_circuit, ) if short_circuit: assert return_type is not None retval = proxy_for_type(return_type, "proxyreturn" + space.uniq()) space.extra(FunctionInterps).append_return(original, retval) debug("short circuit: specs complete; skipping (as uninterpreted)") return retval if return_type is not None: try: self.engaged = False debug( "short circuit: Short circuiting over a call to ", original_name ) return shortcircuit(original, best_sig, bound, return_type) finally: self.engaged = True else: debug("short circuit: Not short circuiting", original_name) return original(*a, **kw) functools.update_wrapper(_crosshair_wrapper, original) self.interceptor_cache[original] = _crosshair_wrapper return _crosshair_wrapper @dataclass class CallTreeAnalysis: messages: Sequence[AnalysisMessage] verification_status: VerificationStatus num_confirmed_paths: int = 0 def analyze_calltree( options: AnalysisOptions, conditions: Conditions ) -> CallTreeAnalysis: fn = conditions.fn debug("Begin analyze calltree ", fn.__name__) all_messages = MessageCollector() search_root = RootNode() space_exhausted = False failing_precondition: Optional[ConditionExpr] = ( conditions.pre[0] if conditions.pre else None ) failing_precondition_reason: str = "" num_confirmed_paths = 0 short_circuit = ShortCircuitingContext() top_analysis: Optional[CallAnalysis] = None enforced_conditions = EnforcedConditions( interceptor=short_circuit.make_interceptor, ) max_uninteresting_iterations = options.get_max_uninteresting_iterations() patched = Patched() # TODO clean up how encofrced conditions works here? with enforced_conditions, patched: for i in range(1, options.max_iterations + 1): start = time.monotonic() if start > options.deadline: debug("Exceeded condition timeout, stopping") break options.incr("num_paths") debug("Iteration ", i) per_path_timeout = options.get_per_path_timeout() space = StateSpace( execution_deadline=start + per_path_timeout, model_check_timeout=per_path_timeout / 2, search_root=search_root, ) try: with StateSpaceContext(space), COMPOSITE_TRACER, NoTracing(): # The real work happens here!: call_analysis = attempt_call( conditions, short_circuit, enforced_conditions ) if failing_precondition is not None: cur_precondition = call_analysis.failing_precondition if cur_precondition is None: if call_analysis.verification_status is not None: # We escaped the all the pre conditions on this try: failing_precondition = None elif ( cur_precondition.line == failing_precondition.line and call_analysis.failing_precondition_reason ): failing_precondition_reason = ( call_analysis.failing_precondition_reason ) elif cur_precondition.line > failing_precondition.line: failing_precondition = cur_precondition failing_precondition_reason = ( call_analysis.failing_precondition_reason ) except UnexploredPath: call_analysis = CallAnalysis(VerificationStatus.UNKNOWN) except IgnoreAttempt: call_analysis = CallAnalysis() status = call_analysis.verification_status if status == VerificationStatus.CONFIRMED: num_confirmed_paths += 1 top_analysis, space_exhausted = space.bubble_status(call_analysis) debug("Path tree stats", search_root.stats()) overall_status = top_analysis.verification_status if top_analysis else None debug( "Iter complete. Worst status found so far:", if overall_status else "None", ) iters_since_discovery = getattr( search_root.pathing_oracle, "iters_since_discovery" ) assert isinstance(iters_since_discovery, int) if iters_since_discovery > max_uninteresting_iterations: break if space_exhausted or overall_status == VerificationStatus.REFUTED: break top_analysis = search_root.child.get_result() if top_analysis.messages: all_messages.extend( replace( m, test_fn=fn.__qualname__,[0].expr_source ) for m in top_analysis.messages ) if top_analysis.verification_status is None: top_analysis.verification_status = VerificationStatus.UNKNOWN if failing_precondition: assert num_confirmed_paths == 0 message = f"Unable to meet precondition" if failing_precondition_reason: message += f" (possibly because {failing_precondition_reason}?)" all_messages.extend( [ AnalysisMessage( MessageType.PRE_UNSAT, message + ".", failing_precondition.filename, failing_precondition.line, 0, "", ) ] ) top_analysis = CallAnalysis(VerificationStatus.REFUTED) assert top_analysis.verification_status is not None debug( ("Exhausted" if space_exhausted else "Aborted"), "calltree search with",, "and", len(all_messages.get()), "messages.", "Number of iterations: ", i - 1, ) return CallTreeAnalysis( messages=all_messages.get(), verification_status=top_analysis.verification_status, num_confirmed_paths=num_confirmed_paths, ) PathCompeltionCallback = Callable[ [ StateSpace, BoundArguments, BoundArguments, Any, Optional[BaseException], Optional[traceback.StackSummary], ], bool, ] def explore_paths( fn: Callable[[BoundArguments], Any], sig: Signature, options: AnalysisOptions, search_root: RootNode, on_path_complete: PathCompeltionCallback = (lambda *a: False), ) -> None: """ Runs a path exploration for use cases beyond invariant checking. """ condition_start = time.monotonic() breakout = False max_uninteresting_iterations = options.get_max_uninteresting_iterations() for i in range(1, options.max_iterations + 1): debug("Iteration ", i) itr_start = time.monotonic() if itr_start > condition_start + options.per_condition_timeout: debug( "Stopping due to --per_condition_timeout=", options.per_condition_timeout, ) break per_path_timeout = options.get_per_path_timeout() space = StateSpace( execution_deadline=itr_start + per_path_timeout, model_check_timeout=per_path_timeout / 2, search_root=search_root, ) with condition_parser( options.analysis_kind ), Patched(), COMPOSITE_TRACER, NoTracing(), StateSpaceContext(space): try: pre_args = gen_args(sig) args = deepcopyext(pre_args, CopyMode.REGULAR, {}) ret: object = None user_exc: Optional[BaseException] = None user_exc_stack: Optional[traceback.StackSummary] = None with ExceptionFilter() as efilter, ResumedTracing(): ret = fn(args) if efilter.user_exc: if isinstance(efilter.user_exc[0], NotDeterministic): raise NotDeterministic else: user_exc, user_exc_stack = efilter.user_exc with ResumedTracing(): breakout = on_path_complete( space, pre_args, args, ret, user_exc, user_exc_stack ) verification_status = VerificationStatus.CONFIRMED except IgnoreAttempt: verification_status = None except UnexploredPath: verification_status = VerificationStatus.UNKNOWN debug("Verification status:", verification_status) _analysis, exhausted = space.bubble_status( CallAnalysis(verification_status) ) debug("Path tree stats", search_root.stats()) if breakout: break if exhausted: debug("Stopping due to path exhaustion") break if max_uninteresting_iterations != sys.maxsize: iters_since_discovery = getattr( search_root.pathing_oracle, "iters_since_discovery" ) assert isinstance(iters_since_discovery, int) debug("iters_since_discovery", iters_since_discovery) if iters_since_discovery > max_uninteresting_iterations: debug( "Stopping due to --max_uninteresting_iterations=", max_uninteresting_iterations, ) break class UnEqual: pass _UNEQUAL = UnEqual() def deep_eq(old_val: object, new_val: object, visiting: Set[Tuple[int, int]]) -> bool: # TODO: test just about all of this if old_val is new_val: return True if type(old_val) != type(new_val): return False visit_key = (id(old_val), id(new_val)) if visit_key in visiting: return True visiting.add(visit_key) try: with NoTracing(): is_ch_value = isinstance(old_val, CrossHairValue) if is_ch_value: return old_val == new_val elif hasattr(old_val, "__dict__") and hasattr(new_val, "__dict__"): return deep_eq(old_val.__dict__, new_val.__dict__, visiting) elif isinstance(old_val, dict): assert isinstance(new_val, dict) for key in set(itertools.chain(old_val.keys(), *new_val.keys())): if (key in old_val) ^ (key in new_val): return False if not deep_eq( old_val.get(key, _UNEQUAL), new_val.get(key, _UNEQUAL), visiting ): return False return True elif isinstance(old_val, Iterable): assert isinstance(new_val, Sized) if isinstance(old_val, Sized): if len(old_val) != len(new_val): return False assert isinstance(new_val, Iterable) return all( deep_eq(o, n, visiting) for (o, n) in itertools.zip_longest( old_val, new_val, fillvalue=_UNEQUAL ) ) elif type(old_val) is object: # Plain object instances are close enough to equal for our purposes return True else: # hopefully this is just ints, bools, etc return old_val == new_val finally: visiting.remove(visit_key) class MessageGenerator: def __init__(self, fn: Callable): self.filename = "" if hasattr(fn, "__code__"): code_obj = fn.__code__ self.filename = code_obj.co_filename self.start_lineno = code_obj.co_firstlineno _, _, lines = sourcelines(fn) self.end_lineno = self.start_lineno + len(lines) def make( self, message_type: MessageType, detail: str, suggested_filename: Optional[str], suggested_lineno: int, tb: str, ) -> AnalysisMessage: if ( suggested_filename is not None and (os.path.abspath(suggested_filename) == os.path.abspath(self.filename)) and (self.start_lineno <= suggested_lineno <= self.end_lineno) ): return AnalysisMessage( message_type, detail, suggested_filename, suggested_lineno, 0, tb ) else: exprline = "<unknown>" if suggested_filename is not None: lines = linecache.getlines(suggested_filename) try: exprline = lines[suggested_lineno - 1].strip() except IndexError: pass detail = f'"{exprline}" yields {detail}' return AnalysisMessage( message_type, detail, self.filename, self.start_lineno, 0, tb ) def make_counterexample_message( conditions: Conditions, args: BoundArguments, return_val: object = None ) -> str: reprer = context_statespace().extra(LazyCreationRepr) with NoTracing(): args = reprer.deep_realize(args) return_val = deep_realize(return_val) with NoTracing(): invocation, retstring = conditions.format_counterexample( args, return_val, reprer.reprs ) patch_expr = context_statespace().extra(FunctionInterps).patch_string() if patch_expr: invocation += f" with {patch_expr}" if retstring == "None": return f"when calling {invocation}" else: return f"when calling {invocation} (which returns {retstring})" def attempt_call( conditions: Conditions, short_circuit: ShortCircuitingContext, enforced_conditions: EnforcedConditions, ) -> CallAnalysis: assert not is_tracing() fn = conditions.fn space = context_statespace() msg_gen = MessageGenerator(conditions.src_fn) with enforced_conditions.enabled_enforcement(): original_args = gen_args(conditions.sig) space.checkpoint() bound_args = deepcopyext(original_args, CopyMode.BEST_EFFORT, {}) lcls: Mapping[str, object] = bound_args.arguments # In preconditions, __old__ exists but is just bound to the same args. # This lets people write class invariants using `__old__` to, for example, # demonstrate immutability. lcls = {"__old__": AttributeHolder(lcls), **lcls} expected_exceptions = conditions.raises for precondition in conditions.pre: if not precondition.evaluate: continue with ExceptionFilter(expected_exceptions) as efilter: with enforced_conditions.enabled_enforcement(), short_circuit: with ResumedTracing(): precondition_ok = precondition.evaluate(lcls) precondition_ok = realize(prefer_true(precondition_ok)) if not precondition_ok: debug("Failed to meet precondition", precondition.expr_source) return CallAnalysis(failing_precondition=precondition) if efilter.ignore: debug("Ignored exception in precondition.", efilter.analysis) return efilter.analysis elif efilter.user_exc is not None: (user_exc, tb) = efilter.user_exc formatted_tb = tb.format() debug( "Exception attempting to meet precondition", precondition.expr_source, ":", user_exc, formatted_tb, ) return CallAnalysis( failing_precondition=precondition, failing_precondition_reason=f'it raised "{repr(user_exc)} at {formatted_tb[-1]}"', ) with ExceptionFilter(expected_exceptions) as efilter: unenforced_fn = NoEnforce(fn) bargs, bkwargs = bound_args.args, bound_args.kwargs debug("Starting function body") with enforced_conditions.enabled_enforcement(), short_circuit, ResumedTracing(): __return__ = unenforced_fn(*bargs, **bkwargs) lcls = { **bound_args.arguments, "__return__": __return__, "_": __return__, "__old__": AttributeHolder(original_args.arguments), fn.__name__: fn, } if efilter.ignore: debug("Ignored exception in function.", efilter.analysis) return efilter.analysis elif efilter.user_exc is not None: (e, tb) = efilter.user_exc detail = name_of_type(type(e)) + ": " + str(e) tb_desc = tb.format() frame_filename, frame_lineno = frame_summary_for_fn(conditions.src_fn, tb) if not isinstance(e, NotDeterministic): with ResumedTracing(): space.detach_path() detail += " " + make_counterexample_message(conditions, original_args) debug("exception while evaluating function body:", detail, tb_desc) return CallAnalysis( VerificationStatus.REFUTED, [ msg_gen.make( MessageType.EXEC_ERR, detail, frame_filename, frame_lineno, "".join(tb_desc), ) ], ) for argname, argval in bound_args.arguments.items(): if ( conditions.mutable_args is not None and argname not in conditions.mutable_args ): old_val, new_val = original_args.arguments[argname], argval # TODO: Do we really need custom equality here? Would love to drop that # `deep_eq` function. with ResumedTracing(): if not deep_eq(old_val, new_val, set()): space.detach_path() detail = 'Argument "{}" is not marked as mutable, but changed from {} to {}'.format( argname, old_val, new_val ) debug("Mutablity problem:", detail) return CallAnalysis( VerificationStatus.REFUTED, [msg_gen.make(MessageType.POST_ERR, detail, None, 0, "")], ) (post_condition,) = assert post_condition.evaluate is not None with ExceptionFilter(expected_exceptions) as efilter: # TODO: re-enable post-condition short circuiting. This will require refactoring how # enforced conditions and short curcuiting interact, so that post-conditions are # selectively run when, and only when, performing a short circuit. # with enforced_conditions.enabled_enforcement(), short_circuit: debug("Starting postcondition") with ResumedTracing(): isok = bool(post_condition.evaluate(lcls)) if efilter.ignore: debug("Ignored exception in postcondition.", efilter.analysis) return efilter.analysis elif efilter.user_exc is not None: (e, tb) = efilter.user_exc detail = name_of_type(type(e)) + ": " + str(e) if not isinstance(e, NotDeterministic): with ResumedTracing(): space.detach_path() detail += " " + make_counterexample_message( conditions, original_args, __return__ ) debug("exception while calling postcondition:", detail) debug("exception traceback:", test_stack(tb)) failures = [ msg_gen.make( MessageType.POST_ERR, detail, post_condition.filename, post_condition.line, "".join(tb.format()), ) ] return CallAnalysis(VerificationStatus.REFUTED, failures) if isok: debug("Postcondition confirmed.") return CallAnalysis(VerificationStatus.CONFIRMED) else: with ResumedTracing(): space.detach_path() detail = "false " + make_counterexample_message( conditions, original_args, __return__ ) debug(detail) failures = [ msg_gen.make( MessageType.POST_FAIL, detail, post_condition.filename, post_condition.line, "", ) ] return CallAnalysis(VerificationStatus.REFUTED, failures) def _mutability_testing_hash(o: object) -> int: if isinstance(o, ATOMIC_IMMUTABLE_TYPES): return 0 if hasattr(o, "__ch_is_deeply_immutable__"): if o.__ch_is_deeply_immutable__(): # type: ignore return 0 else: raise TypeError typ = type(o) if not hasattr(typ, "__hash__"): # TODO: test for __hash__ = None (list has this) raise TypeError # We err on the side of mutability if this object is using the default hash: if typ.__hash__ is object.__hash__: raise TypeError return typ.__hash__(o) def is_deeply_immutable(o: object) -> bool: if not is_tracing(): raise CrosshairInternal("is_deeply_immutable must be run with tracing enabled") orig_modules = COMPOSITE_TRACER.get_modules() hash_intercept_module = PatchingModule({hash: _mutability_testing_hash}) for module in reversed(orig_modules): COMPOSITE_TRACER.pop_config(module) COMPOSITE_TRACER.push_module(hash_intercept_module) try: try: hash(o) return True except TypeError: return False finally: COMPOSITE_TRACER.pop_config(hash_intercept_module) for module in orig_modules: COMPOSITE_TRACER.push_module(module) def find_best_sig( sigs: List[Signature], *args: object, **kwargs: Dict[str, object], ) -> Optional[Signature]: """Return the first signature which complies with the args.""" for sig in sigs: bound = sig.bind(*args, **kwargs) bound.apply_defaults() bindings: typing.ChainMap[object, type] = ChainMap() is_valid = True for param in sig.parameters.values(): argval = bound.arguments[] value_type = python_type(argval) if not dynamic_typing.unify(value_type, param.annotation, bindings): is_valid = False break if is_valid: return sig return None def consider_shortcircuit( fn: Callable, sig: Signature, bound: BoundArguments, subconditions: Conditions, allow_interpretation: bool, ) -> Optional[type]: """ Consider the feasibility of short-circuiting (skipping) a function with the given arguments. :return: The type of a symbolic value that could be returned by ``fn``. :return: None if a short-circuiting should not be attempted. """ return_type = sig.return_annotation if return_type == Signature.empty: return_type = object elif return_type is None: return_type = type(None) mutable_args = subconditions.mutable_args if allow_interpretation: if mutable_args is None or len(mutable_args) > 0: # we don't deal with mutation inside the skipped function yet. debug("aborting shortcircuit: function has matuable args") return None # Deduce type vars if necessary if len(typing_inspect.get_parameters(return_type)) > 0 or typing_inspect.is_typevar( return_type ): typevar_bindings: typing.ChainMap[object, type] = ChainMap() bound.apply_defaults() for param in sig.parameters.values(): argval = bound.arguments[] # We don't need all args to be symbolic, but we don't currently # short circuit in that case as a heuristic. if allow_interpretation and not isinstance(argval, CrossHairValue): debug("aborting shortcircuit:",, "is not symbolic") return None value_type = python_type(argval) if not dynamic_typing.unify(value_type, param.annotation, typevar_bindings): if allow_interpretation: debug("aborting shortcircuit",, "fails unification") return None else: raise CrosshairUnsupported return_type = dynamic_typing.realize(sig.return_annotation, typevar_bindings) if not allow_interpretation: return return_type space = context_statespace() short_stats, callinto_stats = space.stats_lookahead() if callinto_stats.unknown_pct < short_stats.unknown_pct: callinto_probability = 1.0 else: callinto_probability = 0.7 debug("short circuit: call-into probability", callinto_probability) do_short_circuit = space.fork_parallel( callinto_probability, desc=f"shortcircuit {fn.__name__}" ) return return_type if do_short_circuit else None def shortcircuit( fn: Callable, sig: Signature, bound: BoundArguments, return_type: Type ) -> object: space = context_statespace() debug("short circuit: Deduced return type was ", return_type) # Deep copy the arguments for reconciliation later. # (we know that this function won't mutate them, but not that others won't) argscopy = {} for name, val in bound.arguments.items(): if is_deeply_immutable(val): argscopy[name] = val else: with NoTracing(): argscopy[name] = deepcopyext(val, CopyMode.BEST_EFFORT, {}) bound_copy = BoundArguments(sig, argscopy) # type: ignore retval = None if return_type is not type(None): # note that the enforcement wrapper ensures postconditions for us, so # we can just return a free variable here. retval = proxy_for_type(return_type, "proxyreturn" + space.uniq()) def reconciled() -> bool: return retval == fn(*bound_copy.args, **bound_copy.kwargs) space.defer_assumption("Reconcile short circuit", reconciled) return retval