Skip to content

Operators

Low-level NN operator implementations for GPU & CPU

ApproxConv2dOp

Bases: Function

Autograd wrapper around Im2Col/ApproxGeMM Conv2d operator

Source code in src/torchapprox/operators/conv2d.py
class ApproxConv2dOp(torch.autograd.Function):
    """
    Autograd wrapper around Im2Col/ApproxGeMM Conv2d operator
    """

    @staticmethod
    def forward(
        x: torch.FloatTensor,
        w: torch.FloatTensor,
        quant_params: "QuantizationParameters",
        conv_args: Conv2dArgs,
        htp_model: Optional[Callable],
        out_dims: Tuple[int, int],
        lut: torch.ShortTensor,
        traced_inputs: Optional["TracedGeMMInputs"],
    ):
        x_q = torch.round((x / quant_params.x_scale) + quant_params.x_zero_point)
        w_q = torch.round(
            (w / quant_params.w_scale[:, None, None, None])
            + quant_params.w_zero_point[:, None, None, None]
        )

        trace = traced_inputs is not None
        if htp_model is not None and not trace:
            # HTP model
            y_q = htp_model(
                torch.nn.functional.conv2d, x_q, w_q, conv_args.backward_args()
            )
            torch.round(y_q)
        elif (conv_args.use_fast_dwconv() and x.is_cuda and w.is_cuda) and not trace:
            # Depthwise Conv CUDA Kernel
            y_q = dwconv2d(x_q, w_q, lut, conv_args.stride, conv_args.padding)
        else:
            # im2col & gemm kernel (supports CPU & GPU)
            y_q = _im2col_conv2d(x_q, w_q, conv_args, lut, out_dims, traced_inputs)

        if quant_params.x_zero_point == 0 and torch.all(quant_params.w_zero_point == 0):
            y_q = _symmetric_requantize(y_q, quant_params)
        else:
            y_q = _affine_requantize(
                x_q,
                w_q,
                y_q,
                quant_params,
                conv_args,
                out_dims,
            )

        y_q = y_q.view(
            x_q.size(0),
            conv_args.out_channels,
            out_dims[0],
            out_dims[1],
        )

        return y_q

    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        x, w, _, conv_args, _, _, _, _ = inputs
        ctx.save_for_backward(x, w)
        ctx.conf = conv_args.backward_args()

    @staticmethod
    def backward(ctx, grad):
        x, w = ctx.saved_tensors
        conf = ctx.conf
        grad_input, grad_weight = _conv_bwd_ste(
            grad, x, w, conf, ctx.needs_input_grad[0], ctx.needs_input_grad[1]
        )
        return grad_input, grad_weight, None, None, None, None, None, None, None, None