Skip to content

Layers

Approximate Layer implementations

ApproxConv2d

Bases: ApproxLayer, Conv2d

Approximate 2D Convolution layer implementation

Source code in src/torchapprox/layers/approx_conv2d.py
class ApproxConv2d(ApproxLayer, QATConv2d):
    """
    Approximate 2D Convolution layer implementation
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: Union[str, _size_2_t] = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        qconfig=None,
        device=None,
        dtype=None,
    ) -> None:
        QATConv2d.__init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            qconfig,
            device,
            dtype,
        )
        ApproxLayer.__init__(self)
        assert (
            padding_mode == "zeros"
        ), f"Unsupported padding_mode {padding_mode}, only zero-padding is supported"
        self._opcount = None
        self.to(self.weight.device)

    @staticmethod
    def from_super(cls_instance: torch.nn.Conv2d):
        """
        Alias for from_conv2d
        """
        return ApproxConv2d.from_conv2d(cls_instance)

    @staticmethod
    def from_conv2d(conv2d: torch.nn.Conv2d):
        """
        Construct ApproxConv2d from torch.nn.Conv2d layer
        """
        has_bias = conv2d.bias is not None
        approx_instance = ApproxConv2d(
            conv2d.in_channels,
            conv2d.out_channels,
            conv2d.kernel_size,
            stride=conv2d.stride,
            padding=conv2d.padding,
            dilation=conv2d.dilation,
            groups=conv2d.groups,
            bias=has_bias,
            padding_mode=conv2d.padding_mode,
        )

        with torch.no_grad():
            approx_instance.weight = conv2d.weight
            if has_bias:
                approx_instance.bias = conv2d.bias

        return approx_instance

    def output_dims(self, x):
        """
        Output width and height
        """

        def dim(idx):
            # Copied from
            # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
            return math.floor(
                (
                    x.size(idx + 2)
                    + 2 * self.padding[idx]
                    - self.dilation[idx] * (self.kernel_size[idx] - 1)
                    - 1
                )
                / self.stride[idx]
                + 1
            )

        return (dim(0), dim(1))

    @property
    def opcount(self) -> int:
        if self._opcount is None:
            raise ValueError(
                "Conv layer Opcount not populated. Run forward pass first."
            )
        return self._opcount

    @property
    def fan_in(self) -> int:
        """
        Number of incoming connection for a single neuron
        """
        return self.in_channels * math.prod(self.kernel_size)

    @property
    def conv_args(self) -> Conv2dArgs:
        """
        Wrap layer configuration in dataclass for more convenient passing around
        """
        args = Conv2dArgs(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
        )
        return args

    def quant_fwd(self, x_q, w_q):
        y = torch.nn.functional.conv2d(
            x_q,
            w_q,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
        )
        return y

    def approx_fwd(self, x_q, w_q, quant_params: QuantizationParameters):
        y = ApproxConv2dOp.apply(
            x_q,
            w_q,
            quant_params,
            self.conv_args,
            self.htp_model,
            self.output_dims(x_q),
            self.lut,
            self.traced_inputs,
        )

        return y

    # pylint: disable=arguments-renamed
    def forward(
        self,
        x_q: torch.Tensor,
        x_scale: Optional[torch.Tensor] = None,
        x_zero_point: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # calculate opcount for this layer using the input tensor size on first forward pass
        if self._opcount is None:
            self._opcount = int(
                math.prod(self.kernel_size)
                * math.prod(self.output_dims(x_q))
                * (self.in_channels / self.groups)
                * self.out_channels
            )

        # Reshape bias tensor to make it broadcastable
        bias = None if self.bias is None else self.bias[:, None, None]
        return ApproxLayer.forward(self, x_q, x_scale, x_zero_point, bias)

conv_args: Conv2dArgs property

Wrap layer configuration in dataclass for more convenient passing around

fan_in: int property

Number of incoming connection for a single neuron

from_conv2d(conv2d) staticmethod

Construct ApproxConv2d from torch.nn.Conv2d layer

Source code in src/torchapprox/layers/approx_conv2d.py
@staticmethod
def from_conv2d(conv2d: torch.nn.Conv2d):
    """
    Construct ApproxConv2d from torch.nn.Conv2d layer
    """
    has_bias = conv2d.bias is not None
    approx_instance = ApproxConv2d(
        conv2d.in_channels,
        conv2d.out_channels,
        conv2d.kernel_size,
        stride=conv2d.stride,
        padding=conv2d.padding,
        dilation=conv2d.dilation,
        groups=conv2d.groups,
        bias=has_bias,
        padding_mode=conv2d.padding_mode,
    )

    with torch.no_grad():
        approx_instance.weight = conv2d.weight
        if has_bias:
            approx_instance.bias = conv2d.bias

    return approx_instance

from_super(cls_instance) staticmethod

Alias for from_conv2d

Source code in src/torchapprox/layers/approx_conv2d.py
@staticmethod
def from_super(cls_instance: torch.nn.Conv2d):
    """
    Alias for from_conv2d
    """
    return ApproxConv2d.from_conv2d(cls_instance)

output_dims(x)

Output width and height

Source code in src/torchapprox/layers/approx_conv2d.py
def output_dims(self, x):
    """
    Output width and height
    """

    def dim(idx):
        # Copied from
        # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
        return math.floor(
            (
                x.size(idx + 2)
                + 2 * self.padding[idx]
                - self.dilation[idx] * (self.kernel_size[idx] - 1)
                - 1
            )
            / self.stride[idx]
            + 1
        )

    return (dim(0), dim(1))

ApproxLayer

Bases: ABC

Derivable Abstract Base Class for implementing Approximate Neural Network layers

Source code in src/torchapprox/layers/approx_layer.py
class ApproxLayer(ABC):
    """
    Derivable Abstract Base Class for implementing Approximate Neural Network layers
    """

    def __init__(
        self, qconfig: Optional[tq.QConfig] = None, learnable_noise: bool = False
    ):
        self.inference_mode: InferenceMode = InferenceMode.QUANTIZED

        self._lut: Optional[torch.ShortTensor] = None
        self.lut = self.accurate_lut()

        self.htp_model: Optional[Callable] = None
        self.traced_inputs: Optional[TracedGeMMInputs] = None

        self._stdev: torch.Tensor = torch.tensor([0.0])
        self._mean: torch.Tensor = torch.tensor([0.0])

    @staticmethod
    def default_qconfig() -> tq.QConfig:
        act_qconfig = tq.FakeQuantize.with_args(
            observer=tq.HistogramObserver,
            dtype=torch.quint8,
            qscheme=torch.per_tensor_affine,
            quant_min=0,
            quant_max=127,
        )
        weight_qconfig = tq.FakeQuantize.with_args(
            observer=tq.HistogramObserver,
            dtype=torch.qint8,
            qscheme=torch.per_tensor_symmetric,
            quant_min=-128,
            quant_max=127,
        )
        return tq.QConfig(activation=act_qconfig, weight=weight_qconfig)

    @staticmethod
    def accurate_lut() -> npt.NDArray[np.int32]:
        x = np.arange(256)
        x[x >= 128] -= 256
        xx, yy = np.meshgrid(x, x)
        return (xx * yy).astype(np.int32)

    @property
    def lut(self) -> torch.Tensor:
        """
        The Lookup table to use for approximate multiplication. LUT can be:
        - `None`: An accurate product is used internall. This is much faster than passing
            operands through LUT kernels. Functionally equivalent to running the layer in
            `quant` mode, but useful when the unfolded inputs/outputs need to be traced at runtime.
        - `torch.Tensor` or `numpy.array`:
            - 2D array of size 256x256 is required. Unused entries will be ignored when simulating
                multiplication where the operand width is less than 8 Bit
            - When supplying a `torch.Tensor` the datatype needs to be signed 16-Bit.
        """
        return self._lut

    @lut.setter
    def lut(self, new_lut: Union[np.ndarray, torch.Tensor]):
        assert len(new_lut.shape) == 2, "LUT needs to be 2D square matrix"
        assert (
            new_lut.shape[0] == new_lut.shape[1] == 256
        ), "Only 8x8 Bit LUTs are currently supported."

        if isinstance(new_lut, torch.Tensor):
            assert new_lut.dtype == torch.int, "LUT needs to be signed 32 Bit Integer"
            self._lut = new_lut
        elif isinstance(new_lut, np.ndarray):
            self._lut = torch.from_numpy(new_lut).contiguous().int()
        else:
            raise ValueError(
                f"Unknown LUT input type: {type(new_lut)}, supported types: torch.Tensor, np.ndarray"
            )

    @property
    def stdev(self) -> float:
        """
        Perturbation Error Relative Standard Deviation

        Returns:
            Currently configured perturbation standard deviation
        """
        return self._stdev.item()

    @stdev.setter
    def stdev(self, val: float):
        self._stdev = torch.tensor([val], device=self.weight.device)  # type: ignore

    @property
    def mean(self) -> float:
        """
        Perturbation Error mean

        Returns:
            Currently configured perturbation mean
        """
        return self._mean.item()

    @mean.setter
    def mean(self, val: float):
        self._mean = torch.tensor([val], device=self.weight.device)  # type: ignore

    @property
    @abstractmethod
    def fan_in(self) -> int:
        """
        Number of incoming connections for a neuron in this layer
        """

    @property
    @abstractmethod
    def opcount(self) -> int:
        """
        Number of multiplications for a single
        forward pass of this layer
        """

    @abstractmethod
    def quant_fwd(
        self, x: torch.FloatTensor, w: torch.FloatTensor
    ) -> torch.FloatTensor:
        """Quantized Forward Pass
        Performs the layer operation with an additional pass through the
        currently configured quantizer.

        `x_q and w_q are expected to be **fake-quantized** tensors, i.e. floats that are
        discretized to a set of values, but not converted to actual their integer
        representation.

        Args:
            x_q: Fake-quantized activations
            w_q: Fake-quantized weights

        Returns:
            Layer output
        """

    @abstractmethod
    def approx_fwd(
        self,
        x: torch.CharTensor,
        w: torch.CharTensor,
        quant_params: QuantizationParameters,
    ):
        """Approximate Product Forward Pass
        Performs the layer operation using the currently configured
        approximate product Lookup Table.

        Args:
            x: Layer input

        Returns:
            Layer output
        """

    @no_type_check
    def noise_fwd(
        self, x_q: torch.FloatTensor, w_q: torch.FloatTensor
    ) -> torch.FloatTensor:
        """Quantized Forward Pass that is perturbed
        with Gaussian Noise

        The standard deviation of the additive noise
        is derived from the `stdev`parameter and scaled
        with the standard deviation of the current batch

        Args:
            x: Layer input

        Returns:
            Layer output
        """
        y = self.quant_fwd(x_q, w_q)
        if self.training:
            noise = torch.randn_like(y) * torch.std(y) * self.stdev + self.mean
            y = y + noise
        return y

    @no_type_check
    def forward(
        self,
        x: torch.Tensor,
        x_scale: Optional[torch.Tensor] = None,
        x_zero_point: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass with currently selected mode applied

        Args:
            x: Layer input

        Returns:
            Layer output
        """
        assert hasattr(
            self, "weight_fake_quant"
        ), "QAT nodes not replaced. Run `prepare_qat` first."

        w = self.weight_fake_quant(self.weight)
        if self.inference_mode == InferenceMode.NOISE:
            y = self.noise_fwd(x, w)
        elif self.inference_mode == InferenceMode.APPROXIMATE:
            assert (x_scale is not None) and (
                x_zero_point is not None
            ), "Received no activation quantization information during approximate forward pass"
            assert (
                len(x_scale) == 1 and len(x_zero_point) == 1
            ), "Per-channel quantization only supported for weights"
            quant_params = QuantizationParameters(
                x_scale,
                x_zero_point,
                self.weight_fake_quant.scale,
                self.weight_fake_quant.zero_point,
            )
            y = self.approx_fwd(x, w, quant_params)
        else:
            y = self.quant_fwd(x, w)

        if bias is not None:
            y = y + bias
        elif self.bias is not None:
            y = y + self.bias

        return y

fan_in: int abstractmethod property

Number of incoming connections for a neuron in this layer

lut: torch.Tensor property writable

The Lookup table to use for approximate multiplication. LUT can be: - None: An accurate product is used internall. This is much faster than passing operands through LUT kernels. Functionally equivalent to running the layer in quant mode, but useful when the unfolded inputs/outputs need to be traced at runtime. - torch.Tensor or numpy.array: - 2D array of size 256x256 is required. Unused entries will be ignored when simulating multiplication where the operand width is less than 8 Bit - When supplying a torch.Tensor the datatype needs to be signed 16-Bit.

mean: float property writable

Perturbation Error mean

Returns:

Type Description
float

Currently configured perturbation mean

opcount: int abstractmethod property

Number of multiplications for a single forward pass of this layer

stdev: float property writable

Perturbation Error Relative Standard Deviation

Returns:

Type Description
float

Currently configured perturbation standard deviation

approx_fwd(x, w, quant_params) abstractmethod

Approximate Product Forward Pass Performs the layer operation using the currently configured approximate product Lookup Table.

Parameters:

Name Type Description Default
x CharTensor

Layer input

required

Returns:

Type Description

Layer output

Source code in src/torchapprox/layers/approx_layer.py
@abstractmethod
def approx_fwd(
    self,
    x: torch.CharTensor,
    w: torch.CharTensor,
    quant_params: QuantizationParameters,
):
    """Approximate Product Forward Pass
    Performs the layer operation using the currently configured
    approximate product Lookup Table.

    Args:
        x: Layer input

    Returns:
        Layer output
    """

forward(x, x_scale=None, x_zero_point=None, bias=None)

Forward pass with currently selected mode applied

Parameters:

Name Type Description Default
x Tensor

Layer input

required

Returns:

Type Description
Tensor

Layer output

Source code in src/torchapprox/layers/approx_layer.py
@no_type_check
def forward(
    self,
    x: torch.Tensor,
    x_scale: Optional[torch.Tensor] = None,
    x_zero_point: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Forward pass with currently selected mode applied

    Args:
        x: Layer input

    Returns:
        Layer output
    """
    assert hasattr(
        self, "weight_fake_quant"
    ), "QAT nodes not replaced. Run `prepare_qat` first."

    w = self.weight_fake_quant(self.weight)
    if self.inference_mode == InferenceMode.NOISE:
        y = self.noise_fwd(x, w)
    elif self.inference_mode == InferenceMode.APPROXIMATE:
        assert (x_scale is not None) and (
            x_zero_point is not None
        ), "Received no activation quantization information during approximate forward pass"
        assert (
            len(x_scale) == 1 and len(x_zero_point) == 1
        ), "Per-channel quantization only supported for weights"
        quant_params = QuantizationParameters(
            x_scale,
            x_zero_point,
            self.weight_fake_quant.scale,
            self.weight_fake_quant.zero_point,
        )
        y = self.approx_fwd(x, w, quant_params)
    else:
        y = self.quant_fwd(x, w)

    if bias is not None:
        y = y + bias
    elif self.bias is not None:
        y = y + self.bias

    return y

noise_fwd(x_q, w_q)

Quantized Forward Pass that is perturbed with Gaussian Noise

The standard deviation of the additive noise is derived from the stdevparameter and scaled with the standard deviation of the current batch

Parameters:

Name Type Description Default
x

Layer input

required

Returns:

Type Description
FloatTensor

Layer output

Source code in src/torchapprox/layers/approx_layer.py
@no_type_check
def noise_fwd(
    self, x_q: torch.FloatTensor, w_q: torch.FloatTensor
) -> torch.FloatTensor:
    """Quantized Forward Pass that is perturbed
    with Gaussian Noise

    The standard deviation of the additive noise
    is derived from the `stdev`parameter and scaled
    with the standard deviation of the current batch

    Args:
        x: Layer input

    Returns:
        Layer output
    """
    y = self.quant_fwd(x_q, w_q)
    if self.training:
        noise = torch.randn_like(y) * torch.std(y) * self.stdev + self.mean
        y = y + noise
    return y

quant_fwd(x, w) abstractmethod

Quantized Forward Pass Performs the layer operation with an additional pass through the currently configured quantizer.

`x_q and w_q are expected to be fake-quantized tensors, i.e. floats that are discretized to a set of values, but not converted to actual their integer representation.

Parameters:

Name Type Description Default
x_q

Fake-quantized activations

required
w_q

Fake-quantized weights

required

Returns:

Type Description
FloatTensor

Layer output

Source code in src/torchapprox/layers/approx_layer.py
@abstractmethod
def quant_fwd(
    self, x: torch.FloatTensor, w: torch.FloatTensor
) -> torch.FloatTensor:
    """Quantized Forward Pass
    Performs the layer operation with an additional pass through the
    currently configured quantizer.

    `x_q and w_q are expected to be **fake-quantized** tensors, i.e. floats that are
    discretized to a set of values, but not converted to actual their integer
    representation.

    Args:
        x_q: Fake-quantized activations
        w_q: Fake-quantized weights

    Returns:
        Layer output
    """

ApproxLinear

Bases: ApproxLayer, Linear

Approximate Linear Layer implementation

Source code in src/torchapprox/layers/approx_linear.py
class ApproxLinear(ApproxLayer, QATLinear):
    """
    Approximate Linear Layer implementation
    """

    _FLOAT_MODULE = nn.Linear

    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        qconfig=None,
        device=None,
        dtype=None,
    ):
        QATLinear.__init__(
            self, in_features, out_features, bias, qconfig, device, dtype
        )
        ApproxLayer.__init__(self)
        self._opcount = torch.tensor(self.in_features * self.out_features).float()
        self.to(self.weight.device)

    @property
    def fan_in(self) -> int:
        return int(self.in_features)

    @property
    def opcount(self) -> int:
        return int(self._opcount)

    def quant_fwd(self, x, w):
        return torch.nn.functional.linear(x, w)

    def approx_fwd(self, x, w, quant_params: QuantizationParameters):
        return ApproxGeMM.apply(
            x,
            w,
            self.lut,
            quant_params,
            self.htp_model,
            self.traced_inputs,
        )

ApproxWrapper

Bases: Module

Wrapper for adding quant/dequant stubs to a linear layer in a model.

PyTorch provides the option to wrap modules in quantizers automatically, however a custom module is necessary so that we can forward the activation quantization scale and zero point to the approximate layer in the forward function.

The wrapped instance of torch.nn.Module is meant to be replaced with an instance of torchapprox.layers.ApproxLayer in a separate call to torch.ao.quantization.prepare() after it has been wrapped here.

Source code in src/torchapprox/layers/approx_wrapper.py
class ApproxWrapper(torch.nn.Module):
    """
    Wrapper for adding quant/dequant stubs to a linear layer in a model.

    PyTorch provides the option to wrap modules in quantizers automatically,
    however a custom module is necessary so that we can forward the activation
    quantization scale and zero point to the approximate layer in the forward function.

    The wrapped instance of `torch.nn.Module` is meant to be replaced with an instance of
    `torchapprox.layers.ApproxLayer` in a separate call to
    `torch.ao.quantization.prepare()` after it has been wrapped here.
    """

    def __init__(
        self,
        wrapped: Union[torch.nn.Linear, torch.nn.Conv2d],
        qconfig: Optional[tq.QConfig] = None,
    ):
        """
        Wrap a torch.nn.linear layer with quantization stubs

        Args:
            wrapped: the layer to be wrapped
            qconfig: Quantization configuration. Defaults to None.
        """
        torch.nn.Module.__init__(self)
        self.quant_stub = tq.QuantStub()
        self.dequant_stub = tq.DeQuantStub()

        assert isinstance(wrapped, torch.nn.Linear) or isinstance(
            wrapped, torch.nn.Conv2d
        ), f"Received unknown layer type for wrapping: {type(wrapped)}"
        self.wrapped = wrapped

        if not qconfig:
            qconfig = ApproxLayer.default_qconfig()

        self.qconfig = qconfig

    @staticmethod
    def from_float(wrapped):
        return ApproxWrapper(wrapped)

    def forward(self, x):
        x_q = self.quant_stub(x)
        x_scale = getattr(self.quant_stub.activation_post_process, "scale", None)
        x_zero_point = getattr(
            self.quant_stub.activation_post_process, "zero_point", None
        )
        y_q = self.wrapped(x_q, x_scale, x_zero_point)
        y = self.dequant_stub(y_q)
        return y

__init__(wrapped, qconfig=None)

Wrap a torch.nn.linear layer with quantization stubs

Parameters:

Name Type Description Default
wrapped Union[Linear, Conv2d]

the layer to be wrapped

required
qconfig Optional[QConfig]

Quantization configuration. Defaults to None.

None
Source code in src/torchapprox/layers/approx_wrapper.py
def __init__(
    self,
    wrapped: Union[torch.nn.Linear, torch.nn.Conv2d],
    qconfig: Optional[tq.QConfig] = None,
):
    """
    Wrap a torch.nn.linear layer with quantization stubs

    Args:
        wrapped: the layer to be wrapped
        qconfig: Quantization configuration. Defaults to None.
    """
    torch.nn.Module.__init__(self)
    self.quant_stub = tq.QuantStub()
    self.dequant_stub = tq.DeQuantStub()

    assert isinstance(wrapped, torch.nn.Linear) or isinstance(
        wrapped, torch.nn.Conv2d
    ), f"Received unknown layer type for wrapping: {type(wrapped)}"
    self.wrapped = wrapped

    if not qconfig:
        qconfig = ApproxLayer.default_qconfig()

    self.qconfig = qconfig

InferenceMode

Bases: Enum

Layer inference mode. Can be any of: - quant: Run inference using the layer's quantizer - approx: Run inference using approximate product LUT - noise: Run inference that is perturbed with additive Gaussian noise

Source code in src/torchapprox/layers/approx_layer.py
class InferenceMode(enum.Enum):
    """
    Layer inference mode. Can be any of:
    - `quant`: Run inference using the layer's quantizer
    - `approx`: Run inference using approximate product LUT
    - `noise`: Run inference that is perturbed with additive Gaussian noise
    """

    QUANTIZED = "Quantized Mode"
    NOISE = "Noise Mode"
    APPROXIMATE = "Approximate Mode"