Skip to content
Prev Previous commit
Next Next commit
Perf imprv - Map conv2D to depthwiseConv3D
  • Loading branch information
DenisVieriu97 committed Nov 30, 2023
commit 9853d9325baabffb51d790fa5cb802a3c148acd0
58 changes: 33 additions & 25 deletions backends/apple/mps/operations/ConvolutionOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -41,35 +41,43 @@
}

if(is_depthwise){
MPSGraphDepthwiseConvolution2DOpDescriptor* desc = [MPSGraphDepthwiseConvolution2DOpDescriptor
descriptorWithStrideInX:stride[0]
strideInY:stride[1]
dilationRateInX:dilation[0]
dilationRateInY:dilation[1]
paddingLeft:padding[1]
paddingRight:padding[1]
paddingTop:padding[0]
paddingBottom:padding[0]
paddingStyle:MPSGraphPaddingStyleExplicit
dataLayout:MPSGraphTensorNamedDataLayoutNCHW
weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor =
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
depthWiseConv3dDescriptor.strides =
@[ @1, [[NSNumber alloc] initWithInteger:stride[0]], [[NSNumber alloc] initWithInteger:stride[1]] ];
depthWiseConv3dDescriptor.dilationRates =
@[ @1, [[NSNumber alloc] initWithInteger:dilation[0]], [[NSNumber alloc] initWithInteger:dilation[1]] ];

MPSGraphTensor* depthwiseConv2DTensor = [mpsGraph depthwiseConvolution2DWithSourceTensor:primaryTensor
weightsTensor:secondaryTensor
descriptor:desc
name:@"depthwiseConv2D"];
depthWiseConv3dDescriptor.paddingStyle = MPSGraphPaddingStyleExplicit;
depthWiseConv3dDescriptor.paddingValues = @[
@0,
@0,
[[NSNumber alloc] initWithInteger:padding[0]],
[[NSNumber alloc] initWithInteger:padding[0]],
[[NSNumber alloc] initWithInteger:padding[1]],
[[NSNumber alloc] initWithInteger:padding[1]]
];
depthWiseConv3dDescriptor.channelDimensionIndex = -3LL;
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:secondaryTensor
dimension:-3
withDimension:-4
name:nil];
MPSGraphTensor* depthwiseConvTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor:primaryTensor
weightsTensor:weightTransposeTensor
descriptor:depthWiseConv3dDescriptor
name:nil];
//Can be a nullptr
if(biasTensor){
//Need to add correct dimension to bias to avoid broadcasting issues
biasTensor = [mpsGraph expandDimsOfTensor:biasTensor
axes:@[@0, @2, @3]
name:nil];
depthwiseConv2DTensor = [mpsGraph additionWithPrimaryTensor:depthwiseConv2DTensor
depthwiseConvTensor = [mpsGraph additionWithPrimaryTensor:depthwiseConvTensor
secondaryTensor:biasTensor
name:@"depthwiseConv2DWithBiasAdd"];
}

return depthwiseConv2DTensor;
return depthwiseConvTensor;
} else {
MPSGraphConvolution2DOpDescriptor* desc = [MPSGraphConvolution2DOpDescriptor
descriptorWithStrideInX:stride[1]
Expand All @@ -90,7 +98,7 @@
descriptor:desc
name:@"conv2D"];

//Can be a nullptr
// Can be a nullptr
if(biasTensor){
//Need to add correct dimension to bias to avoid broadcasting issues
biasTensor = [mpsGraph expandDimsOfTensor:biasTensor
Expand All @@ -101,13 +109,13 @@
name:@"conv2DWithBiasAdd"];
}

if (isConv1D) {
conv2DTensor = [mpsGraph squeezeTensor:conv2DTensor
axis:2
name:@"squeeze"];
}
if (isConv1D) {
conv2DTensor = [mpsGraph squeezeTensor:conv2DTensor
axis:2
name:@"squeeze"];
}

return conv2DTensor;
return conv2DTensor;
}
}
} //namespace mps