Skip to content

Commit d0e7738

Browse files
[MPS] Add MPS backend plugin for GPU computing on Macbook (PaddlePaddle#485)
1 parent 094616b commit d0e7738

17 files changed

Lines changed: 1036 additions & 1 deletion

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ repos:
6565
description: Format files with ClangFormat.
6666
entry: bash ./tools/codestyle/clang_format.hook -i
6767
language: system
68-
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
68+
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps|mm|m)$
6969
- repo: local
7070
hooks:
7171
- id: cpplint-cpp-source

backends/mps/.clang-format

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
Language: ObjC
3+
BasedOnStyle: Google
4+
IndentWidth: 2
5+
TabWidth: 2
6+
ContinuationIndentWidth: 4
7+
AccessModifierOffset: -1 # The private/protected/public has no indent in class
8+
Standard: Cpp11
9+
AllowAllParametersOfDeclarationOnNextLine: true
10+
BinPackParameters: false
11+
BinPackArguments: false
12+
...

backends/mps/CMakeLists.txt

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright(c) 2023 PaddlePaddle Authors.All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0(the "License"); you may not use
4+
# this file except in compliance with the License.You may obtain a copy of the
5+
# License at
6+
#
7+
# http: // www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.See the
12+
# License for the specific language governing permissions and limitations under
13+
# the License
14+
15+
cmake_minimum_required(VERSION 3.10)
16+
17+
project(paddle-mps CXX C)
18+
19+
set(SIGN_IDENTITY
20+
""
21+
CACHE STRING "Code signing identity for the dylib")
22+
23+
if(SIGN_IDENTITY STREQUAL "")
24+
message(FATAL_ERROR "SIGN_IDENTITY must be set")
25+
endif()
26+
27+
set(CMAKE_CXX_STANDARD 14)
28+
set(CMAKE_XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY ${SIGN_IDENTITY})
29+
set(CMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_REQUIRED "YES")
30+
31+
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
32+
33+
option(WITH_TESTING "compile with unit testing" OFF)
34+
option(ON_INFER "compile with inference c++ lib" OFF)
35+
36+
set(PLUGIN_NAME "paddle-mps")
37+
set(PLUGIN_VERSION "0.0.1")
38+
39+
include(paddle)
40+
41+
include_directories(${PADDLE_INC_DIR} ${CMAKE_SOURCE_DIR}
42+
${CMAKE_SOURCE_DIR}/kernels ${CMAKE_SOURCE_DIR}/runtime)
43+
link_directories(${PADDLE_LIB_DIR})
44+
45+
file(
46+
GLOB_RECURSE PLUGIN_SRCS
47+
RELATIVE ${CMAKE_SOURCE_DIR}
48+
kernels/*.mm ${CMAKE_SOURCE_DIR} kernels/*.cc
49+
${CMAKE_SOURCE_DIR}/runtime/*.mm)
50+
list(APPEND PLUGIN_SRCS runtime/runtime.cc)
51+
52+
# build shared library
53+
add_library(${PLUGIN_NAME} SHARED ${PLUGIN_SRCS})
54+
if(ON_INFER)
55+
target_link_directories(${PLUGIN_NAME} PRIVATE ${PADDLE_INFERENCE_LIB_DIR})
56+
target_link_libraries(${PLUGIN_NAME} PRIVATE paddle_inference)
57+
else()
58+
target_link_libraries(${PLUGIN_NAME} PRIVATE ${PADDLE_CORE_LIB})
59+
endif()
60+
61+
find_library(FOUNDATION_LIBRARY Foundation)
62+
find_library(METAL_LIBRARY Metal REQUIRED)
63+
find_library(MPS_LIBRARY MetalPerformanceShaders REQUIRED)
64+
find_library(MPS_GRAPH_LIBRARY MetalPerformanceShadersGraph REQUIRED)
65+
target_link_libraries(
66+
${PLUGIN_NAME} PRIVATE ${METAL_LIBRARY} ${MPS_LIBRARY} ${FOUNDATION_LIBRARY}
67+
${MPS_GRAPH_LIBRARY})
68+
69+
# packing wheel package
70+
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
71+
${CMAKE_CURRENT_BINARY_DIR}/setup.py)
72+
73+
add_custom_command(
74+
TARGET ${PLUGIN_NAME}
75+
POST_BUILD
76+
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/python/
77+
COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/python/
78+
COMMAND ${CMAKE_COMMAND} -E make_directory
79+
${CMAKE_CURRENT_BINARY_DIR}/python/paddle-plugins/
80+
COMMAND
81+
${CMAKE_COMMAND} -E copy_if_different
82+
${CMAKE_CURRENT_BINARY_DIR}/lib${PLUGIN_NAME}.dylib
83+
${CMAKE_CURRENT_BINARY_DIR}/python/paddle-plugins/
84+
COMMAND
85+
install_name_tool -change @loader_path/../libs/ ${PADDLE_CORE_LIB}
86+
${CMAKE_CURRENT_BINARY_DIR}/python/paddle-plugins/lib${PLUGIN_NAME}.dylib
87+
COMMENT "Creating plugin dirrectories------>>>")
88+
89+
find_package(
90+
Python
91+
COMPONENTS Interpreter
92+
REQUIRED)
93+
94+
add_custom_command(
95+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/python/.timestamp
96+
COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/setup.py bdist_wheel
97+
DEPENDS ${PLUGIN_NAME}
98+
COMMENT "Packing whl packages------>>>")
99+
100+
add_custom_target(python_package ALL
101+
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/python/.timestamp)
102+
103+
if(WITH_TESTING)
104+
set(PYTHON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../Paddle")
105+
enable_testing()
106+
add_subdirectory(tests)
107+
add_custom_command(
108+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/tests/.timestamp
109+
COMMAND cp -r ${CMAKE_SOURCE_DIR}/tests ${CMAKE_CURRENT_BINARY_DIR})
110+
add_custom_target(python_tests ALL
111+
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/tests/.timestamp)
112+
endif()

backends/mps/README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#PaddlePaddle Custom Device Implementaion for Custom CPU
2+
3+
4+
Please refer to the following steps to compile, install and verify the custom device implementaion for MPS backend.
5+
6+
## Prepare environment and source code
7+
8+
```bash
9+
# 1. clone the source code recursively along with Paddle source code
10+
git clone --recursive https://github.com/PaddlePaddle/PaddleCustomDevice
11+
cd PaddleCustomDevice
12+
13+
# 2. execute the following commands to update submodule
14+
git submodule sync
15+
git submodule update --remote --init --recursive
16+
```
17+
18+
## Compile and Install
19+
20+
```bash
21+
#navigate to implementaion for MPS backend.
22+
cd backends/mps
23+
24+
#before compiling, ensure that Paddle is installed, you can run the following command
25+
pip install paddlepaddle
26+
#create the build directory and navigate in
27+
mkdir build && cd build
28+
29+
#Currently, a SIGN_IDENTITY is required to sign the dynamic library(.dylib)
30+
cmake ..-D SIGN_IDENTITY=<Your Identity>
31+
make -j8
32+
33+
#using pip to install the output
34+
pip install dist/paddle_mps*.whl
35+
```
36+
37+
## Verification
38+
39+
```bash
40+
#list available hardware backends
41+
python -c "import paddle; print(paddle.device.get_all_custom_device_type())"
42+
43+
#expected output
44+
['mps']
45+
46+
#run a simple model
47+
python -c "import paddle; paddle.set_device('mps'); print(paddle.nn.functional.softmax(paddle.ones([2])))"
48+
49+
#expected similar output
50+
... ...
51+
Tensor(shape=[2], dtype=float32, place=Place(mps:0), stop_gradient=True,
52+
[0.50000000, 0.50000000])
53+
```

backends/mps/cmake/paddle.cmake

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4+
# use this file except in compliance with the License. You may obtain a copy of
5+
# the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations under
13+
# the License.
14+
15+
find_package(Python ${PYTHON_VERSION} REQUIRED COMPONENTS Interpreter
16+
Development)
17+
18+
if(DEFINED ENV{PADDLE_CUSTOM_PATH})
19+
set(PADDLE_DIR $ENV{PADDLE_CUSTOM_PATH})
20+
else()
21+
execute_process(
22+
COMMAND
23+
"${Python_EXECUTABLE}" "-c"
24+
"import re, paddle; print(re.compile('/__init__.py.*').sub('',paddle.__file__))"
25+
OUTPUT_VARIABLE PADDLE_DIR
26+
OUTPUT_STRIP_TRAILING_WHITESPACE)
27+
endif()
28+
29+
if(NOT EXISTS ${PADDLE_DIR})
30+
message(FATAL_ERROR "NO Installed Paddle Found in ${PADDLE_DIR}")
31+
endif()
32+
33+
set(PADDLE_INC_DIR "${PADDLE_DIR}/include/")
34+
set(PADDLE_LIB_DIR "${PADDLE_DIR}/fluid/")
35+
36+
include_directories(${PADDLE_INC_DIR})
37+
38+
if(EXISTS "${PADDLE_LIB_DIR}/libpaddle.so")
39+
set(paddle_lib_name libpaddle.so)
40+
elseif(EXISTS "${PADDLE_LIB_DIR}/core_avx.so")
41+
set(paddle_lib_name core_avx.so)
42+
else()
43+
set(paddle_lib_name core_noavx.so)
44+
message(WANRING "Cannot find core_avx.so, using core_noavx.so instead.")
45+
endif()
46+
47+
find_library(PADDLE_CORE_LIB ${paddle_lib_name} PATHS ${PADDLE_LIB_DIR})
48+
if(NOT PADDLE_CORE_LIB)
49+
message(FATAL "${paddle_lib_name} NOT found in ${PADDLE_LIB_DIR}")
50+
else()
51+
message(STATUS "Found PADDLE_CORE_LIB: ${PADDLE_CORE_LIB}")
52+
endif()

backends/mps/kernels/op_utils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
18+
#include "runtime/mps_stream.h"
19+
20+
namespace mps {
21+
22+
MPSGraph *make_mps_graph();
23+
24+
MPSGraphTensor *mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph,
25+
MPSDataType dataType,
26+
MPSShape *mpsShape);
27+
28+
void runMPSGraph(MPSStream *mpsStream,
29+
MPSGraph *mpsGraph,
30+
NSDictionary *feeds,
31+
NSDictionary *results);
32+
33+
} // namespace mps

backends/mps/kernels/op_utils.mm

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "op_utils.h"
16+
17+
namespace mps {
18+
19+
MPSGraph *make_mps_graph() {
20+
MPSGraph *mpsGraph = [[MPSGraph new] autorelease];
21+
return mpsGraph;
22+
}
23+
24+
MPSGraphTensor *mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph,
25+
MPSDataType dataType,
26+
MPSShape *mpsShape) {
27+
return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil];
28+
}
29+
30+
void runMPSGraph(MPSStream *mpsStream,
31+
MPSGraph *mpsGraph,
32+
NSDictionary *feeds,
33+
NSDictionary *results) {
34+
mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_WAIT);
35+
}
36+
37+
} // namespace mps
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <vector>
18+
19+
namespace mps_kernel {
20+
21+
void Softmax(const float* in,
22+
float* out,
23+
std::vector<int64_t> x_shape,
24+
std::vector<int64_t> out_shape,
25+
int axis);
26+
27+
} // namespace mps_kernel
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "softmax_impl.h"
2+
#include <Foundation/Foundation.h>
3+
#include "mps_stream.h"
4+
#include "op_utils.h"
5+
6+
namespace mps_kernel {
7+
8+
void Softmax(const float *in,
9+
float *out,
10+
std::vector<int64_t> x_shape,
11+
std::vector<int64_t> out_shape,
12+
int axis) {
13+
mps::MPSStream *stream = mps::getCurrentMPSStream();
14+
@autoreleasepool {
15+
MPSGraph *mpsGraph = mps::make_mps_graph();
16+
17+
int length = x_shape[0];
18+
NSArray<NSNumber *> *input_shape = @[ @(length) ];
19+
20+
MPSGraphTensor *inputTensor =
21+
mps::mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, input_shape);
22+
MPSGraphTensor *outputTensor = [mpsGraph softMaxWithTensor:inputTensor axis:axis name:nil];
23+
24+
id<MTLBuffer> in_buffer = (id<MTLBuffer>)in;
25+
id<MTLBuffer> out_buffer = (id<MTLBuffer>)out;
26+
27+
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
28+
inputTensor : [[[MPSGraphTensorData alloc] initWithMTLBuffer:in_buffer
29+
shape:input_shape
30+
dataType:MPSDataTypeFloat32] autorelease]
31+
};
32+
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
33+
outputTensor : [[[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer
34+
shape:input_shape
35+
dataType:MPSDataTypeFloat32] autorelease]
36+
};
37+
38+
runMPSGraph(stream, mpsGraph, feeds, results);
39+
}
40+
}
41+
42+
} // namespace mps_kernel

0 commit comments

Comments
 (0)