class ApproxConv2dOp(torch.autograd.Function):
"""
Autograd wrapper around Im2Col/ApproxGeMM Conv2d operator
"""
@staticmethod
def forward(
x: torch.FloatTensor,
w: torch.FloatTensor,
quant_params: "QuantizationParameters",
conv_args: Conv2dArgs,
htp_model: Optional[Callable],
out_dims: Tuple[int, int],
lut: torch.ShortTensor,
traced_inputs: Optional["TracedGeMMInputs"],
):
x_q = torch.round((x / quant_params.x_scale) + quant_params.x_zero_point)
w_q = torch.round(
(w / quant_params.w_scale[:, None, None, None])
+ quant_params.w_zero_point[:, None, None, None]
)
trace = traced_inputs is not None
if htp_model is not None and not trace:
# HTP model
y_q = htp_model(
torch.nn.functional.conv2d, x_q, w_q, conv_args.backward_args()
)
torch.round(y_q)
elif (conv_args.use_fast_dwconv() and x.is_cuda and w.is_cuda) and not trace:
# Depthwise Conv CUDA Kernel
y_q = dwconv2d(x_q, w_q, lut, conv_args.stride, conv_args.padding)
else:
# im2col & gemm kernel (supports CPU & GPU)
y_q = _im2col_conv2d(x_q, w_q, conv_args, lut, out_dims, traced_inputs)
if quant_params.x_zero_point == 0 and torch.all(quant_params.w_zero_point == 0):
y_q = _symmetric_requantize(y_q, quant_params)
else:
y_q = _affine_requantize(
x_q,
w_q,
y_q,
quant_params,
conv_args,
out_dims,
)
y_q = y_q.view(
x_q.size(0),
conv_args.out_channels,
out_dims[0],
out_dims[1],
)
return y_q
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
x, w, _, conv_args, _, _, _, _ = inputs
ctx.save_for_backward(x, w)
ctx.conf = conv_args.backward_args()
@staticmethod
def backward(ctx, grad):
x, w = ctx.saved_tensors
conf = ctx.conf
grad_input, grad_weight = _conv_bwd_ste(
grad, x, w, conf, ctx.needs_input_grad[0], ctx.needs_input_grad[1]
)
return grad_input, grad_weight, None, None, None, None, None, None, None, None