Source code for torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
"""Intermediate layer between `Timer` and `valgrind`."""
import collections
import enum
import dataclasses
import itertools as it
import os
import pickle
import re
import shutil
import subprocess
import sys
import textwrap
from typing import (
cast, Any, Callable, DefaultDict, Dict, Generator, List, NamedTuple,
Optional, Tuple, Union, TYPE_CHECKING)
import torch
from torch.utils.benchmark.utils import common, cpp_jit
from torch.utils.benchmark.utils._stubs import CallgrindModuleType
__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
if TYPE_CHECKING:
CompletedProcessType = subprocess.CompletedProcess[str]
else:
CompletedProcessType = subprocess.CompletedProcess
FunctionCount = NamedTuple("FunctionCount", [("count", int), ("function", str)])
[docs]@dataclasses.dataclass(repr=False, eq=False, frozen=True)
class FunctionCounts:
"""Container for manipulating Callgrind results.
It supports:
1) Addition and subtraction to combine or diff results.
2) Tuple-like indexing.
3) A `denoise` function which strips CPython calls which are known to
be non-deterministic and quite noisy.
4) Two higher order methods (`filter` and `transform`) for custom
manipulation.
"""
_data: Tuple[FunctionCount, ...]
inclusive: bool
truncate_rows: bool = True
# For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines
# the print settings. This is simply to allow hermetic unit tests.
_linewidth: Optional[int] = None
def __iter__(self) -> Generator[FunctionCount, None, None]:
yield from self._data
def __len__(self) -> int:
return len(self._data)
def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]:
data: Union[FunctionCount, Tuple[FunctionCount, ...]] = self._data[item]
return (
FunctionCounts(cast(Tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False)
if isinstance(data, tuple) else data
)
def __repr__(self) -> str:
count_len = 0
for c, _ in self:
# Account for sign in string length.
count_len = max(count_len, len(str(c)) + int(c < 0))
lines = []
linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth
fn_str_len = max(linewidth - count_len - 4, 40)
for c, fn in self:
if len(fn) > fn_str_len:
left_len = int((fn_str_len - 5) // 2)
fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):]
lines.append(f" {c:>{count_len}} {fn}")
if self.truncate_rows and len(lines) > 18:
lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:]
if not self.inclusive:
lines.extend(["", f"Total: {self.sum()}"])
return "\n".join([super().__repr__()] + lines)
def __add__(
self,
other: "FunctionCounts",
) -> "FunctionCounts":
return self._merge(other, lambda c: c)
def __sub__(
self,
other: "FunctionCounts",
) -> "FunctionCounts":
return self._merge(other, lambda c: -c)
def __mul__(self, other: Union[int, float]) -> "FunctionCounts":
return self._from_dict({
fn: int(c * other) for c, fn in self._data
}, self.inclusive)
[docs] def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts":
"""Apply `map_fn` to all of the function names.
This can be used to regularize function names (e.g. stripping irrelevant
parts of the file path), coalesce entries by mapping multiple functions
to the same name (in which case the counts are added together), etc.
"""
counts: DefaultDict[str, int] = collections.defaultdict(int)
for c, fn in self._data:
counts[map_fn(fn)] += c
return self._from_dict(counts, self.inclusive)
[docs] def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts":
"""Keep only the elements where `filter_fn` applied to function name returns True."""
return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive)
def sum(self) -> int:
return sum(c for c, _ in self)
[docs] def denoise(self) -> "FunctionCounts":
"""Remove known noisy instructions.
Several instructions in the CPython interpreter are rather noisy. These
instructions involve unicode to dictionary lookups which Python uses to
map variable names. FunctionCounts is generally a content agnostic
container, however this is sufficiently important for obtaining
reliable results to warrant an exception."""
return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn)
def _merge(
self,
second: "FunctionCounts",
merge_fn: Callable[[int], int]
) -> "FunctionCounts":
assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts."
counts: DefaultDict[str, int] = collections.defaultdict(int)
for c, fn in self:
counts[fn] += c
for c, fn in second:
counts[fn] += merge_fn(c)
return self._from_dict(counts, self.inclusive)
@staticmethod
def _from_dict(counts: Dict[str, int], inclusive: bool) -> "FunctionCounts":
flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c)
return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive)
[docs]@dataclasses.dataclass(repr=False, eq=False, frozen=True)
class CallgrindStats:
"""Top level container for Callgrind results collected by Timer.
Manipulation is generally done using the FunctionCounts class, which is
obtained by calling `CallgrindStats.stats(...)`. Several convenience
methods are provided as well; the most significant is
`CallgrindStats.as_standardized()`.
"""
task_spec: common.TaskSpec
number_per_run: int
built_with_debug_symbols: bool
baseline_inclusive_stats: FunctionCounts
baseline_exclusive_stats: FunctionCounts
stmt_inclusive_stats: FunctionCounts
stmt_exclusive_stats: FunctionCounts
stmt_callgrind_out: Optional[str]
def __repr__(self) -> str:
newline = "\n" # `\` cannot appear in fstring code section.
base_stats = self.baseline_exclusive_stats
output = f"""
{super().__repr__()}
{self.task_spec.summarize()}
{'':>25}All{'':>10}Noisy symbols removed
Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12}
Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12}
{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''}
""".strip()
if not self.built_with_debug_symbols:
output += textwrap.dedent("""
Warning: PyTorch was not built with debug symbols.
Source information may be limited. Rebuild with
REL_WITH_DEB_INFO=1 for more detailed results.""")
return output
[docs] def stats(self, inclusive: bool = False) -> FunctionCounts:
"""Returns detailed function counts.
Conceptually, the FunctionCounts returned can be thought of as a tuple
of (count, path_and_function_name) tuples.
`inclusive` matches the semantics of callgrind. If True, the counts
include instructions executed by children. `inclusive=True` is useful
for identifying hot spots in code; `inclusive=False` is useful for
reducing noise when diffing counts from two different runs. (See
CallgrindStats.delta(...) for more details)
"""
return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats
[docs] def counts(self, *, denoise: bool = False) -> int:
"""Returns the total number of instructions executed.
See `FunctionCounts.denoise()` for an explanation of the `denoise` arg.
"""
stats = self.stmt_exclusive_stats
return (stats.denoise() if denoise else stats).sum()
# FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563
[docs] def delta(
self,
other: "CallgrindStats",
inclusive: bool = False,
) -> FunctionCounts:
"""Diff two sets of counts.
One common reason to collect instruction counts is to determine the
the effect that a particular change will have on the number of instructions
needed to perform some unit of work. If a change increases that number, the
next logical question is "why". This generally involves looking at what part
if the code increased in instruction count. This function automates that
process so that one can easily diff counts on both an inclusive and
exclusive basis.
"""
return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive)
[docs] def as_standardized(self) -> "CallgrindStats":
"""Strip library names and some prefixes from function strings.
When comparing two different sets of instruction counts, on stumbling
block can be path prefixes. Callgrind includes the full filepath
when reporting a function (as it should). However, this can cause
issues when diffing profiles. If a key component such as Python
or PyTorch was built in separate locations in the two profiles, which
can result in something resembling::
23234231 /tmp/first_build_dir/thing.c:foo(...)
9823794 /tmp/first_build_dir/thing.c:bar(...)
...
53453 .../aten/src/Aten/...:function_that_actually_changed(...)
...
-9823794 /tmp/second_build_dir/thing.c:bar(...)
-23234231 /tmp/second_build_dir/thing.c:foo(...)
Stripping prefixes can ameliorate this issue by regularizing the
strings and causing better cancellation of equivalent call sites
when diffing.
"""
def strip(stats: FunctionCounts) -> FunctionCounts:
transforms = (
# PyTorch may have been built in different locations.
(r"^.+build/\.\./", "build/../"),
(r"^.+/" + re.escape("build/aten/"), "build/aten/"),
# "Python" and "Objects" come from CPython.
(r"^.+/" + re.escape("Python/"), "Python/"),
(r"^.+/" + re.escape("Objects/"), "Objects/"),
# Strip library name. e.g. `libtorch.so`
(r"\s\[.+\]$", ""),
)
for before, after in transforms:
stats = stats.transform(lambda fn: re.sub(before, after, fn))
return stats
return CallgrindStats(
task_spec=self.task_spec,
number_per_run=self.number_per_run,
built_with_debug_symbols=self.built_with_debug_symbols,
baseline_inclusive_stats=strip(self.baseline_inclusive_stats),
baseline_exclusive_stats=strip(self.baseline_exclusive_stats),
stmt_inclusive_stats=strip(self.stmt_inclusive_stats),
stmt_exclusive_stats=strip(self.stmt_exclusive_stats),
# `as_standardized` will change symbol names, so the contents will
# no longer map directly to `callgrind.out`
stmt_callgrind_out=None,
)
class Serialization(enum.Enum):
PICKLE = 0
TORCH = 1
TORCH_JIT = 2
_GLOBALS_ALLOWED_TYPES: Dict[Serialization, Tuple[Any, ...]] = {
Serialization.PICKLE: (str, bytes, bool, int, float, complex),
Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule),
Serialization.TORCH: (torch.nn.Module,),
}
class CopyIfCallgrind:
"""Signal that a global may be replaced with a deserialized copy.
See `GlobalsBridge` for why this matters.
"""
def __init__(self, value: Any, *, setup: Optional[str] = None):
for method, supported_types in _GLOBALS_ALLOWED_TYPES.items():
if any(isinstance(value, t) for t in supported_types):
self._value: Any = value
self._setup: Optional[str] = setup
self._serialization: Serialization = method
break
else:
supported_str = "\n".join([
getattr(t, "__name__", repr(t))
for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())])
raise ValueError(
f"Unsupported type: {type(value)}\n"
f"`collect_callgrind` restricts globals to the following types:\n"
f"{textwrap.indent(supported_str, ' ')}"
)
@property
def value(self) -> Any:
return self._value
@property
def setup(self) -> Optional[str]:
return self._setup
@property
def serialization(self) -> Serialization:
return self._serialization
@staticmethod
def unwrap_all(globals: Dict[str, Any]) -> Dict[str, Any]:
return {
k: (v.value if isinstance(v, CopyIfCallgrind) else v)
for k, v in globals.items()
}
class GlobalsBridge:
"""Handle the transfer of (certain) globals when collecting Callgrind statistics.
Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to
work with `Timer.collect_callgrind`.
Consider the following code snippet:
```
import pickle
import timeit
class Counter:
value = 0
def __call__(self):
self.value += 1
counter = Counter()
timeit.Timer("counter()", globals={"counter": counter}).timeit(10)
print(counter.value) # 10
timeit.Timer(
"counter()",
globals={"counter": pickle.loads(pickle.dumps(counter))}
).timeit(20)
print(counter.value) # Still 10
```
In the first case, `stmt` is executed using the objects in `globals`;
however, the addition of serialization and deserialization changes the
semantics and may meaningfully change behavior.
This is a practical consideration when collecting Callgrind statistics.
Unlike `exec` based execution (which `timeit` uses under the hood) which
can share in-memory data structures with the caller, Callgrind collection
requires an entirely new process in order to run under Valgrind. This means
that any data structures used for statement execution will have to be
serialized and deserialized in the subprocess.
In order to avoid surprising semantics from (user invisible) process
boundaries, what can be passed through `globals` is severely restricted
for `Timer.collect_callgrind`. It is expected that most setup should be
achievable (albeit perhaps less ergonomically) by passing a `setup`
string.
There are, however, exceptions. One such class are TorchScripted functions.
Because they require a concrete file with source code it is not possible
to define them using a `setup` string. Another group are torch.nn.Modules,
whose construction can be complex and prohibitively cumbersome to coerce
into a `setup` string. Finally, most builtin types are sufficiently well
behaved and sufficiently common to warrant allowing as well. (e.g.
`globals={"n": 1}` is very convenient.)
Fortunately, all have well defined serialization semantics. This class
is responsible for enabling the Valgrind subprocess to use elements in
`globals` so long as they are an allowed type.
Caveats:
The user is required to acknowledge this serialization by wrapping
elements in `globals` with `CopyIfCallgrind`.
While ScriptFunction and ScriptModule are expected to save and load
quite robustly, it is up to the user to ensure that an nn.Module can
un-pickle successfully.
`torch.Tensor` and `np.ndarray` are deliberately excluded. The
serialization/deserialization process perturbs the representation of a
tensor in ways that could result in incorrect measurements. For example,
if a tensor lives in pinned CPU memory, this fact would not be preserved
by a dump, and that will in turn change the performance of certain CUDA
operations.
"""
def __init__(self, globals: Dict[str, Any], data_dir: str) -> None:
self._globals: Dict[str, CopyIfCallgrind] = {}
self._data_dir = data_dir
if not os.path.exists(data_dir):
os.mkdir(data_dir)
if globals.get("torch", torch) is not torch:
raise ValueError("`collect_callgrind` does not support mocking out `torch`.")
for name, value in globals.items():
if name in ("torch", "__builtins__"):
# Torch will be imported by the collection script, and
# __builtins__ is added by Timer.
continue
if not isinstance(value, CopyIfCallgrind):
raise ValueError(
"`collect_callgrind` requires that globals be wrapped in "
"`CopyIfCallgrind` so that serialization is explicit."
)
self._globals[name] = value
def construct(self) -> str:
load_lines = []
for name, wrapped_value in self._globals.items():
if wrapped_value.setup is not None:
load_lines.append(textwrap.dedent(wrapped_value.setup))
if wrapped_value.serialization == Serialization.PICKLE:
path = os.path.join(self._data_dir, f"{name}.pkl")
load_lines.append(
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
with open(path, "wb") as f:
pickle.dump(wrapped_value.value, f)
elif wrapped_value.serialization == Serialization.TORCH:
path = os.path.join(self._data_dir, f"{name}.pt")
load_lines.append(f"{name} = torch.load({repr(path)})")
torch.save(wrapped_value.value, path)
elif wrapped_value.serialization == Serialization.TORCH_JIT:
path = os.path.join(self._data_dir, f"{name}.pt")
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
with open(path, "wb") as f:
torch.jit.save(wrapped_value.value, f)
else:
raise NotImplementedError(
f"Unknown serialization method: {wrapped_value.serialization}")
return "\n".join(load_lines)
class _ValgrindWrapper:
def __init__(self) -> None:
self._bindings_module: Optional[CallgrindModuleType] = None
valgrind_symbols = (
"_valgrind_supported_platform",
"_valgrind_toggle",
"_valgrind_toggle_and_dump_stats",
)
if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols):
self._supported_platform: bool = torch._C._valgrind_supported_platform()
else:
print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
self._bindings_module = cpp_jit.get_compat_bindings()
assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols)
self._supported_platform = self._bindings_module._valgrind_supported_platform()
self._commands_available: Dict[str, bool] = {}
if self._supported_platform:
# Only bother checking on supported platforms.
for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"):
self._commands_available[cmd] = not subprocess.run(
["which", cmd],
capture_output=True,
).returncode
self._build_type: Optional[str] = None
build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show())
if build_search is not None:
self._build_type = build_search.groups()[0].split(",")[0]
def _validate(self) -> None:
if not self._supported_platform:
raise OSError("Valgrind is not supported on this platform.")
missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available]
if missing_cmds:
raise OSError("Missing: " + ", ".join(missing_cmds))
def collect_callgrind(
self,
task_spec: common.TaskSpec,
globals: Dict[str, Any],
*,
number: int,
repeats: int,
collect_baseline: bool,
is_python: bool,
retain_out_file: bool,
) -> Tuple[CallgrindStats, ...]:
"""Collect stats, and attach a reference run which can be used to filter interpreter overhead."""
self._validate()
assert is_python or not collect_baseline
*task_stats, baseline_stats = self._invoke(
task_spec=task_spec,
globals=globals,
number=number,
repeats=repeats,
collect_baseline=collect_baseline,
is_python=is_python,
retain_out_file=retain_out_file,
)
assert len(task_stats) == repeats
return tuple(
CallgrindStats(
task_spec=task_spec,
number_per_run=number,
built_with_debug_symbols=self._build_type == "RelWithDebInfo",
baseline_inclusive_stats=baseline_stats[0],
baseline_exclusive_stats=baseline_stats[1],
stmt_inclusive_stats=stmt_inclusive_stats,
stmt_exclusive_stats=stmt_exclusive_stats,
stmt_callgrind_out=out_contents,
)
for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats
)
def _invoke(
self,
*,
task_spec: common.TaskSpec,
globals: Dict[str, Any],
number: int,
repeats: int,
collect_baseline: bool,
is_python: bool,
retain_out_file: bool,
) -> Tuple[Tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]:
"""Core invocation method for Callgrind collection.
Valgrind operates by effectively replacing the CPU with an emulated
version which allows it to instrument any code at the cost of severe
performance degradation. This has the practical effect that in order
to collect Callgrind statistics, a new process has to be created
running under `valgrind`. The steps for this process are:
1) Create a scratch directory.
2) Codegen a run script. (_ValgrindWrapper._construct_script)
Inside the run script:
* Validate that Python and torch match the parent process
* Validate that it is indeed running under valgrind
* Execute `setup` and warm up `stmt`
* Begin collecting stats
* Run the `stmt` loop
* Stop collecting stats
3) Parse the run results.
4) Cleanup the scratch directory.
"""
working_dir = common._make_temp_dir(prefix="callgrind")
data_dir = os.path.join(working_dir, "data")
script_file = os.path.join(working_dir, "timer_callgrind.py")
callgrind_out = os.path.join(working_dir, "callgrind.out")
error_log = os.path.join(working_dir, "error.txt")
stat_log = os.path.join(working_dir, "callgrind_stat.txt")
stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log")
def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]:
# https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/
f_stdout_stderr = open(stdout_stderr_log, "wb")
try:
invocation = subprocess.run(
args,
stdout=f_stdout_stderr,
stderr=subprocess.STDOUT,
**kwargs,
)
with open(stdout_stderr_log, "rt") as f:
return invocation, f.read()
finally:
f_stdout_stderr.close()
try:
if is_python:
if self._bindings_module is not None:
shutil.copy(
self._bindings_module.__file__,
os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1])
)
script_file = os.path.join(working_dir, "timer_callgrind.py")
with open(script_file, "wt") as f:
f.write(self._construct_script(
task_spec,
globals=GlobalsBridge(globals, data_dir),
number=number,
repeats=repeats,
collect_baseline=collect_baseline,
error_log=error_log,
stat_log=stat_log,
bindings=self._bindings_module))
run_loop_cmd = ["python", script_file]
else:
assert not collect_baseline
run_loop_exec = cpp_jit.compile_callgrind_template(
stmt=task_spec.stmt,
setup=task_spec.setup,
global_setup=task_spec.global_setup,
)
run_loop_cmd = [
run_loop_exec,
"--number", str(number),
"--number-warmup", str(min(number, 10)),
"--repeats", str(repeats),
"--number-threads", str(task_spec.num_threads),
]
valgrind_invocation, valgrind_invocation_output = run([
"valgrind",
"--tool=callgrind",
f"--callgrind-out-file={callgrind_out}",
"--dump-line=yes",
"--dump-instr=yes",
"--instr-atstart=yes",
"--collect-atstart=no",
] + run_loop_cmd)
if valgrind_invocation.returncode:
error_report = ""
if os.path.exists(error_log):
with open(error_log, "rt") as f:
error_report = f.read()
if not error_report:
error_report = "Unknown error.\n" + valgrind_invocation_output
raise OSError(f"Failed to collect callgrind profile:\n{error_report}")
def parse_output(fpath: str, inclusive: bool) -> FunctionCounts:
annotate_invocation, annotate_invocation_output = run([
"callgrind_annotate",
f"--inclusive={'yes' if inclusive else 'no'}",
"--threshold=100",
"--show-percs=no",
fpath
], check=True)
total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS")
begin_pattern = re.compile(r"Ir\s+file:function")
function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$")
class ScanState(enum.Enum):
SCANNING_FOR_TOTAL = 0
SCANNING_FOR_START = 1
PARSING = 2
scan_state = ScanState.SCANNING_FOR_TOTAL
fn_counts = []
for l in annotate_invocation_output.splitlines(keepends=False):
if scan_state == ScanState.SCANNING_FOR_TOTAL:
total_match = total_pattern.match(l)
if total_match:
program_totals = int(total_match.groups()[0].replace(",", ""))
scan_state = ScanState.SCANNING_FOR_START
elif scan_state == ScanState.SCANNING_FOR_START:
if begin_pattern.match(l):
scan_state = ScanState.PARSING
else:
assert scan_state == ScanState.PARSING
fn_match = function_pattern.match(l)
if fn_match:
ir_str, file_function = fn_match.groups()
ir = int(ir_str.replace(",", ""))
if ir == program_totals:
# Callgrind includes some top level red herring symbols when
# a program dumps multiple profiles.
continue
fn_counts.append(FunctionCount(ir, file_function))
elif re.match(r"-+", l):
# Ignore heading separator lines.
continue
else:
break
assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}"
return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive)
def read_results(i: int) -> Tuple[FunctionCounts, FunctionCounts, Optional[str]]:
if i == repeats and not collect_baseline:
# Null baseline.
return (
FunctionCounts((), inclusive=True),
FunctionCounts((), inclusive=False),
None,
)
fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files.
callgrind_out_contents: Optional[str] = None
if retain_out_file:
with open(fpath, "rt") as f:
callgrind_out_contents = f.read()
return (
parse_output(fpath, inclusive=True),
parse_output(fpath, inclusive=False),
callgrind_out_contents
)
return tuple(read_results(i) for i in range(repeats + 1))
finally:
shutil.rmtree(working_dir)
@staticmethod
def _construct_script(
task_spec: common.TaskSpec,
globals: GlobalsBridge,
*,
number: int,
repeats: int,
collect_baseline: bool,
error_log: str,
stat_log: str,
bindings: Optional[CallgrindModuleType],
) -> str:
def block_stmt(stmt: str, indent: int = 0) -> str:
"""Partially unroll benchmark loop.
The naive template looks something like:
"for _ in range({number}): {stmt}"
However a loop in Python is surprisingly expensive, and significantly
increases the number of background Python instructions. So instead we
partially unroll the loops, with a block size of 100 chosen to keep
the instruction overhead from `range` low while also not ballooning
the size of the generated file.
"""
block_size = 100
loop_count = number // block_size
if loop_count == 1:
# There is no point in having `for _ in range(1): ...` rather
# than just `...`, and this lets us save shave a few background
# instructions.
loop_count = 0
remainder = number - block_size * loop_count
blocked_stmt = ""
if loop_count:
unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4)
blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n"
if remainder:
blocked_stmt += "\n".join([stmt] * remainder)
return textwrap.indent(blocked_stmt, " " * indent)
pass_baseline = (
"callgrind_bindings._valgrind_toggle()\n"
f"{block_stmt('pass')}\n"
"callgrind_bindings._valgrind_toggle_and_dump_stats()"
)
return textwrap.dedent(r"""
import gc
import os
import pickle
import subprocess
import sys
import time
# Mitigate https://github.com/pytorch/pytorch/issues/37377
# which can sometimes cause the subprocess call to fail.
import numpy as np
import torch
torch.set_num_threads({num_threads})
{bindings_import}
PID = os.getpid()
def log_failure(msg):
with open({error_log_repr}, "wt") as f:
f.write(msg)
sys.exit(1)
def check_result(completed_process):
if completed_process.returncode:
log_failure(f"Command failed: {{' '.join(completed_process.args)}}")
return completed_process
# =============================================================================
# == Check that subprocess matches parent =====================================
# =============================================================================
if os.path.realpath(sys.executable) != "{parent_interpreter}":
log_failure(
"Interpreter mismatch:\n"
f" {{os.path.realpath(sys.executable)}}\n vs.\n {parent_interpreter}"
)
if torch.__file__ != "{torch_file}":
log_failure(
"PyTorch does not match expected file:\n"
f" {{torch.__file__}}\n vs.\n {torch_file}"
)
# =============================================================================
# == User specified setup =====================================================
# =============================================================================
# Load serialized globals
{load_globals}
# User setup str
{setup}
for _ in range({warmup_number}):
{indented_stmt}
# =============================================================================
# == Callgrind management =====================================================
# =============================================================================
with open("{stat_log}", "wb") as stat_file:
# If many instances of callgrind are running at once, the output of
# `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE`
# to deadlock. So instead we use a file.
callgrind_stat = check_result(subprocess.run(
["callgrind_control", "--stat"],
stdout=stat_file,
stderr=subprocess.STDOUT,
))
with open("{stat_log}", "rt") as stat_file:
stat_lines = stat_file.read().splitlines()
if f"PID {{PID}}: python {{__file__}}" not in stat_lines:
log_failure("Process does not appear to be running callgrind.")
gc.collect()
time.sleep(0.01)
# =============================================================================
# == User code block ==========================================================
# =============================================================================
for _ in range({repeats}):
callgrind_bindings._valgrind_toggle()
{blocked_stmt}
callgrind_bindings._valgrind_toggle_and_dump_stats()
gc.collect()
{baseline}
""").strip().format(
indented_stmt=textwrap.indent(task_spec.stmt, " " * 4),
blocked_stmt=block_stmt(task_spec.stmt, indent=4),
baseline=(pass_baseline if collect_baseline else ""),
number=number,
repeats=repeats,
load_globals=globals.construct(),
setup=task_spec.setup,
warmup_number=min(number, 10),
num_threads=task_spec.num_threads,
error_log_repr=repr(error_log),
stat_log=stat_log,
parent_interpreter=os.path.realpath(sys.executable),
torch_file=torch.__file__,
bindings_import=(
"import torch._C as callgrind_bindings" if bindings is None
else f"import {bindings.__name__} as callgrind_bindings"),
)
CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None
def wrapper_singleton() -> _ValgrindWrapper:
global CALLGRIND_SINGLETON
if CALLGRIND_SINGLETON is None:
CALLGRIND_SINGLETON = _ValgrindWrapper()
return CALLGRIND_SINGLETON