From a09c377da3fcdca18dc2de8812c5582770aeea5e Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 18 Feb 2024 09:50:04 -0500 Subject: [PATCH] [RUNTIME][METAL] Fix multithreading access of metal runtime This PR fixes a bug where metal runtime cannot be accessed from multiple threads. This is because the ThreadLocal entry initialization happens during global workspace initialization, meaning other threads that tries to use metal runtime later cannot have the thread local entry correctly initialized. This PR fixes the problem by always use nullptr fallback and lookup at the global workspace for default stream. --- src/runtime/metal/metal_common.h | 25 +++++++++++------ src/runtime/metal/metal_device_api.mm | 40 ++++++++------------------- src/runtime/metal/metal_module.mm | 9 +++--- 3 files changed, 32 insertions(+), 42 deletions(-) diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index dad156bcdddc..d9154e0f7906 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -136,10 +136,7 @@ class MetalWorkspace final : public DeviceAPI { std::vector> devices; // Warp size constant std::vector warp_size; - // Whether it is initialized. - bool initialized_{false}; - // the mutex for initialization - std::mutex mutex; + MetalWorkspace(); // Destructor ~MetalWorkspace(); // Get device for given device @@ -149,9 +146,6 @@ class MetalWorkspace final : public DeviceAPI { << "Invalid Metal device_id=" << dev.device_id; return devices[dev.device_id]; } - // Initialize workspace - // Return false if already initialized, otherwise return true. - void Init(); // override device API void SetDevice(Device dev) final; void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; @@ -163,7 +157,16 @@ class MetalWorkspace final : public DeviceAPI { void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; - void ReinitializeStreams(); + void ReinitializeDefaultStreams(); + + /** + * Cast stream to the right metal stream data structure + * if stream is nullptr , return the default stream of device_id + * \param stream the input stream handle + * \param device_id The device id of interest + * \returns The stream used in this function. + */ + Stream* CastStreamOrGetDefault(TVMStreamHandle stream, int device_id); // get the global workspace static MetalWorkspace* Global(); @@ -184,7 +187,7 @@ class MetalThreadEntry { /*! \brief The current device */ Device device; /*! \brief The current stream */ - std::vector stream; + std::vector stream; /*! \brief The shared buffer used for copy. */ std::vector> temp_buffer_; /*! \brief workspace pool */ @@ -193,6 +196,10 @@ class MetalThreadEntry { MetalThreadEntry() : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { device.device_id = 0; device.device_type = static_cast(kDLMetal); + MetalWorkspace* global_ws = MetalWorkspace::Global(); + // by default, set the stream to nullptr, which indicate + // that we are using default stream + this->stream.resize(global_ws->devices.size(), nullptr); } ~MetalThreadEntry(); // Get temp buffer with at least size under dev. diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index c4ffc8943c01..e3853ef6d62a 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -42,7 +42,6 @@ void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { AUTORELEASEPOOL { - this->Init(); size_t index = static_cast(dev.device_id); if (kind == kExist) { *rv = int(index < devices.size()); @@ -142,29 +141,18 @@ int GetWarpSize(id dev) { } } -void MetalWorkspace::ReinitializeStreams() { - std::vector& threadStreams = MetalThreadEntry::ThreadLocal()->stream; - ICHECK_EQ(default_streams_.size(), threadStreams.size()); +void MetalWorkspace::ReinitializeDefaultStreams() { for (size_t i = 0; i < default_streams_.size(); ++i) { - if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i]) - delete threadStreams[i]; delete default_streams_[i]; } default_streams_.resize(devices.size()); - threadStreams.resize(devices.size()); for (size_t i = 0; i < devices.size(); ++i) { Stream* stream = new Stream(devices[i]); default_streams_[i] = stream; - threadStreams[i] = stream; } } -void MetalWorkspace::Init() { - if (initialized_) return; - std::lock_guard lock(this->mutex); - if (initialized_) return; - initialized_ = true; - if (devices.size() != 0) return; +MetalWorkspace::MetalWorkspace() { #if TARGET_OS_IPHONE // on iPhone id d = MTLCreateSystemDefaultDevice(); @@ -178,7 +166,7 @@ int GetWarpSize(id dev) { warp_size.push_back(GetWarpSize(d)); } #endif - ReinitializeStreams(); + this->ReinitializeDefaultStreams(); } void MetalWorkspace::SetDevice(Device dev) { @@ -189,7 +177,6 @@ int GetWarpSize(id dev) { DLDataType type_hint) { id buf; AUTORELEASEPOOL { - this->Init(); id dev = GetDevice(device); // GPU memory only MTLResourceOptions storage_mode = MTLResourceStorageModePrivate; @@ -220,20 +207,20 @@ int GetWarpSize(id dev) { }; } -Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) { +Stream* MetalWorkspace::CastStreamOrGetDefault(TVMStreamHandle stream, int device_id) { if (stream != nullptr) return static_cast(stream); - ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr); - return MetalThreadEntry::ThreadLocal()->stream[device_id]; + ICHECK_LT(static_cast(device_id), default_streams_.size()); + ICHECK(default_streams_[device_id] != nullptr); + return default_streams_[device_id]; } void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { AUTORELEASEPOOL { - this->Init(); Device dev = dev_from; if (dev_from.device_type == kDLCPU) dev = dev_to; - Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); + Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id); if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; } @@ -303,15 +290,12 @@ int GetWarpSize(id dev) { void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { ICHECK(stream != nullptr); ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; - Stream* s = static_cast(stream); - if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s) - MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr; - delete s; + delete static_cast(stream); } void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { AUTORELEASEPOOL { - Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); + Stream* s = CastStreamOrGetDefault(stream, dev.device_id); // commit an empty command buffer and wait until it completes. id cb = s->GetCommandBuffer(); [cb commit]; @@ -325,7 +309,7 @@ int GetWarpSize(id dev) { void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; ICHECK(stream != nullptr); - MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream; } void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { @@ -374,7 +358,7 @@ int GetWarpSize(id dev) { }); TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { - MetalWorkspace::Global()->ReinitializeStreams(); + MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); } // namespace metal diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 98e32cdf9caa..01d107942664 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -191,7 +191,9 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons AUTORELEASEPOOL { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; - auto stream = static_cast(t->stream[device_id]); + // obtain the stream + auto stream = + metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id); if (stream->HasErrorHappened()) return; if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); @@ -265,10 +267,7 @@ Module MetalModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string fmt, std::string source) { ObjectPtr n; - AUTORELEASEPOOL { - metal::MetalWorkspace::Global()->Init(); - n = make_object(smap, fmap, fmt, source); - }; + AUTORELEASEPOOL { n = make_object(smap, fmap, fmt, source); }; return Module(n); }