Shortcuts

Source code for torch.utils.benchmark.utils.valgrind_wrapper.timer_interface

"""Intermediate layer between `Timer` and `valgrind`."""
import collections
import enum
import dataclasses
import itertools as it
import os
import pickle
import re
import shutil
import subprocess
import sys
import textwrap
from typing import (
    cast, Any, Callable, DefaultDict, Dict, Generator, List, NamedTuple,
    Optional, Tuple, Union, TYPE_CHECKING)

import torch
from torch.utils.benchmark.utils import common, cpp_jit
from torch.utils.benchmark.utils._stubs import CallgrindModuleType


__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]


if TYPE_CHECKING:
    CompletedProcessType = subprocess.CompletedProcess[str]
else:
    CompletedProcessType = subprocess.CompletedProcess


FunctionCount = NamedTuple("FunctionCount", [("count", int), ("function", str)])


[docs]@dataclasses.dataclass(repr=False, eq=False, frozen=True) class FunctionCounts: """Container for manipulating Callgrind results. It supports: 1) Addition and subtraction to combine or diff results. 2) Tuple-like indexing. 3) A `denoise` function which strips CPython calls which are known to be non-deterministic and quite noisy. 4) Two higher order methods (`filter` and `transform`) for custom manipulation. """ _data: Tuple[FunctionCount, ...] inclusive: bool truncate_rows: bool = True # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines # the print settings. This is simply to allow hermetic unit tests. _linewidth: Optional[int] = None def __iter__(self) -> Generator[FunctionCount, None, None]: yield from self._data def __len__(self) -> int: return len(self._data) def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]: data: Union[FunctionCount, Tuple[FunctionCount, ...]] = self._data[item] return ( FunctionCounts(cast(Tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False) if isinstance(data, tuple) else data ) def __repr__(self) -> str: count_len = 0 for c, _ in self: # Account for sign in string length. count_len = max(count_len, len(str(c)) + int(c < 0)) lines = [] linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth fn_str_len = max(linewidth - count_len - 4, 40) for c, fn in self: if len(fn) > fn_str_len: left_len = int((fn_str_len - 5) // 2) fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):] lines.append(f" {c:>{count_len}} {fn}") if self.truncate_rows and len(lines) > 18: lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:] if not self.inclusive: lines.extend(["", f"Total: {self.sum()}"]) return "\n".join([super().__repr__()] + lines) def __add__( self, other: "FunctionCounts", ) -> "FunctionCounts": return self._merge(other, lambda c: c) def __sub__( self, other: "FunctionCounts", ) -> "FunctionCounts": return self._merge(other, lambda c: -c) def __mul__(self, other: Union[int, float]) -> "FunctionCounts": return self._from_dict({ fn: int(c * other) for c, fn in self._data }, self.inclusive)
[docs] def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts": """Apply `map_fn` to all of the function names. This can be used to regularize function names (e.g. stripping irrelevant parts of the file path), coalesce entries by mapping multiple functions to the same name (in which case the counts are added together), etc. """ counts: DefaultDict[str, int] = collections.defaultdict(int) for c, fn in self._data: counts[map_fn(fn)] += c return self._from_dict(counts, self.inclusive)
[docs] def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts": """Keep only the elements where `filter_fn` applied to function name returns True.""" return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive)
def sum(self) -> int: return sum(c for c, _ in self)
[docs] def denoise(self) -> "FunctionCounts": """Remove known noisy instructions. Several instructions in the CPython interpreter are rather noisy. These instructions involve unicode to dictionary lookups which Python uses to map variable names. FunctionCounts is generally a content agnostic container, however this is sufficiently important for obtaining reliable results to warrant an exception.""" return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn)
def _merge( self, second: "FunctionCounts", merge_fn: Callable[[int], int] ) -> "FunctionCounts": assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts." counts: DefaultDict[str, int] = collections.defaultdict(int) for c, fn in self: counts[fn] += c for c, fn in second: counts[fn] += merge_fn(c) return self._from_dict(counts, self.inclusive) @staticmethod def _from_dict(counts: Dict[str, int], inclusive: bool) -> "FunctionCounts": flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c) return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive)
[docs]@dataclasses.dataclass(repr=False, eq=False, frozen=True) class CallgrindStats: """Top level container for Callgrind results collected by Timer. Manipulation is generally done using the FunctionCounts class, which is obtained by calling `CallgrindStats.stats(...)`. Several convenience methods are provided as well; the most significant is `CallgrindStats.as_standardized()`. """ task_spec: common.TaskSpec number_per_run: int built_with_debug_symbols: bool baseline_inclusive_stats: FunctionCounts baseline_exclusive_stats: FunctionCounts stmt_inclusive_stats: FunctionCounts stmt_exclusive_stats: FunctionCounts stmt_callgrind_out: Optional[str] def __repr__(self) -> str: newline = "\n" # `\` cannot appear in fstring code section. base_stats = self.baseline_exclusive_stats output = f""" {super().__repr__()} {self.task_spec.summarize()} {'':>25}All{'':>10}Noisy symbols removed Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12} Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12} {self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''} """.strip() if not self.built_with_debug_symbols: output += textwrap.dedent(""" Warning: PyTorch was not built with debug symbols. Source information may be limited. Rebuild with REL_WITH_DEB_INFO=1 for more detailed results.""") return output
[docs] def stats(self, inclusive: bool = False) -> FunctionCounts: """Returns detailed function counts. Conceptually, the FunctionCounts returned can be thought of as a tuple of (count, path_and_function_name) tuples. `inclusive` matches the semantics of callgrind. If True, the counts include instructions executed by children. `inclusive=True` is useful for identifying hot spots in code; `inclusive=False` is useful for reducing noise when diffing counts from two different runs. (See CallgrindStats.delta(...) for more details) """ return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats
[docs] def counts(self, *, denoise: bool = False) -> int: """Returns the total number of instructions executed. See `FunctionCounts.denoise()` for an explanation of the `denoise` arg. """ stats = self.stmt_exclusive_stats return (stats.denoise() if denoise else stats).sum()
# FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563
[docs] def delta( self, other: "CallgrindStats", inclusive: bool = False, ) -> FunctionCounts: """Diff two sets of counts. One common reason to collect instruction counts is to determine the the effect that a particular change will have on the number of instructions needed to perform some unit of work. If a change increases that number, the next logical question is "why". This generally involves looking at what part if the code increased in instruction count. This function automates that process so that one can easily diff counts on both an inclusive and exclusive basis. """ return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive)
[docs] def as_standardized(self) -> "CallgrindStats": """Strip library names and some prefixes from function strings. When comparing two different sets of instruction counts, on stumbling block can be path prefixes. Callgrind includes the full filepath when reporting a function (as it should). However, this can cause issues when diffing profiles. If a key component such as Python or PyTorch was built in separate locations in the two profiles, which can result in something resembling:: 23234231 /tmp/first_build_dir/thing.c:foo(...) 9823794 /tmp/first_build_dir/thing.c:bar(...) ... 53453 .../aten/src/Aten/...:function_that_actually_changed(...) ... -9823794 /tmp/second_build_dir/thing.c:bar(...) -23234231 /tmp/second_build_dir/thing.c:foo(...) Stripping prefixes can ameliorate this issue by regularizing the strings and causing better cancellation of equivalent call sites when diffing. """ def strip(stats: FunctionCounts) -> FunctionCounts: transforms = ( # PyTorch may have been built in different locations. (r"^.+build/\.\./", "build/../"), (r"^.+/" + re.escape("build/aten/"), "build/aten/"), # "Python" and "Objects" come from CPython. (r"^.+/" + re.escape("Python/"), "Python/"), (r"^.+/" + re.escape("Objects/"), "Objects/"), # Strip library name. e.g. `libtorch.so` (r"\s\[.+\]$", ""), ) for before, after in transforms: stats = stats.transform(lambda fn: re.sub(before, after, fn)) return stats return CallgrindStats( task_spec=self.task_spec, number_per_run=self.number_per_run, built_with_debug_symbols=self.built_with_debug_symbols, baseline_inclusive_stats=strip(self.baseline_inclusive_stats), baseline_exclusive_stats=strip(self.baseline_exclusive_stats), stmt_inclusive_stats=strip(self.stmt_inclusive_stats), stmt_exclusive_stats=strip(self.stmt_exclusive_stats), # `as_standardized` will change symbol names, so the contents will # no longer map directly to `callgrind.out` stmt_callgrind_out=None, )
class Serialization(enum.Enum): PICKLE = 0 TORCH = 1 TORCH_JIT = 2 _GLOBALS_ALLOWED_TYPES: Dict[Serialization, Tuple[Any, ...]] = { Serialization.PICKLE: (str, bytes, bool, int, float, complex), Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule), Serialization.TORCH: (torch.nn.Module,), } class CopyIfCallgrind: """Signal that a global may be replaced with a deserialized copy. See `GlobalsBridge` for why this matters. """ def __init__(self, value: Any, *, setup: Optional[str] = None): for method, supported_types in _GLOBALS_ALLOWED_TYPES.items(): if any(isinstance(value, t) for t in supported_types): self._value: Any = value self._setup: Optional[str] = setup self._serialization: Serialization = method break else: supported_str = "\n".join([ getattr(t, "__name__", repr(t)) for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())]) raise ValueError( f"Unsupported type: {type(value)}\n" f"`collect_callgrind` restricts globals to the following types:\n" f"{textwrap.indent(supported_str, ' ')}" ) @property def value(self) -> Any: return self._value @property def setup(self) -> Optional[str]: return self._setup @property def serialization(self) -> Serialization: return self._serialization @staticmethod def unwrap_all(globals: Dict[str, Any]) -> Dict[str, Any]: return { k: (v.value if isinstance(v, CopyIfCallgrind) else v) for k, v in globals.items() } class GlobalsBridge: """Handle the transfer of (certain) globals when collecting Callgrind statistics. Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to work with `Timer.collect_callgrind`. Consider the following code snippet: ``` import pickle import timeit class Counter: value = 0 def __call__(self): self.value += 1 counter = Counter() timeit.Timer("counter()", globals={"counter": counter}).timeit(10) print(counter.value) # 10 timeit.Timer( "counter()", globals={"counter": pickle.loads(pickle.dumps(counter))} ).timeit(20) print(counter.value) # Still 10 ``` In the first case, `stmt` is executed using the objects in `globals`; however, the addition of serialization and deserialization changes the semantics and may meaningfully change behavior. This is a practical consideration when collecting Callgrind statistics. Unlike `exec` based execution (which `timeit` uses under the hood) which can share in-memory data structures with the caller, Callgrind collection requires an entirely new process in order to run under Valgrind. This means that any data structures used for statement execution will have to be serialized and deserialized in the subprocess. In order to avoid surprising semantics from (user invisible) process boundaries, what can be passed through `globals` is severely restricted for `Timer.collect_callgrind`. It is expected that most setup should be achievable (albeit perhaps less ergonomically) by passing a `setup` string. There are, however, exceptions. One such class are TorchScripted functions. Because they require a concrete file with source code it is not possible to define them using a `setup` string. Another group are torch.nn.Modules, whose construction can be complex and prohibitively cumbersome to coerce into a `setup` string. Finally, most builtin types are sufficiently well behaved and sufficiently common to warrant allowing as well. (e.g. `globals={"n": 1}` is very convenient.) Fortunately, all have well defined serialization semantics. This class is responsible for enabling the Valgrind subprocess to use elements in `globals` so long as they are an allowed type. Caveats: The user is required to acknowledge this serialization by wrapping elements in `globals` with `CopyIfCallgrind`. While ScriptFunction and ScriptModule are expected to save and load quite robustly, it is up to the user to ensure that an nn.Module can un-pickle successfully. `torch.Tensor` and `np.ndarray` are deliberately excluded. The serialization/deserialization process perturbs the representation of a tensor in ways that could result in incorrect measurements. For example, if a tensor lives in pinned CPU memory, this fact would not be preserved by a dump, and that will in turn change the performance of certain CUDA operations. """ def __init__(self, globals: Dict[str, Any], data_dir: str) -> None: self._globals: Dict[str, CopyIfCallgrind] = {} self._data_dir = data_dir if not os.path.exists(data_dir): os.mkdir(data_dir) if globals.get("torch", torch) is not torch: raise ValueError("`collect_callgrind` does not support mocking out `torch`.") for name, value in globals.items(): if name in ("torch", "__builtins__"): # Torch will be imported by the collection script, and # __builtins__ is added by Timer. continue if not isinstance(value, CopyIfCallgrind): raise ValueError( "`collect_callgrind` requires that globals be wrapped in " "`CopyIfCallgrind` so that serialization is explicit." ) self._globals[name] = value def construct(self) -> str: load_lines = [] for name, wrapped_value in self._globals.items(): if wrapped_value.setup is not None: load_lines.append(textwrap.dedent(wrapped_value.setup)) if wrapped_value.serialization == Serialization.PICKLE: path = os.path.join(self._data_dir, f"{name}.pkl") load_lines.append( f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)") with open(path, "wb") as f: pickle.dump(wrapped_value.value, f) elif wrapped_value.serialization == Serialization.TORCH: path = os.path.join(self._data_dir, f"{name}.pt") load_lines.append(f"{name} = torch.load({repr(path)})") torch.save(wrapped_value.value, path) elif wrapped_value.serialization == Serialization.TORCH_JIT: path = os.path.join(self._data_dir, f"{name}.pt") load_lines.append(f"{name} = torch.jit.load({repr(path)})") with open(path, "wb") as f: torch.jit.save(wrapped_value.value, f) else: raise NotImplementedError( f"Unknown serialization method: {wrapped_value.serialization}") return "\n".join(load_lines) class _ValgrindWrapper: def __init__(self) -> None: self._bindings_module: Optional[CallgrindModuleType] = None valgrind_symbols = ( "_valgrind_supported_platform", "_valgrind_toggle", "_valgrind_toggle_and_dump_stats", ) if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols): self._supported_platform: bool = torch._C._valgrind_supported_platform() else: print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.") self._bindings_module = cpp_jit.get_compat_bindings() assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols) self._supported_platform = self._bindings_module._valgrind_supported_platform() self._commands_available: Dict[str, bool] = {} if self._supported_platform: # Only bother checking on supported platforms. for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"): self._commands_available[cmd] = not subprocess.run( ["which", cmd], capture_output=True, ).returncode self._build_type: Optional[str] = None build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) if build_search is not None: self._build_type = build_search.groups()[0].split(",")[0] def _validate(self) -> None: if not self._supported_platform: raise OSError("Valgrind is not supported on this platform.") missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available] if missing_cmds: raise OSError("Missing: " + ", ".join(missing_cmds)) def collect_callgrind( self, task_spec: common.TaskSpec, globals: Dict[str, Any], *, number: int, repeats: int, collect_baseline: bool, is_python: bool, retain_out_file: bool, ) -> Tuple[CallgrindStats, ...]: """Collect stats, and attach a reference run which can be used to filter interpreter overhead.""" self._validate() assert is_python or not collect_baseline *task_stats, baseline_stats = self._invoke( task_spec=task_spec, globals=globals, number=number, repeats=repeats, collect_baseline=collect_baseline, is_python=is_python, retain_out_file=retain_out_file, ) assert len(task_stats) == repeats return tuple( CallgrindStats( task_spec=task_spec, number_per_run=number, built_with_debug_symbols=self._build_type == "RelWithDebInfo", baseline_inclusive_stats=baseline_stats[0], baseline_exclusive_stats=baseline_stats[1], stmt_inclusive_stats=stmt_inclusive_stats, stmt_exclusive_stats=stmt_exclusive_stats, stmt_callgrind_out=out_contents, ) for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats ) def _invoke( self, *, task_spec: common.TaskSpec, globals: Dict[str, Any], number: int, repeats: int, collect_baseline: bool, is_python: bool, retain_out_file: bool, ) -> Tuple[Tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]: """Core invocation method for Callgrind collection. Valgrind operates by effectively replacing the CPU with an emulated version which allows it to instrument any code at the cost of severe performance degradation. This has the practical effect that in order to collect Callgrind statistics, a new process has to be created running under `valgrind`. The steps for this process are: 1) Create a scratch directory. 2) Codegen a run script. (_ValgrindWrapper._construct_script) Inside the run script: * Validate that Python and torch match the parent process * Validate that it is indeed running under valgrind * Execute `setup` and warm up `stmt` * Begin collecting stats * Run the `stmt` loop * Stop collecting stats 3) Parse the run results. 4) Cleanup the scratch directory. """ working_dir = common._make_temp_dir(prefix="callgrind") data_dir = os.path.join(working_dir, "data") script_file = os.path.join(working_dir, "timer_callgrind.py") callgrind_out = os.path.join(working_dir, "callgrind.out") error_log = os.path.join(working_dir, "error.txt") stat_log = os.path.join(working_dir, "callgrind_stat.txt") stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log") def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]: # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/ f_stdout_stderr = open(stdout_stderr_log, "wb") try: invocation = subprocess.run( args, stdout=f_stdout_stderr, stderr=subprocess.STDOUT, **kwargs, ) with open(stdout_stderr_log, "rt") as f: return invocation, f.read() finally: f_stdout_stderr.close() try: if is_python: if self._bindings_module is not None: shutil.copy( self._bindings_module.__file__, os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1]) ) script_file = os.path.join(working_dir, "timer_callgrind.py") with open(script_file, "wt") as f: f.write(self._construct_script( task_spec, globals=GlobalsBridge(globals, data_dir), number=number, repeats=repeats, collect_baseline=collect_baseline, error_log=error_log, stat_log=stat_log, bindings=self._bindings_module)) run_loop_cmd = ["python", script_file] else: assert not collect_baseline run_loop_exec = cpp_jit.compile_callgrind_template( stmt=task_spec.stmt, setup=task_spec.setup, global_setup=task_spec.global_setup, ) run_loop_cmd = [ run_loop_exec, "--number", str(number), "--number-warmup", str(min(number, 10)), "--repeats", str(repeats), "--number-threads", str(task_spec.num_threads), ] valgrind_invocation, valgrind_invocation_output = run([ "valgrind", "--tool=callgrind", f"--callgrind-out-file={callgrind_out}", "--dump-line=yes", "--dump-instr=yes", "--instr-atstart=yes", "--collect-atstart=no", ] + run_loop_cmd) if valgrind_invocation.returncode: error_report = "" if os.path.exists(error_log): with open(error_log, "rt") as f: error_report = f.read() if not error_report: error_report = "Unknown error.\n" + valgrind_invocation_output raise OSError(f"Failed to collect callgrind profile:\n{error_report}") def parse_output(fpath: str, inclusive: bool) -> FunctionCounts: annotate_invocation, annotate_invocation_output = run([ "callgrind_annotate", f"--inclusive={'yes' if inclusive else 'no'}", "--threshold=100", "--show-percs=no", fpath ], check=True) total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS") begin_pattern = re.compile(r"Ir\s+file:function") function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$") class ScanState(enum.Enum): SCANNING_FOR_TOTAL = 0 SCANNING_FOR_START = 1 PARSING = 2 scan_state = ScanState.SCANNING_FOR_TOTAL fn_counts = [] for l in annotate_invocation_output.splitlines(keepends=False): if scan_state == ScanState.SCANNING_FOR_TOTAL: total_match = total_pattern.match(l) if total_match: program_totals = int(total_match.groups()[0].replace(",", "")) scan_state = ScanState.SCANNING_FOR_START elif scan_state == ScanState.SCANNING_FOR_START: if begin_pattern.match(l): scan_state = ScanState.PARSING else: assert scan_state == ScanState.PARSING fn_match = function_pattern.match(l) if fn_match: ir_str, file_function = fn_match.groups() ir = int(ir_str.replace(",", "")) if ir == program_totals: # Callgrind includes some top level red herring symbols when # a program dumps multiple profiles. continue fn_counts.append(FunctionCount(ir, file_function)) elif re.match(r"-+", l): # Ignore heading separator lines. continue else: break assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}" return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive) def read_results(i: int) -> Tuple[FunctionCounts, FunctionCounts, Optional[str]]: if i == repeats and not collect_baseline: # Null baseline. return ( FunctionCounts((), inclusive=True), FunctionCounts((), inclusive=False), None, ) fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files. callgrind_out_contents: Optional[str] = None if retain_out_file: with open(fpath, "rt") as f: callgrind_out_contents = f.read() return ( parse_output(fpath, inclusive=True), parse_output(fpath, inclusive=False), callgrind_out_contents ) return tuple(read_results(i) for i in range(repeats + 1)) finally: shutil.rmtree(working_dir) @staticmethod def _construct_script( task_spec: common.TaskSpec, globals: GlobalsBridge, *, number: int, repeats: int, collect_baseline: bool, error_log: str, stat_log: str, bindings: Optional[CallgrindModuleType], ) -> str: def block_stmt(stmt: str, indent: int = 0) -> str: """Partially unroll benchmark loop. The naive template looks something like: "for _ in range({number}): {stmt}" However a loop in Python is surprisingly expensive, and significantly increases the number of background Python instructions. So instead we partially unroll the loops, with a block size of 100 chosen to keep the instruction overhead from `range` low while also not ballooning the size of the generated file. """ block_size = 100 loop_count = number // block_size if loop_count == 1: # There is no point in having `for _ in range(1): ...` rather # than just `...`, and this lets us save shave a few background # instructions. loop_count = 0 remainder = number - block_size * loop_count blocked_stmt = "" if loop_count: unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4) blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n" if remainder: blocked_stmt += "\n".join([stmt] * remainder) return textwrap.indent(blocked_stmt, " " * indent) pass_baseline = ( "callgrind_bindings._valgrind_toggle()\n" f"{block_stmt('pass')}\n" "callgrind_bindings._valgrind_toggle_and_dump_stats()" ) return textwrap.dedent(r""" import gc import os import pickle import subprocess import sys import time # Mitigate https://github.com/pytorch/pytorch/issues/37377 # which can sometimes cause the subprocess call to fail. import numpy as np import torch torch.set_num_threads({num_threads}) {bindings_import} PID = os.getpid() def log_failure(msg): with open({error_log_repr}, "wt") as f: f.write(msg) sys.exit(1) def check_result(completed_process): if completed_process.returncode: log_failure(f"Command failed: {{' '.join(completed_process.args)}}") return completed_process # ============================================================================= # == Check that subprocess matches parent ===================================== # ============================================================================= if os.path.realpath(sys.executable) != "{parent_interpreter}": log_failure( "Interpreter mismatch:\n" f" {{os.path.realpath(sys.executable)}}\n vs.\n {parent_interpreter}" ) if torch.__file__ != "{torch_file}": log_failure( "PyTorch does not match expected file:\n" f" {{torch.__file__}}\n vs.\n {torch_file}" ) # ============================================================================= # == User specified setup ===================================================== # ============================================================================= # Load serialized globals {load_globals} # User setup str {setup} for _ in range({warmup_number}): {indented_stmt} # ============================================================================= # == Callgrind management ===================================================== # ============================================================================= with open("{stat_log}", "wb") as stat_file: # If many instances of callgrind are running at once, the output of # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE` # to deadlock. So instead we use a file. callgrind_stat = check_result(subprocess.run( ["callgrind_control", "--stat"], stdout=stat_file, stderr=subprocess.STDOUT, )) with open("{stat_log}", "rt") as stat_file: stat_lines = stat_file.read().splitlines() if f"PID {{PID}}: python {{__file__}}" not in stat_lines: log_failure("Process does not appear to be running callgrind.") gc.collect() time.sleep(0.01) # ============================================================================= # == User code block ========================================================== # ============================================================================= for _ in range({repeats}): callgrind_bindings._valgrind_toggle() {blocked_stmt} callgrind_bindings._valgrind_toggle_and_dump_stats() gc.collect() {baseline} """).strip().format( indented_stmt=textwrap.indent(task_spec.stmt, " " * 4), blocked_stmt=block_stmt(task_spec.stmt, indent=4), baseline=(pass_baseline if collect_baseline else ""), number=number, repeats=repeats, load_globals=globals.construct(), setup=task_spec.setup, warmup_number=min(number, 10), num_threads=task_spec.num_threads, error_log_repr=repr(error_log), stat_log=stat_log, parent_interpreter=os.path.realpath(sys.executable), torch_file=torch.__file__, bindings_import=( "import torch._C as callgrind_bindings" if bindings is None else f"import {bindings.__name__} as callgrind_bindings"), ) CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None def wrapper_singleton() -> _ValgrindWrapper: global CALLGRIND_SINGLETON if CALLGRIND_SINGLETON is None: CALLGRIND_SINGLETON = _ValgrindWrapper() return CALLGRIND_SINGLETON

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources