Source code for torch.utils.benchmark.utils.common
"""Base shared classes and utilities."""
import collections
import contextlib
import dataclasses
import os
import shutil
import tempfile
import textwrap
import time
from typing import cast, Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple
import uuid
import torch
__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"]
_MAX_SIGNIFICANT_FIGURES = 4
_MIN_CONFIDENCE_INTERVAL = 25e-9 # 25 ns
# Measurement will include a warning if the distribution is suspect. All
# runs are expected to have some variation; these parameters set the
# thresholds.
_IQR_WARN_THRESHOLD = 0.1
_IQR_GROSS_WARN_THRESHOLD = 0.25
@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True)
class TaskSpec:
"""Container for information used to define a Timer. (except globals)"""
stmt: str
setup: str
global_setup: str = ""
label: Optional[str] = None
sub_label: Optional[str] = None
description: Optional[str] = None
env: Optional[str] = None
num_threads: int = 1
@property
def title(self) -> str:
"""Best effort attempt at a string label for the measurement."""
if self.label is not None:
return self.label + (f": {self.sub_label}" if self.sub_label else "")
elif "\n" not in self.stmt:
return self.stmt + (f": {self.sub_label}" if self.sub_label else "")
return (
f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n"
f"{textwrap.indent(self.stmt, ' ')}"
)
def setup_str(self) -> str:
return (
"" if (self.setup == "pass" or not self.setup)
else f"setup:\n{textwrap.indent(self.setup, ' ')}" if "\n" in self.setup
else f"setup: {self.setup}"
)
def summarize(self) -> str:
"""Build TaskSpec portion of repr string for other containers."""
sections = [
self.title,
self.description or "",
self.setup_str(),
]
return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i])
_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec))
[docs]@dataclasses.dataclass(init=True, repr=False)
class Measurement:
"""The result of a Timer measurement.
This class stores one or more measurements of a given statement. It is
serializable and provides several convenience methods
(including a detailed __repr__) for downstream consumers.
"""
number_per_run: int
raw_times: List[float]
task_spec: TaskSpec
metadata: Optional[Dict[Any, Any]] = None # Reserved for user payloads.
def __post_init__(self) -> None:
self._sorted_times: Tuple[float, ...] = ()
self._warnings: Tuple[str, ...] = ()
self._median: float = -1.0
self._mean: float = -1.0
self._p25: float = -1.0
self._p75: float = -1.0
def __getattr__(self, name: str) -> Any:
# Forward TaskSpec fields for convenience.
if name in _TASKSPEC_FIELDS:
return getattr(self.task_spec, name)
return super().__getattribute__(name)
# =========================================================================
# == Convenience methods for statistics ===================================
# =========================================================================
#
# These methods use raw time divided by number_per_run; this is an
# extrapolation and hides the fact that different number_per_run will
# result in different amortization of overheads, however if Timer has
# selected an appropriate number_per_run then this is a non-issue, and
# forcing users to handle that division would result in a poor experience.
@property
def times(self) -> List[float]:
return [t / self.number_per_run for t in self.raw_times]
@property
def median(self) -> float:
self._lazy_init()
return self._median
@property
def mean(self) -> float:
self._lazy_init()
return self._mean
@property
def iqr(self) -> float:
self._lazy_init()
return self._p75 - self._p25
@property
def significant_figures(self) -> int:
"""Approximate significant figure estimate.
This property is intended to give a convenient way to estimate the
precision of a measurement. It only uses the interquartile region to
estimate statistics to try to mitigate skew from the tails, and
uses a static z value of 1.645 since it is not expected to be used
for small values of `n`, so z can approximate `t`.
The significant figure estimation used in conjunction with the
`trim_sigfig` method to provide a more human interpretable data
summary. __repr__ does not use this method; it simply displays raw
values. Significant figure estimation is intended for `Compare`.
"""
self._lazy_init()
n_total = len(self._sorted_times)
lower_bound = int(n_total // 4)
upper_bound = int(torch.tensor(3 * n_total / 4).ceil())
interquartile_points: Tuple[float, ...] = self._sorted_times[lower_bound:upper_bound]
std = torch.tensor(interquartile_points).std(unbiased=False).item()
sqrt_n = torch.tensor(len(interquartile_points)).sqrt().item()
# Rough estimates. These are by no means statistically rigorous.
confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL)
relative_ci = torch.tensor(self._median / confidence_interval).log10().item()
num_significant_figures = int(torch.tensor(relative_ci).floor())
return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES)
@property
def has_warnings(self) -> bool:
self._lazy_init()
return bool(self._warnings)
def _lazy_init(self) -> None:
if self.raw_times and not self._sorted_times:
self._sorted_times = tuple(sorted(self.times))
_sorted_times = torch.tensor(self._sorted_times, dtype=torch.float64)
self._median = _sorted_times.quantile(.5).item()
self._mean = _sorted_times.mean().item()
self._p25 = _sorted_times.quantile(.25).item()
self._p75 = _sorted_times.quantile(.75).item()
def add_warning(msg: str) -> None:
rel_iqr = self.iqr / self.median * 100
self._warnings += (
f" WARNING: Interquartile range is {rel_iqr:.1f}% "
f"of the median measurement.\n {msg}",
)
if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD):
add_warning("This suggests significant environmental influence.")
elif not self.meets_confidence(_IQR_WARN_THRESHOLD):
add_warning("This could indicate system fluctuation.")
def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool:
return self.iqr / self.median < threshold
@property
def title(self) -> str:
return self.task_spec.title
@property
def env(self) -> str:
return (
"Unspecified env" if self.taskspec.env is None
else cast(str, self.taskspec.env)
)
@property
def as_row_name(self) -> str:
return self.sub_label or self.stmt or "[Unknown]"
def __repr__(self) -> str:
"""
Example repr:
<utils.common.Measurement object at 0x7f395b6ac110>
Broadcasting add (4x8)
Median: 5.73 us
IQR: 2.25 us (4.01 to 6.26)
372 measurements, 100 runs per measurement, 1 thread
WARNING: Interquartile range is 39.4% of the median measurement.
This suggests significant environmental influence.
"""
self._lazy_init()
skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n"
n = len(self._sorted_times)
time_unit, time_scale = select_unit(self._median)
iqr_filter = '' if n >= 4 else skip_line
repr_str = f"""
{super().__repr__()}
{self.task_spec.summarize()}
{'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit}
{iqr_filter}IQR: {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f})
{n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''}
{newline.join(self._warnings)}""".strip() # noqa: B950
return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l)
[docs] @staticmethod
def merge(measurements: Iterable["Measurement"]) -> List["Measurement"]:
"""Convenience method for merging replicates.
Merge will extrapolate times to `number_per_run=1` and will not
transfer any metadata. (Since it might differ between replicates)
"""
grouped_measurements: DefaultDict[TaskSpec, List[Measurement]] = collections.defaultdict(list)
for m in measurements:
grouped_measurements[m.task_spec].append(m)
def merge_group(task_spec: TaskSpec, group: List["Measurement"]) -> "Measurement":
times: List[float] = []
for m in group:
# Different measurements could have different `number_per_run`,
# so we call `.times` which normalizes the results.
times.extend(m.times)
return Measurement(
number_per_run=1,
raw_times=times,
task_spec=task_spec,
metadata=None,
)
return [merge_group(t, g) for t, g in grouped_measurements.items()]
def select_unit(t: float) -> Tuple[str, float]:
"""Determine how to scale times for O(1) magnitude.
This utility is used to format numbers for human consumption.
"""
time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(torch.tensor(t).log10().item() // 3), "s")
time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit]
return time_unit, time_scale
def unit_to_english(u: str) -> str:
return {
"ns": "nanosecond",
"us": "microsecond",
"ms": "millisecond",
"s": "second",
}[u]
def trim_sigfig(x: float, n: int) -> float:
"""Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)"""
assert n == int(n)
magnitude = int(torch.tensor(x).abs().log10().ceil().item())
scale = 10 ** (magnitude - n)
return float(torch.tensor(x / scale).round() * scale)
def ordered_unique(elements: Iterable[Any]) -> List[Any]:
return list(collections.OrderedDict(dict.fromkeys(elements)).keys())
@contextlib.contextmanager
def set_torch_threads(n: int) -> Iterator[None]:
prior_num_threads = torch.get_num_threads()
try:
torch.set_num_threads(n)
yield
finally:
torch.set_num_threads(prior_num_threads)
def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str:
"""Create a temporary directory. The caller is responsible for cleanup.
This function is conceptually similar to `tempfile.mkdtemp`, but with
the key additional feature that it will use shared memory if the
`BENCHMARK_USE_DEV_SHM` environment variable is set. This is an
implementation detail, but an important one for cases where many Callgrind
measurements are collected at once. (Such as when collecting
microbenchmarks.)
This is an internal utility, and is exported solely so that microbenchmarks
can reuse the util.
"""
use_dev_shm: bool = (os.getenv("BENCHMARK_USE_DEV_SHM") or "").lower() in ("1", "true")
if use_dev_shm:
root = "/dev/shm/pytorch_benchmark_utils"
assert os.name == "posix", f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}"
assert os.path.exists("/dev/shm"), "This system does not appear to support tmpfs (/dev/shm)."
os.makedirs(root, exist_ok=True)
# Because we're working in shared memory, it is more important than
# usual to clean up ALL intermediate files. However we don't want every
# worker to walk over all outstanding directories, so instead we only
# check when we are sure that it won't lead to contention.
if gc_dev_shm:
for i in os.listdir(root):
owner_file = os.path.join(root, i, "owner.pid")
if not os.path.exists(owner_file):
continue
with open(owner_file) as f:
owner_pid = int(f.read())
if owner_pid == os.getpid():
continue
try:
# https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python
os.kill(owner_pid, 0)
except OSError:
print(f"Detected that {os.path.join(root, i)} was orphaned in shared memory. Cleaning up.")
shutil.rmtree(os.path.join(root, i))
else:
root = tempfile.gettempdir()
# We include the time so names sort by creation time, and add a UUID
# to ensure we don't collide.
name = f"{prefix or tempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}"
path = os.path.join(root, name)
os.makedirs(path, exist_ok=False)
if use_dev_shm:
with open(os.path.join(path, "owner.pid"), "w") as f:
f.write(str(os.getpid()))
return path